{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module RL.Encoding where

import Common
import GreedyParser
import Internal.MultiSet qualified as MS
import PVGrammar (Edge, Edges (Edges), Freeze (FreezeOp), InnerEdge, Note (..), Notes (Notes), PVAnalysis, PVLeftmost, Split, Spread)
import PVGrammar.Generate (derivationPlayerPV)
import PVGrammar.Parse (protoVoiceEvaluator, pvThaw)
import RL.ModelTypes

import Control.DeepSeq
import Data.Foldable qualified as F
import Data.HashSet qualified as HS
import Data.Hashable (Hashable)
import Data.List qualified
import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (catMaybes, mapMaybe)
import Data.Proxy (Proxy (..))
import Data.Type.Equality ((:~:) (..))
import Data.TypeNums (KnownInt, KnownNat, Nat, TInt (..), intVal, intVal', type (*), type (+), type (-), type (>=))
import Data.Vector qualified as V
import Data.Vector.Generic.Sized.Internal qualified as VSU
import Data.Vector.Sized qualified as VS
import Debug.Trace qualified as DT
import GHC.Exts (Proxy#, proxy#)
import GHC.Generics
import Musicology.Pitch
import Torch qualified as T
import Torch.Lens qualified as T
import Torch.Typed qualified as TT
import Unsafe.Coerce (unsafeCoerce)

-- Utilities
-- =========

-- -- Tensorized: get all tensors out of a data structure
-- -- ---------------------------------------------------

-- class GTensorized f where
--   gFlattenTensors :: forall a. f a -> [T.Tensor]

-- instance GTensorized U1 where
--   gFlattenTensors U1 = []

-- instance (GTensorized f, GTensorized g) => GTensorized (f :+: g) where
--   gFlattenTensors (L1 x) = gFlattenTensors x
--   gFlattenTensors (R1 x) = gFlattenTensors x

-- instance (GTensorized f, GTensorized g) => GTensorized (f :*: g) where
--   gFlattenTensors (x :*: y) = gFlattenTensors x ++ gFlattenTensors y

-- instance (Tensorized c) => GTensorized (K1 i c) where
--   gFlattenTensors (K1 x) = flattenTensors x

-- instance (GTensorized f) => GTensorized (M1 i t f) where
--   gFlattenTensors (M1 x) = gFlattenTensors x

-- class Tensorized a where
--   flattenTensors :: a -> [T.Tensor]
--   default flattenTensors :: (Generic a, GTensorized (Rep a)) => a -> [T.Tensor]
--   flattenTensors f = gFlattenTensors (from f)

-- instance Tensorized (TT.Tensor dev dtype shape) where
--   flattenTensors t = [TT.toDynamic t]

-- instance Tensorized (TT.Parameter dev dtype shape) where
--   flattenTensors t = [TT.toDynamic $ TT.toDependent t]

-- instance (Tensorized a) => Tensorized (StartStop a)

-- instance (Tensorized a, Tensorized b) => Tensorized (a, b)

-- instance (Tensorized a) => Tensorized [a]

-- instance Tensorized Double where
--   flattenTensors _ = []

-- instance Tensorized (TT.Conv2d a b c d dtype dev)

-- instance Tensorized (TT.Linear i o dtype device)

-- instance Tensorized (TT.LayerNorm shape dtype device)

-- Stackable and Batchable class
-- -----------------------------

class Stackable a where
  type Stacked a (n :: Nat)
  stack :: (KnownNat n, KnownNat (1 + n)) => VS.Vector (1 + n) a -> Stacked a (1 + n)

stackUnsafe :: (Stackable a) => [a] -> Stacked a FakeSize
stackUnsafe :: forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [a]
things = Vector (1 + 1336) a -> Stacked a (1 + 1336)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + 1336) a -> Stacked a (1 + 1336))
-> Vector (1 + 1336) a -> Stacked a (1 + 1336)
forall a b. (a -> b) -> a -> b
$ Vector a -> Vector (1 + 1336) a
forall (v :: Type -> Type) (n :: Natural) a. v a -> Vector v n a
VSU.Vector (Vector a -> Vector (1 + 1336) a)
-> Vector a -> Vector (1 + 1336) a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
things

class Batchable a where
  type Batched a
  addBatchDim :: a -> Batched a

instance Batchable (TT.Tensor dev dtype shape) where
  type Batched (TT.Tensor dev dtype shape) = TT.Tensor dev dtype (1 : shape)
  addBatchDim :: Tensor dev dtype shape -> Batched (Tensor dev dtype shape)
addBatchDim = forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0

-- Masked Maybe
-- ------------

data QMaybe dev (batchShape :: [Nat]) a = QMaybe
  { forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> QTensor dev batchShape
qmMask :: QTensor dev batchShape
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> a
qmContent :: a
  }
  deriving (Int -> QMaybe dev batchShape a -> ShowS
[QMaybe dev batchShape a] -> ShowS
QMaybe dev batchShape a -> String
(Int -> QMaybe dev batchShape a -> ShowS)
-> (QMaybe dev batchShape a -> String)
-> ([QMaybe dev batchShape a] -> ShowS)
-> Show (QMaybe dev batchShape a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QMaybe dev batchShape a -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QMaybe dev batchShape a] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QMaybe dev batchShape a -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QMaybe dev batchShape a -> ShowS
showsPrec :: Int -> QMaybe dev batchShape a -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QMaybe dev batchShape a -> String
show :: QMaybe dev batchShape a -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QMaybe dev batchShape a] -> ShowS
showList :: [QMaybe dev batchShape a] -> ShowS
Show, (forall x.
 QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x)
-> (forall x.
    Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a)
-> Generic (QMaybe dev batchShape a)
forall x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
forall x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
from :: forall x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
to :: forall x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
Generic, QMaybe dev batchShape a -> ()
(QMaybe dev batchShape a -> ()) -> NFData (QMaybe dev batchShape a)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QMaybe dev batchShape a -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QMaybe dev batchShape a -> ()
rnf :: QMaybe dev batchShape a -> ()
NFData)

qNothing
  :: ( TT.TensorOptions batchShape QDType dev
     )
  => a
  -> QMaybe dev batchShape a
qNothing :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing = QTensor dev batchShape -> a -> QMaybe dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe QTensor dev batchShape
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros

qJust
  :: (TT.TensorOptions batchShape QDType dev)
  => a
  -> QMaybe dev batchShape a
qJust :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust = QTensor dev batchShape -> a -> QMaybe dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe QTensor dev batchShape
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.ones

instance (Stackable a) => Stackable (QMaybe dev batchShape a) where
  type Stacked (QMaybe dev batchShape a) n = QMaybe dev (n ': batchShape) (Stacked a n)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (QMaybe dev batchShape a)
-> Stacked (QMaybe dev batchShape a) (1 + n)
stack Vector (1 + n) (QMaybe dev batchShape a)
ms = QTensor dev ((1 + n) : batchShape)
-> Stacked a (1 + n)
-> QMaybe dev ((1 + n) : batchShape) (Stacked a (1 + n))
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe QTensor dev ((1 + n) : batchShape)
Tensor dev QDType (Insert 0 (1 + n) batchShape)
masks Stacked a (1 + n)
Stacked a (1 + n)
contents
   where
    masks :: Tensor dev QDType (Insert 0 (1 + n) batchShape)
masks = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType batchShape)
 -> Tensor dev QDType (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev QDType batchShape)
-> Tensor dev QDType (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ QMaybe dev batchShape a -> Tensor dev QDType batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> QTensor dev batchShape
qmMask (QMaybe dev batchShape a -> Tensor dev QDType batchShape)
-> Vector Vector (1 + n) (QMaybe dev batchShape a)
-> Vector (1 + n) (Tensor dev QDType batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QMaybe dev batchShape a)
Vector (1 + n) (QMaybe dev batchShape a)
ms
    contents :: Stacked a (1 + n)
contents = Vector (1 + n) a -> Stacked a (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) a -> Stacked a (1 + n))
-> Vector (1 + n) a -> Stacked a (1 + n)
forall a b. (a -> b) -> a -> b
$ QMaybe dev batchShape a -> a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> a
qmContent (QMaybe dev batchShape a -> a)
-> Vector Vector (1 + n) (QMaybe dev batchShape a)
-> Vector Vector (1 + n) a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QMaybe dev batchShape a)
Vector (1 + n) (QMaybe dev batchShape a)
ms

instance (Batchable a) => Batchable (QMaybe dev shape a) where
  type Batched (QMaybe dev shape a) = QMaybe dev (1 : shape) (Batched a)
  addBatchDim :: QMaybe dev shape a -> Batched (QMaybe dev shape a)
addBatchDim (QMaybe QTensor dev shape
mask a
content) = QTensor dev (1 : shape)
-> Batched a -> QMaybe dev (1 : shape) (Batched a)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 QTensor dev shape
mask) (a -> Batched a
forall a. Batchable a => a -> Batched a
addBatchDim a
content)

-- instance (T.HasTypes a T.Tensor) => T.HasTypes (QMaybe shape a) T.Tensor

-- Masked List
-- -----------

data QBoundedList dev (dtype :: TT.DType) (maxLen :: Nat) (batchShape :: [Nat]) (innerShape :: [Nat])
  = QBoundedList
  { forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> QTensor dev (batchShape ++ '[maxLen])
qlMask :: QTensor dev (batchShape TT.++ '[maxLen])
  , forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
qlContent :: TT.Tensor dev dtype (batchShape TT.++ '[maxLen] TT.++ innerShape)
  }
  deriving (Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
[QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
QBoundedList dev dtype maxLen batchShape innerShape -> String
(Int
 -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS)
-> (QBoundedList dev dtype maxLen batchShape innerShape -> String)
-> ([QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS)
-> Show (QBoundedList dev dtype maxLen batchShape innerShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
[QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
showsPrec :: Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> String
show :: QBoundedList dev dtype maxLen batchShape innerShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
[QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
showList :: [QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
Show, (forall x.
 QBoundedList dev dtype maxLen batchShape innerShape
 -> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x)
-> (forall x.
    Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
    -> QBoundedList dev dtype maxLen batchShape innerShape)
-> Generic (QBoundedList dev dtype maxLen batchShape innerShape)
forall x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
forall x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]) x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]) x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]) x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
from :: forall x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]) x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
to :: forall x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
Generic, QBoundedList dev dtype maxLen batchShape innerShape -> ()
(QBoundedList dev dtype maxLen batchShape innerShape -> ())
-> NFData (QBoundedList dev dtype maxLen batchShape innerShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> ()
rnf :: QBoundedList dev dtype maxLen batchShape innerShape -> ()
NFData)

qBoundedList
  :: forall dev dtype maxLen innerShape
   . ( KnownNat maxLen
     , TT.KnownDevice dev
     , TT.KnownShape innerShape
     , TT.TensorOptions innerShape QDType dev
     , TT.TensorOptions innerShape dtype dev
     )
  => [TT.Tensor dev dtype innerShape]
  -> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
 TensorOptions innerShape QDType dev,
 TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList [] = QTensor dev ('[] ++ '[maxLen])
-> Tensor dev dtype (('[] ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen '[] innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList Tensor dev QDType '[maxLen]
QTensor dev ('[] ++ '[maxLen])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros Tensor dev dtype (maxLen : innerShape)
Tensor dev dtype (('[] ++ '[maxLen]) ++ innerShape)
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
qBoundedList [Tensor dev dtype innerShape]
lst = QTensor dev ('[] ++ '[maxLen])
-> Tensor dev dtype (('[] ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen '[] innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList (Tensor -> Tensor dev QDType '[maxLen]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
mask) (Tensor -> Tensor dev dtype (maxLen : innerShape)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
paddedContent)
 where
  maxLen :: Int
maxLen = forall (n :: Natural). KnownNat n => Int
TT.natValI @maxLen
  content :: Tensor
content = Dim -> [Tensor] -> Tensor
T.stack (Int -> Dim
T.Dim Int
0) ([Tensor] -> Tensor) -> [Tensor] -> Tensor
forall a b. (a -> b) -> a -> b
$ Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
take Int
maxLen ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ Tensor dev dtype innerShape -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (Tensor dev dtype innerShape -> Tensor)
-> [Tensor dev dtype innerShape] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Tensor dev dtype innerShape]
lst
  len :: Int
len = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
maxLen (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ [Tensor dev dtype innerShape] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Tensor dev dtype innerShape]
lst
  padLen :: Int
padLen = Int
maxLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
  innerShape :: [Int]
innerShape = forall (shape :: [Natural]). KnownShape shape => [Int]
TT.shapeVal @innerShape
  -- padSpec: two numbers per dim for pre and post padding, respectively
  -- here: list dim (only post) + inner dims (no padding)
  padSpec :: [Int]
padSpec = Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
innerShape) Int
0 [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0, Int
padLen]
  paddedContent :: Tensor
paddedContent = [Int] -> Float -> Tensor -> Tensor
T.constantPadNd1d [Int]
padSpec Float
0 Tensor
content
  mask :: Tensor
mask = Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) [[Int] -> TensorOptions -> Tensor
T.ones [Int
len] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev, [Int] -> TensorOptions -> Tensor
T.zeros [Int
padLen] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev]

instance Stackable (QBoundedList dev dtype maxLen batchShape innerShape) where
  type
    Stacked (QBoundedList dev dtype maxLen batchShape innerShape) n =
      QBoundedList dev dtype maxLen (n ': batchShape) innerShape
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
-> Stacked
     (QBoundedList dev dtype maxLen batchShape innerShape) (1 + n)
stack Vector
  (1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
xs = QTensor dev (((1 + n) : batchShape) ++ '[maxLen])
-> Tensor
     dev dtype ((((1 + n) : batchShape) ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen ((1 + n) : batchShape) innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList QTensor dev (((1 + n) : batchShape) ++ '[maxLen])
Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen]))
masks Tensor
  dev dtype ((((1 + n) : batchShape) ++ '[maxLen]) ++ innerShape)
Tensor
  dev
  dtype
  (Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape))
contents
   where
    masks :: Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen]))
masks = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType (batchShape ++ '[maxLen]))
 -> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen])))
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[maxLen]))
-> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen]))
forall a b. (a -> b) -> a -> b
$ QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev QDType (batchShape ++ '[maxLen])
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> QTensor dev (batchShape ++ '[maxLen])
qlMask (QBoundedList dev dtype maxLen batchShape innerShape
 -> Tensor dev QDType (batchShape ++ '[maxLen]))
-> Vector
     Vector
     (1 + n)
     (QBoundedList dev dtype maxLen batchShape innerShape)
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[maxLen]))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector
  Vector
  (1 + n)
  (QBoundedList dev dtype maxLen batchShape innerShape)
Vector
  (1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
xs
    contents :: Tensor
  dev
  dtype
  (Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape))
contents = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector
   (1 + n)
   (Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
 -> Tensor
      dev
      dtype
      (Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape)))
-> Vector
     (1 + n)
     (Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
-> Tensor
     dev
     dtype
     (Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape))
forall a b. (a -> b) -> a -> b
$ QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
qlContent (QBoundedList dev dtype maxLen batchShape innerShape
 -> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
-> Vector
     Vector
     (1 + n)
     (QBoundedList dev dtype maxLen batchShape innerShape)
-> Vector
     (1 + n)
     (Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector
  Vector
  (1 + n)
  (QBoundedList dev dtype maxLen batchShape innerShape)
Vector
  (1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
xs

instance Batchable (QBoundedList dev dtype maxLen batchShape innerShape) where
  type
    Batched (QBoundedList dev dtype maxLen batchShape innerShape) =
      QBoundedList dev dtype maxLen (1 : batchShape) innerShape
  addBatchDim :: QBoundedList dev dtype maxLen batchShape innerShape
-> Batched (QBoundedList dev dtype maxLen batchShape innerShape)
addBatchDim (QBoundedList QTensor dev (batchShape ++ '[maxLen])
mask Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
content) =
    QTensor dev ((1 : batchShape) ++ '[maxLen])
-> Tensor dev dtype (((1 : batchShape) ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen (1 : batchShape) innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 QTensor dev (batchShape ++ '[maxLen])
mask) (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
content)

-- instance T.HasTypes (QBoundedList dtype maxLen batchShape innerShape) T.Tensor

-- Tagged StartStop
-- ----------------

data QStartStop dev (batchShape :: [Nat]) a = QStartStop
  { forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape
qssTag :: TT.Tensor dev TT.Int64 batchShape
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> a
qssContent :: a
  }
  deriving (Int -> QStartStop dev batchShape a -> ShowS
[QStartStop dev batchShape a] -> ShowS
QStartStop dev batchShape a -> String
(Int -> QStartStop dev batchShape a -> ShowS)
-> (QStartStop dev batchShape a -> String)
-> ([QStartStop dev batchShape a] -> ShowS)
-> Show (QStartStop dev batchShape a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QStartStop dev batchShape a -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QStartStop dev batchShape a] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QStartStop dev batchShape a -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QStartStop dev batchShape a -> ShowS
showsPrec :: Int -> QStartStop dev batchShape a -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QStartStop dev batchShape a -> String
show :: QStartStop dev batchShape a -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QStartStop dev batchShape a] -> ShowS
showList :: [QStartStop dev batchShape a] -> ShowS
Show, (forall x.
 QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x)
-> (forall x.
    Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a)
-> Generic (QStartStop dev batchShape a)
forall x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
forall x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
from :: forall x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
       x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
to :: forall x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
Generic, QStartStop dev batchShape a -> ()
(QStartStop dev batchShape a -> ())
-> NFData (QStartStop dev batchShape a)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QStartStop dev batchShape a -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QStartStop dev batchShape a -> ()
rnf :: QStartStop dev batchShape a -> ()
NFData)

qInner :: (TT.TensorOptions batchShape TT.Int64 dev) => a -> QStartStop dev batchShape a
qInner :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qInner = Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (Int -> Tensor dev 'Int64 batchShape
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
1 :: Int))

qStart :: (TT.TensorOptions batchShape TT.Int64 dev) => a -> QStartStop dev batchShape a
qStart :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart = Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (Int -> Tensor dev 'Int64 batchShape
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
0 :: Int))

qStop :: (TT.TensorOptions batchShape TT.Int64 dev) => a -> QStartStop dev batchShape a
qStop :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop = Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (Int -> Tensor dev 'Int64 batchShape
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
2 :: Int))

qStartStop
  :: (TT.TensorOptions batchShape TT.Int64 dev)
  => (a -> b)
  -> b
  -> StartStop a
  -> QStartStop dev batchShape b
qStartStop :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop a -> b
f b
def StartStop a
val = case StartStop a
val of
  StartStop a
Start -> b -> QStartStop dev batchShape b
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart b
def
  StartStop a
Stop -> b -> QStartStop dev batchShape b
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop b
def
  Inner a
x -> b -> QStartStop dev batchShape b
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qInner (b -> QStartStop dev batchShape b)
-> b -> QStartStop dev batchShape b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x

instance (Stackable a) => Stackable (QStartStop dev batchShape a) where
  type Stacked (QStartStop dev batchShape a) n = QStartStop dev (n ': batchShape) (Stacked a n)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (QStartStop dev batchShape a)
-> Stacked (QStartStop dev batchShape a) (1 + n)
stack Vector (1 + n) (QStartStop dev batchShape a)
xs = Tensor dev 'Int64 ((1 + n) : batchShape)
-> Stacked a (1 + n)
-> QStartStop dev ((1 + n) : batchShape) (Stacked a (1 + n))
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop Tensor dev 'Int64 ((1 + n) : batchShape)
Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
tags Stacked a (1 + n)
Stacked a (1 + n)
contents
   where
    tags :: Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
tags = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev 'Int64 batchShape)
 -> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
-> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape
qssTag (QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape)
-> Vector Vector (1 + n) (QStartStop dev batchShape a)
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QStartStop dev batchShape a)
Vector (1 + n) (QStartStop dev batchShape a)
xs
    contents :: Stacked a (1 + n)
contents = Vector (1 + n) a -> Stacked a (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) a -> Stacked a (1 + n))
-> Vector (1 + n) a -> Stacked a (1 + n)
forall a b. (a -> b) -> a -> b
$ QStartStop dev batchShape a -> a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> a
qssContent (QStartStop dev batchShape a -> a)
-> Vector Vector (1 + n) (QStartStop dev batchShape a)
-> Vector Vector (1 + n) a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QStartStop dev batchShape a)
Vector (1 + n) (QStartStop dev batchShape a)
xs

instance (Batchable a) => Batchable (QStartStop dev shape a) where
  type Batched (QStartStop dev shape a) = QStartStop dev (1 : shape) (Batched a)
  addBatchDim :: QStartStop dev shape a -> Batched (QStartStop dev shape a)
addBatchDim (QStartStop Tensor dev 'Int64 shape
tag a
content) =
    Tensor dev 'Int64 (1 : shape)
-> Batched a -> QStartStop dev (1 : shape) (Batched a)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev 'Int64 shape
tag) (a -> Batched a
forall a. Batchable a => a -> Batched a
addBatchDim a
content)

-- instance (T.HasTypes a T.Tensor) => T.HasTypes (QStartStop shape a) T.Tensor

-- Slice Encoding
-- ==============

-- types of slice encodings
-- ------------------------

newtype SliceEncodingSparse dev batchShape = SliceEncodingSparse
  {forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2]
getSliceEncodingSparse :: QBoundedList dev TT.Int64 MaxPitches batchShape '[2]}
  deriving (Int -> SliceEncodingSparse dev batchShape -> ShowS
[SliceEncodingSparse dev batchShape] -> ShowS
SliceEncodingSparse dev batchShape -> String
(Int -> SliceEncodingSparse dev batchShape -> ShowS)
-> (SliceEncodingSparse dev batchShape -> String)
-> ([SliceEncodingSparse dev batchShape] -> ShowS)
-> Show (SliceEncodingSparse dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingSparse dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingSparse dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingSparse dev batchShape -> ShowS
showsPrec :: Int -> SliceEncodingSparse dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> String
show :: SliceEncodingSparse dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingSparse dev batchShape] -> ShowS
showList :: [SliceEncodingSparse dev batchShape] -> ShowS
Show, (forall x.
 SliceEncodingSparse dev batchShape
 -> Rep (SliceEncodingSparse dev batchShape) x)
-> (forall x.
    Rep (SliceEncodingSparse dev batchShape) x
    -> SliceEncodingSparse dev batchShape)
-> Generic (SliceEncodingSparse dev batchShape)
forall x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
forall x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
from :: forall x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
to :: forall x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
Generic)
  deriving newtype (SliceEncodingSparse dev batchShape -> ()
(SliceEncodingSparse dev batchShape -> ())
-> NFData (SliceEncodingSparse dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> ()
rnf :: SliceEncodingSparse dev batchShape -> ()
NFData)

instance Stackable (SliceEncodingSparse dev batchShape) where
  type Stacked (SliceEncodingSparse dev batchShape) n = SliceEncodingSparse dev (n ': batchShape)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncodingSparse dev batchShape)
-> Stacked (SliceEncodingSparse dev batchShape) (1 + n)
stack Vector (1 + n) (SliceEncodingSparse dev batchShape)
slices = QBoundedList dev 'Int64 MaxPitches ((1 + n) : batchShape) '[2]
-> SliceEncodingSparse dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev 'Int64 MaxPitches batchShape '[2]
-> SliceEncodingSparse dev batchShape
SliceEncodingSparse (QBoundedList dev 'Int64 MaxPitches ((1 + n) : batchShape) '[2]
 -> SliceEncodingSparse dev ((1 + n) : batchShape))
-> QBoundedList dev 'Int64 MaxPitches ((1 + n) : batchShape) '[2]
-> SliceEncodingSparse dev ((1 + n) : batchShape)
forall a b. (a -> b) -> a -> b
$ Vector (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
     (QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
     (QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
 -> Stacked
      (QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n))
-> Vector
     (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
     (QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n)
forall a b. (a -> b) -> a -> b
$ SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2]
getSliceEncodingSparse (SliceEncodingSparse dev batchShape
 -> QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Vector Vector (1 + n) (SliceEncodingSparse dev batchShape)
-> Vector
     Vector (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (SliceEncodingSparse dev batchShape)
Vector (1 + n) (SliceEncodingSparse dev batchShape)
slices

instance Batchable (SliceEncodingSparse dev shape) where
  type Batched (SliceEncodingSparse dev shape) = SliceEncodingSparse dev (1 ': shape)
  addBatchDim :: SliceEncodingSparse dev shape
-> Batched (SliceEncodingSparse dev shape)
addBatchDim (SliceEncodingSparse QBoundedList dev 'Int64 MaxPitches shape '[2]
slice) = QBoundedList dev 'Int64 MaxPitches (1 : shape) '[2]
-> SliceEncodingSparse dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev 'Int64 MaxPitches batchShape '[2]
-> SliceEncodingSparse dev batchShape
SliceEncodingSparse (QBoundedList dev 'Int64 MaxPitches (1 : shape) '[2]
 -> SliceEncodingSparse dev (1 : shape))
-> QBoundedList dev 'Int64 MaxPitches (1 : shape) '[2]
-> SliceEncodingSparse dev (1 : shape)
forall a b. (a -> b) -> a -> b
$ QBoundedList dev 'Int64 MaxPitches shape '[2]
-> Batched (QBoundedList dev 'Int64 MaxPitches shape '[2])
forall a. Batchable a => a -> Batched a
addBatchDim QBoundedList dev 'Int64 MaxPitches shape '[2]
slice

-- instance T.HasTypes (SliceEncodingSparse shape) T.Tensor

newtype SliceEncodingDense dev batchShape = SliceEncodingDense
  {forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSliceEncodingDense :: QTensor dev (batchShape TT.++ PShape)}
  deriving (Int -> SliceEncodingDense dev batchShape -> ShowS
[SliceEncodingDense dev batchShape] -> ShowS
SliceEncodingDense dev batchShape -> String
(Int -> SliceEncodingDense dev batchShape -> ShowS)
-> (SliceEncodingDense dev batchShape -> String)
-> ([SliceEncodingDense dev batchShape] -> ShowS)
-> Show (SliceEncodingDense dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingDense dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingDense dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingDense dev batchShape -> ShowS
showsPrec :: Int -> SliceEncodingDense dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> String
show :: SliceEncodingDense dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingDense dev batchShape] -> ShowS
showList :: [SliceEncodingDense dev batchShape] -> ShowS
Show, (forall x.
 SliceEncodingDense dev batchShape
 -> Rep (SliceEncodingDense dev batchShape) x)
-> (forall x.
    Rep (SliceEncodingDense dev batchShape) x
    -> SliceEncodingDense dev batchShape)
-> Generic (SliceEncodingDense dev batchShape)
forall x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
forall x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
from :: forall x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
to :: forall x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
Generic)
  deriving newtype (SliceEncodingDense dev batchShape -> ()
(SliceEncodingDense dev batchShape -> ())
-> NFData (SliceEncodingDense dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> ()
rnf :: SliceEncodingDense dev batchShape -> ()
NFData)

instance Stackable (SliceEncodingDense dev batchShape) where
  type Stacked (SliceEncodingDense dev batchShape) n = SliceEncodingDense dev (n ': batchShape)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncodingDense dev batchShape)
-> Stacked (SliceEncodingDense dev batchShape) (1 + n)
stack Vector (1 + n) (SliceEncodingDense dev batchShape)
slices = QTensor dev (((1 + n) : batchShape) ++ PShape)
-> SliceEncodingDense dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense (QTensor dev (((1 + n) : batchShape) ++ PShape)
 -> SliceEncodingDense dev ((1 + n) : batchShape))
-> QTensor dev (((1 + n) : batchShape) ++ PShape)
-> SliceEncodingDense dev ((1 + n) : batchShape)
forall a b. (a -> b) -> a -> b
$ forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType (batchShape ++ '[13, 5]))
 -> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[13, 5])))
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[13, 5]))
-> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[13, 5]))
forall a b. (a -> b) -> a -> b
$ SliceEncodingDense dev batchShape
-> Tensor dev QDType (batchShape ++ '[13, 5])
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSliceEncodingDense (SliceEncodingDense dev batchShape
 -> Tensor dev QDType (batchShape ++ '[13, 5]))
-> Vector Vector (1 + n) (SliceEncodingDense dev batchShape)
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[13, 5]))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (SliceEncodingDense dev batchShape)
Vector (1 + n) (SliceEncodingDense dev batchShape)
slices

instance Batchable (SliceEncodingDense dev batchShape) where
  type Batched (SliceEncodingDense dev batchShape) = SliceEncodingDense dev (1 ': batchShape)
  addBatchDim :: SliceEncodingDense dev batchShape
-> Batched (SliceEncodingDense dev batchShape)
addBatchDim (SliceEncodingDense QTensor dev (batchShape ++ PShape)
slice) = QTensor dev ((1 : batchShape) ++ PShape)
-> SliceEncodingDense dev (1 : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense (QTensor dev ((1 : batchShape) ++ PShape)
 -> SliceEncodingDense dev (1 : batchShape))
-> QTensor dev ((1 : batchShape) ++ PShape)
-> SliceEncodingDense dev (1 : batchShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev (batchShape ++ '[13, 5])
-> Batched (QTensor dev (batchShape ++ '[13, 5]))
forall a. Batchable a => a -> Batched a
addBatchDim QTensor dev (batchShape ++ '[13, 5])
QTensor dev (batchShape ++ PShape)
slice

-- instance T.HasTypes (SliceEncodingDense shape) T.Tensor

-- choose slice encoding type:
-- ---------------------------

type SliceEncoding = SliceEncodingDense

getSlice
  :: forall dev batchShape
   . SliceEncoding dev batchShape
  -> QTensor dev (batchShape TT.++ PShape)
getSlice :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSlice = SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSliceEncodingDense -- . sliceIndex2OneHot

encodePitches
  :: (TT.KnownDevice dev)
  => [SPitch]
  -> SliceEncoding dev '[]
encodePitches :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
encodePitches = [SPitch] -> SliceEncodingDense dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum

-- sliceIndex2OneHot
--   :: forall dev batchShape
--    . ( TT.KnownShape batchShape
--      )
--   => SliceEncodingSparse dev batchShape
--   -> SliceEncodingDense dev batchShape
-- sliceIndex2OneHot (SliceEncodingSparse (QBoundedList mask values)) =
--   SliceEncodingDense $ QBoundedList mask values'
--  where
--   fifthSize = TT.natValI @FifthSize
--   octaveSize = TT.natValI @OctaveSize
--   shape = TT.shapeVal @batchShape
--   hotF = T.toType qDType $ T.oneHot fifthSize $ T.select (-1) 0 $ TT.toDynamic values
--   hotO = T.toType qDType $ T.oneHot octaveSize $ T.select (-1) 1 $ TT.toDynamic values
--   outer = T.einsum "...i,...j->...ij" [hotF, hotO] [1, 0]
--   values' = TT.UnsafeMkTensor $ T.unsqueeze (T.Dim (-3)) outer

-- slice variants
-- --------------

pitch2index
  :: SPitch
  -> [Int]
pitch2index :: SPitch -> [Int]
pitch2index SPitch
p =
  [ Int -> Int -> Int
forall {a}. (Ord a, Num a) => a -> a -> a
clamp Int
fifthSize (SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow)
  , Int -> Int -> Int
forall {a}. (Ord a, Num a) => a -> a -> a
clamp Int
octaveSize (SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow)
  ]
 where
  clamp :: a -> a -> a
clamp a
m a
i = a -> a -> a
forall a. Ord a => a -> a -> a
max a
0 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. Ord a => a -> a -> a
min a
m a
i
  fifthLow :: Int
fifthLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @FifthLow Proxy# FifthLow
forall {k} (a :: k). Proxy# a
proxy#
  octaveLow :: Int
octaveLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @OctaveLow Proxy# OctaveLow
forall {k} (a :: k). Proxy# a
proxy#
  fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
  octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize

pitchesMultiHot
  :: forall dev
   . (TT.KnownDevice dev)
  => HS.HashSet SPitch
  -> QTensor dev PShape
pitchesMultiHot :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
HashSet SPitch -> QTensor dev PShape
pitchesMultiHot HashSet SPitch
ps = Tensor -> Tensor dev QDType '[13, 5]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
out
 where
  out :: Tensor
out =
    Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
      if HashSet SPitch -> Bool
forall a. HashSet a -> Bool
HS.null HashSet SPitch
ps
        then Tensor
zeros
        else Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indices Tensor
values Tensor
zeros
  ~[Tensor]
indices = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose (SPitch -> [Int]
pitch2index (SPitch -> [Int]) -> [SPitch] -> [[Int]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> HashSet SPitch -> [SPitch]
forall a. HashSet a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
F.toList HashSet SPitch
ps)
  values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [HashSet SPitch -> Int
forall a. HashSet a -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
F.length HashSet SPitch
ps] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  zeros :: Tensor
zeros = [Int] -> TensorOptions -> Tensor
T.zeros [Int]
dims (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
  octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
  dims :: [Int]
dims = [Int
fifthSize, Int
octaveSize]

pitchesOneHotSum
  :: forall dev
   . (TT.KnownDevice dev)
  => [SPitch]
  -> SliceEncodingDense dev '[]
pitchesOneHotSum :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum [] = QTensor dev ('[] ++ PShape) -> SliceEncodingDense dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense Tensor dev QDType '[13, 5]
QTensor dev ('[] ++ PShape)
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
pitchesOneHotSum [SPitch]
ps = QTensor dev ('[] ++ PShape) -> SliceEncodingDense dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense QTensor dev PShape
QTensor dev ('[] ++ PShape)
out
 where
  pitches :: [SPitch]
pitches = [SPitch]
ps
  n :: Int
n = [SPitch] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [SPitch]
pitches
  -- maxPitches = TT.natValI @MaxPitches
  mkIndex :: Int -> SPitch -> [Int]
mkIndex Int
i SPitch
pitch = Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: SPitch -> [Int]
pitch2index SPitch
pitch
  indices :: [Tensor]
indices = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose ((Int -> SPitch -> [Int]) -> [Int] -> [SPitch] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SPitch -> [Int]
mkIndex [Int
0 ..] [SPitch]
pitches)
  values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  -- zeros :: QTensor dev (MaxPitches ': PShape)
  zeros :: Tensor
zeros = [Int] -> TensorOptions -> Tensor
T.zeros (Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: forall (shape :: [Natural]). KnownShape shape => [Int]
TT.shapeVal @PShape) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  out :: QTensor dev PShape
  out :: QTensor dev PShape
out = Tensor -> QTensor dev PShape
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev PShape) -> Tensor -> QTensor dev PShape
forall a b. (a -> b) -> a -> b
$ Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> KeepDim -> DType -> Tensor -> Tensor
T.sumDim (Int -> Dim
T.Dim Int
0) KeepDim
T.RemoveDim DType
qDType (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indices Tensor
values (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor
zeros

pitchesOneHots
  :: forall dev
   . (TT.KnownDevice dev)
  => [SPitch]
  -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots [] = QTensor dev ('[] ++ '[MaxPitches])
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ '[1, 13, 5])
-> QBoundedList dev QDType MaxPitches '[] '[1, 13, 5]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList Tensor dev QDType '[MaxPitches]
QTensor dev ('[] ++ '[MaxPitches])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros Tensor dev QDType '[MaxPitches, 1, 13, 5]
Tensor dev QDType (('[] ++ '[MaxPitches]) ++ '[1, 13, 5])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
pitchesOneHots [SPitch]
ps = QTensor dev ('[] ++ '[MaxPitches])
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ '[1, 13, 5])
-> QBoundedList dev QDType MaxPitches '[] '[1, 13, 5]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList Tensor dev QDType '[MaxPitches]
QTensor dev ('[] ++ '[MaxPitches])
mask (Tensor dev QDType '[MaxPitches, 13, 5]
-> Tensor dev QDType '[MaxPitches, 1, 13, 5]
forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape Tensor dev QDType '[MaxPitches, 13, 5]
QTensor dev (MaxPitches : PShape)
out)
 where
  pitches :: [SPitch]
pitches = Int -> [SPitch] -> [SPitch]
forall a. Int -> [a] -> [a]
take Int
maxPitches [SPitch]
ps
  n :: Int
n = [SPitch] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [SPitch]
pitches
  maxPitches :: Int
maxPitches = forall (n :: Natural). KnownNat n => Int
TT.natValI @MaxPitches
  mkIndex :: Int -> SPitch -> [Int]
mkIndex Int
i SPitch
pitch = Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: SPitch -> [Int]
pitch2index SPitch
pitch
  indices :: [Tensor]
indices = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose ((Int -> SPitch -> [Int]) -> [Int] -> [SPitch] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SPitch -> [Int]
mkIndex [Int
0 ..] [SPitch]
pitches)
  values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  zeros :: QTensor '(TT.CPU, 0) (MaxPitches ': PShape)
  zeros :: QTensor '( 'CPU, 0) (MaxPitches : PShape)
zeros = QTensor '( 'CPU, 0) '[MaxPitches, 13, 5]
QTensor '( 'CPU, 0) (MaxPitches : PShape)
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
  out :: QTensor dev (MaxPitches : PShape)
  out :: QTensor dev (MaxPitches : PShape)
out = Tensor -> QTensor dev (MaxPitches : PShape)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev (MaxPitches : PShape))
-> Tensor -> QTensor dev (MaxPitches : PShape)
forall a b. (a -> b) -> a -> b
$ Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indices Tensor
values (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor '( 'CPU, 0) '[MaxPitches, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor '( 'CPU, 0) '[MaxPitches, 13, 5]
QTensor '( 'CPU, 0) (MaxPitches : PShape)
zeros
  mask :: QTensor dev '[MaxPitches]
  mask :: Tensor dev QDType '[MaxPitches]
mask = Tensor -> Tensor dev QDType '[MaxPitches]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[MaxPitches])
-> Tensor -> Tensor dev QDType '[MaxPitches]
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) [[Int] -> TensorOptions -> Tensor
T.ones [Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev, [Int] -> TensorOptions -> Tensor
T.zeros [Int
maxPitches Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev]

pitchesTokens
  :: forall dev
   . (TT.KnownDevice dev)
  => [SPitch]
  -> QBoundedList dev QDType MaxPitches '[] '[PSize] -- SliceEncoding '[]
pitchesTokens :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] '[PSize]
pitchesTokens [SPitch]
ps = [Tensor dev QDType '[18]]
-> QBoundedList dev QDType MaxPitches '[] '[18]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
 TensorOptions innerShape QDType dev,
 TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList (SPitch -> Tensor dev QDType '[18]
mkToken (SPitch -> Tensor dev QDType '[18])
-> [SPitch] -> [Tensor dev QDType '[18]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SPitch]
ps)
 where
  -- todo: batch oneHot
  opts' :: DType -> TensorOptions -> TensorOptions
opts' = DType -> TensorOptions -> TensorOptions
T.withDType
  mkToken :: SPitch -> Tensor dev QDType '[18]
mkToken SPitch
p =
    Tensor -> Tensor dev QDType '[18]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[18])
-> Tensor -> Tensor dev QDType '[18]
forall a b. (a -> b) -> a -> b
$ DType -> Tensor -> Tensor
T.toDType DType
qDType (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) [Int -> Tensor -> Tensor
T.oneHot Int
fifthSize Tensor
f, Int -> Tensor -> Tensor
T.oneHot Int
octaveSize Tensor
o]
   where
    f :: Tensor
f = Int -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => Int -> opt -> Tensor
T.asTensor' (SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
    o :: Tensor
o = Int -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => Int -> opt -> Tensor
T.asTensor' (SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
  fifthLow :: Int
fifthLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @FifthLow Proxy# FifthLow
forall {k} (a :: k). Proxy# a
proxy#
  octaveLow :: Int
octaveLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @OctaveLow Proxy# OctaveLow
forall {k} (a :: k). Proxy# a
proxy#
  fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
  octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize

pitchesIndices
  :: forall dev
   . (TT.KnownDevice dev)
  => [SPitch]
  -> SliceEncodingSparse dev '[]
pitchesIndices :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncodingSparse dev '[]
pitchesIndices [SPitch]
ps = QBoundedList dev 'Int64 MaxPitches '[] '[2]
-> SliceEncodingSparse dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev 'Int64 MaxPitches batchShape '[2]
-> SliceEncodingSparse dev batchShape
SliceEncodingSparse (QBoundedList dev 'Int64 MaxPitches '[] '[2]
 -> SliceEncodingSparse dev '[])
-> QBoundedList dev 'Int64 MaxPitches '[] '[2]
-> SliceEncodingSparse dev '[]
forall a b. (a -> b) -> a -> b
$ [Tensor dev 'Int64 '[2]]
-> QBoundedList dev 'Int64 MaxPitches '[] '[2]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
 TensorOptions innerShape QDType dev,
 TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList (SPitch -> Tensor dev 'Int64 '[2]
mkToken (SPitch -> Tensor dev 'Int64 '[2])
-> [SPitch] -> [Tensor dev 'Int64 '[2]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SPitch]
ps)
 where
  mkIndex :: SPitch -> [Int]
mkIndex = SPitch -> [Int]
pitch2index
  mkToken :: SPitch -> Tensor dev 'Int64 '[2]
mkToken SPitch
p = Tensor -> Tensor dev 'Int64 '[2]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev 'Int64 '[2])
-> Tensor -> Tensor dev 'Int64 '[2]
forall a b. (a -> b) -> a -> b
$ [Int] -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => [Int] -> opt -> Tensor
T.asTensor' (SPitch -> [Int]
mkIndex SPitch
p) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev

encodeSlice
  :: (TT.KnownDevice dev)
  => Notes SPitch
  -> SliceEncoding dev '[]
-- encodeSlice = encodeSliceIndices
encodeSlice :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice (Notes HashSet (Note SPitch)
notes) = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
encodePitches ([SPitch] -> SliceEncoding dev '[])
-> [SPitch] -> SliceEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch) -> [Note SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> HashSet (Note SPitch) -> [Note SPitch]
forall a. HashSet a -> [a]
HS.toList HashSet (Note SPitch)
notes

emptySlice
  :: (TT.KnownDevice dev) => SliceEncoding dev '[]
emptySlice :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
encodePitches []

-- Transition Encoding
-- ===================

data TransitionEncoding dev batchShape = TransitionEncoding
  { forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencPassing :: QBoundedList dev QDType MaxEdges batchShape (2 ': PShape)
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencInner :: QBoundedList dev QDType MaxEdges batchShape (2 ': PShape)
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft :: SliceEncoding dev batchShape -- QBoundedList dev QDType MaxEdges batchShape (1 ': PShape)
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencRight :: SliceEncoding dev batchShape -- QBoundedList dev QDType MaxEdges batchShape (1 ': PShape)
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRoot :: QTensor dev batchShape
  }
  deriving (Int -> TransitionEncoding dev batchShape -> ShowS
[TransitionEncoding dev batchShape] -> ShowS
TransitionEncoding dev batchShape -> String
(Int -> TransitionEncoding dev batchShape -> ShowS)
-> (TransitionEncoding dev batchShape -> String)
-> ([TransitionEncoding dev batchShape] -> ShowS)
-> Show (TransitionEncoding dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> TransitionEncoding dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[TransitionEncoding dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> TransitionEncoding dev batchShape -> ShowS
showsPrec :: Int -> TransitionEncoding dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> String
show :: TransitionEncoding dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[TransitionEncoding dev batchShape] -> ShowS
showList :: [TransitionEncoding dev batchShape] -> ShowS
Show, (forall x.
 TransitionEncoding dev batchShape
 -> Rep (TransitionEncoding dev batchShape) x)
-> (forall x.
    Rep (TransitionEncoding dev batchShape) x
    -> TransitionEncoding dev batchShape)
-> Generic (TransitionEncoding dev batchShape)
forall x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
forall x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
from :: forall x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
to :: forall x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
Generic, TransitionEncoding dev batchShape -> ()
(TransitionEncoding dev batchShape -> ())
-> NFData (TransitionEncoding dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> ()
rnf :: TransitionEncoding dev batchShape -> ()
NFData)

instance Stackable (TransitionEncoding dev batchShape) where
  type
    Stacked (TransitionEncoding dev batchShape) n =
      TransitionEncoding dev (n ': batchShape)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
stack Vector (1 + n) (TransitionEncoding dev batchShape)
xs = QBoundedList
  dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
-> QBoundedList
     dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
-> SliceEncoding dev ((1 + n) : batchShape)
-> SliceEncoding dev ((1 + n) : batchShape)
-> QTensor dev ((1 + n) : batchShape)
-> TransitionEncoding dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> SliceEncoding dev batchShape
-> SliceEncoding dev batchShape
-> QTensor dev batchShape
-> TransitionEncoding dev batchShape
TransitionEncoding QBoundedList
  dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
Stacked
  (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
passing QBoundedList
  dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
Stacked
  (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
inner SliceEncoding dev ((1 + n) : batchShape)
Stacked (SliceEncoding dev batchShape) (1 + n)
left SliceEncoding dev ((1 + n) : batchShape)
Stacked (SliceEncoding dev batchShape) (1 + n)
right QTensor dev ((1 + n) : batchShape)
Tensor dev QDType (Insert 0 (1 + n) batchShape)
root
   where
    passing :: Stacked
  (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
passing = Vector
  (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
 -> Stacked
      (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
      (1 + n))
-> Vector
     (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencPassing (TransitionEncoding dev batchShape
 -> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector
     Vector
     (1 + n)
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
    inner :: Stacked
  (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
inner = Vector
  (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
 -> Stacked
      (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
      (1 + n))
-> Vector
     (1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencInner (TransitionEncoding dev batchShape
 -> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector
     Vector
     (1 + n)
     (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
    left :: Stacked (SliceEncoding dev batchShape) (1 + n)
left = Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (SliceEncoding dev batchShape)
 -> Stacked (SliceEncoding dev batchShape) (1 + n))
-> Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft (TransitionEncoding dev batchShape -> SliceEncoding dev batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector Vector (1 + n) (SliceEncoding dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
    right :: Stacked (SliceEncoding dev batchShape) (1 + n)
right = Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (SliceEncoding dev batchShape)
 -> Stacked (SliceEncoding dev batchShape) (1 + n))
-> Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencRight (TransitionEncoding dev batchShape -> SliceEncoding dev batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector Vector (1 + n) (SliceEncoding dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
    root :: Tensor dev QDType (Insert 0 (1 + n) batchShape)
root = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType batchShape)
 -> Tensor dev QDType (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev QDType batchShape)
-> Tensor dev QDType (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape -> Tensor dev QDType batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRoot (TransitionEncoding dev batchShape -> Tensor dev QDType batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector (1 + n) (Tensor dev QDType batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs

instance Batchable (TransitionEncoding dev shape) where
  type Batched (TransitionEncoding dev shape) = TransitionEncoding dev (1 : shape)
  addBatchDim :: TransitionEncoding dev shape
-> Batched (TransitionEncoding dev shape)
addBatchDim (TransitionEncoding QBoundedList dev QDType MaxPitches shape (2 : PShape)
p QBoundedList dev QDType MaxPitches shape (2 : PShape)
i SliceEncoding dev shape
l SliceEncoding dev shape
r QTensor dev shape
rt) =
    QBoundedList dev QDType MaxPitches (1 : shape) (2 : PShape)
-> QBoundedList dev QDType MaxPitches (1 : shape) (2 : PShape)
-> SliceEncoding dev (1 : shape)
-> SliceEncoding dev (1 : shape)
-> QTensor dev (1 : shape)
-> TransitionEncoding dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> SliceEncoding dev batchShape
-> SliceEncoding dev batchShape
-> QTensor dev batchShape
-> TransitionEncoding dev batchShape
TransitionEncoding
      (QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
-> Batched (QBoundedList dev QDType MaxPitches shape '[2, 13, 5])
forall a. Batchable a => a -> Batched a
addBatchDim QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
QBoundedList dev QDType MaxPitches shape (2 : PShape)
p)
      (QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
-> Batched (QBoundedList dev QDType MaxPitches shape '[2, 13, 5])
forall a. Batchable a => a -> Batched a
addBatchDim QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
QBoundedList dev QDType MaxPitches shape (2 : PShape)
i)
      (SliceEncoding dev shape -> Batched (SliceEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim SliceEncoding dev shape
l)
      (SliceEncoding dev shape -> Batched (SliceEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim SliceEncoding dev shape
r)
      (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 QTensor dev shape
rt)

edgesMultiHot
  :: forall dev
   . (TT.KnownDevice dev)
  => HS.HashSet (InnerEdge SPitch)
  -> QTensor dev EShape'
edgesMultiHot :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
HashSet (InnerEdge SPitch) -> QTensor dev EShape'
edgesMultiHot HashSet (InnerEdge SPitch)
es = Tensor -> Tensor dev QDType '[FakeSize, 36]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
out
 where
  out :: Tensor
out =
    Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
      if HashSet (InnerEdge SPitch) -> Bool
forall a. HashSet a -> Bool
HS.null HashSet (InnerEdge SPitch)
es
        then Tensor
zeros
        else Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indexTensors Tensor
values Tensor
zeros
  edge2index :: InnerEdge SPitch -> [Int]
edge2index (Note SPitch
p1 String
_, Note SPitch
p2 String
_) =
    SPitch -> [Int]
pitch2index SPitch
p1
      [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ SPitch -> [Int]
pitch2index SPitch
p2
  indices :: [[Int]]
indices = InnerEdge SPitch -> [Int]
edge2index (InnerEdge SPitch -> [Int]) -> [InnerEdge SPitch] -> [[Int]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> HashSet (InnerEdge SPitch) -> [InnerEdge SPitch]
forall a. HashSet a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
F.toList HashSet (InnerEdge SPitch)
es
  ~[Tensor]
indexTensors = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose [[Int]]
indices
  values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [HashSet (InnerEdge SPitch) -> Int
forall a. HashSet a -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
F.length HashSet (InnerEdge SPitch)
es] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  zeros :: Tensor
zeros = [Int] -> TensorOptions -> Tensor
T.zeros [Int]
dims (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
  fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
  octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
  dims :: [Int]
dims = [Int
fifthSize, Int
octaveSize, Int
fifthSize, Int
octaveSize]

edgesOneHots
  :: forall dev
   . (TT.KnownDevice dev)
  => [InnerEdge SPitch]
  -> QBoundedList dev QDType MaxEdges '[] (2 ': PShape)
edgesOneHots :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
edgesOneHots [InnerEdge SPitch]
es = QTensor dev ('[] ++ '[MaxPitches])
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (2 : PShape))
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (batchShape :: [Natural])
       (innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList QTensor dev ('[] ++ '[MaxPitches])
mask (Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (2 : PShape))
 -> QBoundedList dev QDType MaxPitches '[] (2 : PShape))
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (2 : PShape))
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall a b. (a -> b) -> a -> b
$ forall (dim :: Natural) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Natural) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
TT.cat @1 (Tensor dev QDType '[MaxPitches, 1, 13, 5]
Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots1 Tensor dev QDType '[MaxPitches, 1, 13, 5]
-> HList '[Tensor dev QDType '[MaxPitches, 1, 13, 5]]
-> HList
     '[Tensor dev QDType '[MaxPitches, 1, 13, 5],
       Tensor dev QDType '[MaxPitches, 1, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. Tensor dev QDType '[MaxPitches, 1, 13, 5]
Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots2 Tensor dev QDType '[MaxPitches, 1, 13, 5]
-> HList '[] -> HList '[Tensor dev QDType '[MaxPitches, 1, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil)
 where
  (QBoundedList QTensor dev ('[] ++ '[MaxPitches])
mask Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots1) = forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots @dev ([SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape))
-> [SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
forall a b. (a -> b) -> a -> b
$ (Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch)
-> (InnerEdge SPitch -> Note SPitch) -> InnerEdge SPitch -> SPitch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InnerEdge SPitch -> Note SPitch
forall a b. (a, b) -> a
fst) (InnerEdge SPitch -> SPitch) -> [InnerEdge SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [InnerEdge SPitch]
es
  (QBoundedList QTensor dev ('[] ++ '[MaxPitches])
_ Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots2) = forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots @dev ([SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape))
-> [SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
forall a b. (a -> b) -> a -> b
$ (Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch)
-> (InnerEdge SPitch -> Note SPitch) -> InnerEdge SPitch -> SPitch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InnerEdge SPitch -> Note SPitch
forall a b. (a, b) -> b
snd) (InnerEdge SPitch -> SPitch) -> [InnerEdge SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [InnerEdge SPitch]
es

edgesTokens
  :: forall dev
   . (TT.KnownDevice dev)
  => [InnerEdge SPitch]
  -> QBoundedList dev QDType MaxEdges '[] '[ESize] -- Maybe (QTensor EShape)
edgesTokens :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] '[ESize]
edgesTokens [InnerEdge SPitch]
es = [Tensor dev QDType '[36]]
-> QBoundedList dev QDType MaxPitches '[] '[36]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
       (maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
 TensorOptions innerShape QDType dev,
 TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList (InnerEdge SPitch -> Tensor dev QDType '[36]
mkToken (InnerEdge SPitch -> Tensor dev QDType '[36])
-> [InnerEdge SPitch] -> [Tensor dev QDType '[36]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [InnerEdge SPitch]
es)
 where
  -- todo: batch oneHot
  mkToken :: InnerEdge SPitch -> Tensor dev QDType '[36]
mkToken (Note SPitch
p1 String
_, Note SPitch
p2 String
_) =
    Tensor -> Tensor dev QDType '[36]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[36])
-> Tensor -> Tensor dev QDType '[36]
forall a b. (a -> b) -> a -> b
$!
      forall (dev :: (DeviceType, Natural)) a.
(KnownDevice dev, HasTypes a Tensor) =>
a -> a
toOpts @dev (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
        Dim -> [Tensor] -> Tensor
T.cat
          (Int -> Dim
T.Dim Int
0)
          [ Int -> Tensor -> Tensor
T.oneHot Int
fifthSize Tensor
f1
          , Int -> Tensor -> Tensor
T.oneHot Int
octaveSize Tensor
o1
          , Int -> Tensor -> Tensor
T.oneHot Int
fifthSize Tensor
f2
          , Int -> Tensor -> Tensor
T.oneHot Int
octaveSize Tensor
o2
          ]
   where
    toIndex :: Int -> Tensor
toIndex Int
i = Int -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => Int -> opt -> Tensor
T.asTensor' Int
i (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
    f1 :: Tensor
f1 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow
    o1 :: Tensor
o1 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow
    f2 :: Tensor
f2 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow
    o2 :: Tensor
o2 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow
  fifthLow :: Int
fifthLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @FifthLow Proxy# FifthLow
forall {k} (a :: k). Proxy# a
proxy#
  octaveLow :: Int
octaveLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @OctaveLow Proxy# OctaveLow
forall {k} (a :: k). Proxy# a
proxy#
  fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
  octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize

encodeTransition
  :: (TT.KnownDevice dev)
  => Edges SPitch
  -> TransitionEncoding dev '[]
encodeTransition :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
reg MultiSet (InnerEdge SPitch)
pass) =
  TransitionEncoding
    { trencPassing :: QBoundedList dev QDType MaxPitches '[] (2 : PShape)
trencPassing = [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
edgesOneHots ([InnerEdge SPitch]
 -> QBoundedList dev QDType MaxPitches '[] (2 : PShape))
-> [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall a b. (a -> b) -> a -> b
$ MultiSet (InnerEdge SPitch) -> [InnerEdge SPitch]
forall a. MultiSet a -> [a]
MS.toList MultiSet (InnerEdge SPitch)
pass
    , -- , trencPassing = edgesOneHot $ MS.toSet pass
      trencInner :: QBoundedList dev QDType MaxPitches '[] (2 : PShape)
trencInner = [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
edgesOneHots ([InnerEdge SPitch]
 -> QBoundedList dev QDType MaxPitches '[] (2 : PShape))
-> [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall a b. (a -> b) -> a -> b
$ ((StartStop (Note SPitch), StartStop (Note SPitch))
 -> Maybe (InnerEdge SPitch))
-> [InnerEdge SPitch]
forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (InnerEdge SPitch)
forall {a} {b}. (StartStop a, StartStop b) -> Maybe (a, b)
getInner
    , -- , trencInner = edgesOneHot $ HS.fromList $ getEdges getInner
      trencLeft :: SliceEncoding dev '[]
trencLeft = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum ([SPitch] -> SliceEncoding dev '[])
-> [SPitch] -> SliceEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch) -> [Note SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((StartStop (Note SPitch), StartStop (Note SPitch))
 -> Maybe (Note SPitch))
-> [Note SPitch]
forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (Note SPitch)
forall {a} {a}. (StartStop a, StartStop a) -> Maybe a
getLeft
    , trencRight :: SliceEncoding dev '[]
trencRight = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum ([SPitch] -> SliceEncoding dev '[])
-> [SPitch] -> SliceEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch) -> [Note SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((StartStop (Note SPitch), StartStop (Note SPitch))
 -> Maybe (Note SPitch))
-> [Note SPitch]
forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (Note SPitch)
forall {a} {a}. (StartStop a, StartStop a) -> Maybe a
getRight
    , trencRoot :: QTensor dev '[]
trencRoot = if (StartStop (Note SPitch), StartStop (Note SPitch))
-> HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
-> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
HS.member (StartStop (Note SPitch)
forall a. StartStop a
Start, StartStop (Note SPitch)
forall a. StartStop a
Stop) HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
reg then QTensor dev '[]
isRoot else QTensor dev '[]
isNoRoot
    }
 where
  regulars :: [(StartStop (Note SPitch), StartStop (Note SPitch))]
regulars = HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
-> [(StartStop (Note SPitch), StartStop (Note SPitch))]
forall a. HashSet a -> [a]
HS.toList HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
reg
  getEdges :: (Hashable a) => (Edge SPitch -> Maybe a) -> [a]
  getEdges :: forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a
f = ((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [(StartStop (Note SPitch), StartStop (Note SPitch))] -> [a]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a
f [(StartStop (Note SPitch), StartStop (Note SPitch))]
regulars
  getInner :: (StartStop a, StartStop b) -> Maybe (a, b)
getInner (Inner a
a, Inner b
b) = (a, b) -> Maybe (a, b)
forall a. a -> Maybe a
Just (a
a, b
b)
  getInner (StartStop a, StartStop b)
_ = Maybe (a, b)
forall a. Maybe a
Nothing
  getLeft :: (StartStop a, StartStop a) -> Maybe a
getLeft (StartStop a
Start, Inner a
b) = a -> Maybe a
forall a. a -> Maybe a
Just a
b
  getLeft (StartStop a, StartStop a)
_ = Maybe a
forall a. Maybe a
Nothing
  getRight :: (StartStop a, StartStop a) -> Maybe a
getRight (Inner a
a, StartStop a
Stop) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
  getRight (StartStop a, StartStop a)
_ = Maybe a
forall a. Maybe a
Nothing
  isRoot :: QTensor dev '[]
isRoot = QTensor dev '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.ones
  isNoRoot :: QTensor dev '[]
isNoRoot = QTensor dev '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros

emptyTransition
  :: (TT.KnownDevice dev) => TransitionEncoding dev '[]
emptyTransition :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition = Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges SPitch -> TransitionEncoding dev '[])
-> Edges SPitch -> TransitionEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
-> MultiSet (InnerEdge SPitch) -> Edges SPitch
forall n. HashSet (Edge n) -> MultiSet (InnerEdge n) -> Edges n
Edges HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
forall a. HashSet a
HS.empty MultiSet (InnerEdge SPitch)
forall a. MultiSet a
MS.empty

-- Action Encoding
-- ---------------

data ActionTop dev batchShape = ActionTop
  { forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSl :: !(QStartStop dev batchShape (SliceEncoding dev batchShape))
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> TransitionEncoding dev batchShape
atopT1 :: !(TransitionEncoding dev batchShape)
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
atopSm :: !(QMaybe dev batchShape (SliceEncoding dev batchShape))
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
atopT2 :: !(QMaybe dev batchShape (TransitionEncoding dev batchShape))
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSr :: !(QStartStop dev batchShape (SliceEncoding dev batchShape))
  }
  deriving (Int -> ActionTop dev batchShape -> ShowS
[ActionTop dev batchShape] -> ShowS
ActionTop dev batchShape -> String
(Int -> ActionTop dev batchShape -> ShowS)
-> (ActionTop dev batchShape -> String)
-> ([ActionTop dev batchShape] -> ShowS)
-> Show (ActionTop dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionTop dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionTop dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionTop dev batchShape -> ShowS
showsPrec :: Int -> ActionTop dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> String
show :: ActionTop dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionTop dev batchShape] -> ShowS
showList :: [ActionTop dev batchShape] -> ShowS
Show, (forall x.
 ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x)
-> (forall x.
    Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape)
-> Generic (ActionTop dev batchShape)
forall x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
forall x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
from :: forall x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
to :: forall x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
Generic, ActionTop dev batchShape -> ()
(ActionTop dev batchShape -> ())
-> NFData (ActionTop dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> ()
rnf :: ActionTop dev batchShape -> ()
NFData)

instance Stackable (ActionTop dev batchShape) where
  type Stacked (ActionTop dev batchShape) n = ActionTop dev (n ': batchShape)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
stack Vector (1 + n) (ActionTop dev batchShape)
xs = QStartStop
  dev
  ((1 + n) : batchShape)
  (SliceEncoding dev ((1 + n) : batchShape))
-> TransitionEncoding dev ((1 + n) : batchShape)
-> QMaybe
     dev
     ((1 + n) : batchShape)
     (SliceEncoding dev ((1 + n) : batchShape))
-> QMaybe
     dev
     ((1 + n) : batchShape)
     (TransitionEncoding dev ((1 + n) : batchShape))
-> QStartStop
     dev
     ((1 + n) : batchShape)
     (SliceEncoding dev ((1 + n) : batchShape))
-> ActionTop dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop QStartStop
  dev
  ((1 + n) : batchShape)
  (SliceEncoding dev ((1 + n) : batchShape))
Stacked
  (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sl TransitionEncoding dev ((1 + n) : batchShape)
Stacked (TransitionEncoding dev batchShape) (1 + n)
t1 QMaybe
  dev
  ((1 + n) : batchShape)
  (SliceEncoding dev ((1 + n) : batchShape))
Stacked
  (QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sm QMaybe
  dev
  ((1 + n) : batchShape)
  (TransitionEncoding dev ((1 + n) : batchShape))
Stacked
  (QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
t2 QStartStop
  dev
  ((1 + n) : batchShape)
  (SliceEncoding dev ((1 + n) : batchShape))
Stacked
  (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sr
   where
    sl :: Stacked
  (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sl = Vector
  (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
 -> Stacked
      (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n))
-> Vector
     (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSl (ActionTop dev batchShape
 -> QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
     Vector
     (1 + n)
     (QStartStop dev batchShape (SliceEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
    t1 :: Stacked (TransitionEncoding dev batchShape) (1 + n)
t1 = Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (TransitionEncoding dev batchShape)
 -> Stacked (TransitionEncoding dev batchShape) (1 + n))
-> Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape -> TransitionEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> TransitionEncoding dev batchShape
atopT1 (ActionTop dev batchShape -> TransitionEncoding dev batchShape)
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
    sm :: Stacked
  (QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sm = Vector
  (1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
 -> Stacked
      (QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n))
-> Vector
     (1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
atopSm (ActionTop dev batchShape
 -> QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
     Vector
     (1 + n)
     (QMaybe dev batchShape (SliceEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
    t2 :: Stacked
  (QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
t2 = Vector
  (1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
     (QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
     (QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
 -> Stacked
      (QMaybe dev batchShape (TransitionEncoding dev batchShape))
      (1 + n))
-> Vector
     (1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
     (QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
atopT2 (ActionTop dev batchShape
 -> QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
     Vector
     (1 + n)
     (QMaybe dev batchShape (TransitionEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
    sr :: Stacked
  (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sr = Vector
  (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
  (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
   (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
 -> Stacked
      (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n))
-> Vector
     (1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
     (QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSr (ActionTop dev batchShape
 -> QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
     Vector
     (1 + n)
     (QStartStop dev batchShape (SliceEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs

instance Batchable (ActionTop dev shape) where
  type Batched (ActionTop dev shape) = ActionTop dev (1 : shape)
  addBatchDim :: ActionTop dev shape -> Batched (ActionTop dev shape)
addBatchDim (ActionTop QStartStop dev shape (SliceEncoding dev shape)
sl TransitionEncoding dev shape
t1 QMaybe dev shape (SliceEncoding dev shape)
sm QMaybe dev shape (TransitionEncoding dev shape)
t2 QStartStop dev shape (SliceEncoding dev shape)
sr) =
    QStartStop dev (1 : shape) (SliceEncoding dev (1 : shape))
-> TransitionEncoding dev (1 : shape)
-> QMaybe dev (1 : shape) (SliceEncoding dev (1 : shape))
-> QMaybe dev (1 : shape) (TransitionEncoding dev (1 : shape))
-> QStartStop dev (1 : shape) (SliceEncoding dev (1 : shape))
-> ActionTop dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop
      (QStartStop dev shape (SliceEncoding dev shape)
-> Batched (QStartStop dev shape (SliceEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QStartStop dev shape (SliceEncoding dev shape)
sl)
      (TransitionEncoding dev shape
-> Batched (TransitionEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim TransitionEncoding dev shape
t1)
      (QMaybe dev shape (SliceEncoding dev shape)
-> Batched (QMaybe dev shape (SliceEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QMaybe dev shape (SliceEncoding dev shape)
sm)
      (QMaybe dev shape (TransitionEncoding dev shape)
-> Batched (QMaybe dev shape (TransitionEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QMaybe dev shape (TransitionEncoding dev shape)
t2)
      (QStartStop dev shape (SliceEncoding dev shape)
-> Batched (QStartStop dev shape (SliceEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QStartStop dev shape (SliceEncoding dev shape)
sr)

data ActionEncoding dev batchShape = ActionEncoding
  { forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ActionTop dev batchShape
actionEncodingTop :: !(ActionTop dev batchShape) -- (Either (SingleTop batchShape) (DoubleTop batchShape))
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape
actionEncodingOp :: !(TT.Tensor dev 'TT.Int64 batchShape) -- !(Leftmost () () ())
  }
  deriving (Int -> ActionEncoding dev batchShape -> ShowS
[ActionEncoding dev batchShape] -> ShowS
ActionEncoding dev batchShape -> String
(Int -> ActionEncoding dev batchShape -> ShowS)
-> (ActionEncoding dev batchShape -> String)
-> ([ActionEncoding dev batchShape] -> ShowS)
-> Show (ActionEncoding dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionEncoding dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionEncoding dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionEncoding dev batchShape -> ShowS
showsPrec :: Int -> ActionEncoding dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> String
show :: ActionEncoding dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionEncoding dev batchShape] -> ShowS
showList :: [ActionEncoding dev batchShape] -> ShowS
Show, (forall x.
 ActionEncoding dev batchShape
 -> Rep (ActionEncoding dev batchShape) x)
-> (forall x.
    Rep (ActionEncoding dev batchShape) x
    -> ActionEncoding dev batchShape)
-> Generic (ActionEncoding dev batchShape)
forall x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
forall x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
from :: forall x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
to :: forall x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
Generic, ActionEncoding dev batchShape -> ()
(ActionEncoding dev batchShape -> ())
-> NFData (ActionEncoding dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ()
rnf :: ActionEncoding dev batchShape -> ()
NFData)

instance Stackable (ActionEncoding dev batchShape) where
  type Stacked (ActionEncoding dev batchShape) n = ActionEncoding dev (n ': batchShape)
  stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionEncoding dev batchShape)
-> Stacked (ActionEncoding dev batchShape) (1 + n)
stack Vector (1 + n) (ActionEncoding dev batchShape)
xs = ActionTop dev ((1 + n) : batchShape)
-> Tensor dev 'Int64 ((1 + n) : batchShape)
-> ActionEncoding dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding ActionTop dev ((1 + n) : batchShape)
Stacked (ActionTop dev batchShape) (1 + n)
tops Tensor dev 'Int64 ((1 + n) : batchShape)
Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
ops
   where
    tops :: Stacked (ActionTop dev batchShape) (1 + n)
tops = Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (ActionTop dev batchShape)
 -> Stacked (ActionTop dev batchShape) (1 + n))
-> Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionEncoding dev batchShape -> ActionTop dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ActionTop dev batchShape
actionEncodingTop (ActionEncoding dev batchShape -> ActionTop dev batchShape)
-> Vector Vector (1 + n) (ActionEncoding dev batchShape)
-> Vector Vector (1 + n) (ActionTop dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionEncoding dev batchShape)
Vector (1 + n) (ActionEncoding dev batchShape)
xs
    ops :: Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
ops = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev 'Int64 batchShape)
 -> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
-> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape
actionEncodingOp (ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape)
-> Vector Vector (1 + n) (ActionEncoding dev batchShape)
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionEncoding dev batchShape)
Vector (1 + n) (ActionEncoding dev batchShape)
xs

instance Batchable (ActionEncoding dev shape) where
  type Batched (ActionEncoding dev shape) = ActionEncoding dev (1 : shape)
  addBatchDim :: ActionEncoding dev shape -> Batched (ActionEncoding dev shape)
addBatchDim (ActionEncoding ActionTop dev shape
top Tensor dev 'Int64 shape
op) = ActionTop dev (1 : shape)
-> Tensor dev 'Int64 (1 : shape) -> ActionEncoding dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding (ActionTop dev shape -> Batched (ActionTop dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim ActionTop dev shape
top) (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev 'Int64 shape
op)

-- instance T.HasTypes (ActionEncoding shape) T.Tensor

encodePVAction
  :: (TT.KnownDevice dev)
  => PVAction
  -> ActionEncoding dev '[]
encodePVAction :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction (Left (ActionSingle SingleParent (Notes SPitch) (Edges SPitch)
top LeftmostSingle (Split SPitch) (Freeze SPitch)
action)) = ActionTop dev '[]
-> Tensor dev 'Int64 '[] -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding ActionTop dev '[]
encTop Tensor dev 'Int64 '[]
encAction
 where
  (SingleParent StartStop (Notes SPitch)
sl Edges SPitch
t StartStop (Notes SPitch)
sr) = SingleParent (Notes SPitch) (Edges SPitch)
top
  encTop :: ActionTop dev '[]
encTop =
    QStartStop dev '[] (SliceEncoding dev '[])
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (SliceEncoding dev '[])
-> QMaybe dev '[] (TransitionEncoding dev '[])
-> QStartStop dev '[] (SliceEncoding dev '[])
-> ActionTop dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop
      ((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sl)
      (Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition Edges SPitch
t)
      (SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing (SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[]))
-> SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice)
      (TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing (TransitionEncoding dev '[]
 -> QMaybe dev '[] (TransitionEncoding dev '[]))
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition)
      ((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sr)
  encAction :: Tensor dev 'Int64 '[]
encAction = case LeftmostSingle (Split SPitch) (Freeze SPitch)
action of
    LMSingleFreeze Freeze SPitch
_freeze -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
0 :: Int) --  LMFreezeOnly ()
    LMSingleSplit Split SPitch
_split -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
1 :: Int) -- LMSplitOnly ()
encodePVAction (Right (ActionDouble DoubleParent (Notes SPitch) (Edges SPitch)
top LeftmostDouble (Split SPitch) (Freeze SPitch) (Spread SPitch)
action)) = ActionTop dev '[]
-> Tensor dev 'Int64 '[] -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding ActionTop dev '[]
encTop Tensor dev 'Int64 '[]
encAction
 where
  (DoubleParent StartStop (Notes SPitch)
sl Edges SPitch
t1 Notes SPitch
sm Edges SPitch
t2 StartStop (Notes SPitch)
sr) = DoubleParent (Notes SPitch) (Edges SPitch)
top
  encTop :: ActionTop dev '[]
encTop =
    QStartStop dev '[] (SliceEncoding dev '[])
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (SliceEncoding dev '[])
-> QMaybe dev '[] (TransitionEncoding dev '[])
-> QStartStop dev '[] (SliceEncoding dev '[])
-> ActionTop dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop
      ((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sl)
      (Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition Edges SPitch
t1)
      (SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust (SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[]))
-> SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice Notes SPitch
sm)
      (TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust (TransitionEncoding dev '[]
 -> QMaybe dev '[] (TransitionEncoding dev '[]))
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition Edges SPitch
t2)
      ((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sr)

  encAction :: Tensor dev 'Int64 '[]
encAction = case LeftmostDouble (Split SPitch) (Freeze SPitch) (Spread SPitch)
action of
    LMDoubleFreezeLeft Freeze SPitch
_freeze -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
2 :: Int) -- LMFreezeLeft ()
    LMDoubleSpread Spread SPitch
_spread -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
3 :: Int) -- LMSpread ()
    LMDoubleSplitLeft Split SPitch
_split -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
4 :: Int) -- LMSplitLeft ()
    LMDoubleSplitRight Split SPitch
_split -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
5 :: Int) -- LMSplitRight ()

-- State Encoding
-- --------------

data StateEncoding dev = StateEncoding
  { forall (dev :: (DeviceType, Natural)).
StateEncoding dev -> QStartStop dev '[] (SliceEncoding dev '[])
stateEncodingMid :: !(QStartStop dev '[] (SliceEncoding dev '[]))
  , forall (dev :: (DeviceType, Natural)).
StateEncoding dev
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
stateEncodingFrozen :: !(QMaybe dev '[] (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
  , forall (dev :: (DeviceType, Natural)).
StateEncoding dev
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
stateEncodingOpen :: !(QMaybe dev '[] (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
  }
  deriving (Int -> StateEncoding dev -> ShowS
[StateEncoding dev] -> ShowS
StateEncoding dev -> String
(Int -> StateEncoding dev -> ShowS)
-> (StateEncoding dev -> String)
-> ([StateEncoding dev] -> ShowS)
-> Show (StateEncoding dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)).
Int -> StateEncoding dev -> ShowS
forall (dev :: (DeviceType, Natural)). [StateEncoding dev] -> ShowS
forall (dev :: (DeviceType, Natural)). StateEncoding dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)).
Int -> StateEncoding dev -> ShowS
showsPrec :: Int -> StateEncoding dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)). StateEncoding dev -> String
show :: StateEncoding dev -> String
$cshowList :: forall (dev :: (DeviceType, Natural)). [StateEncoding dev] -> ShowS
showList :: [StateEncoding dev] -> ShowS
Show, (forall x. StateEncoding dev -> Rep (StateEncoding dev) x)
-> (forall x. Rep (StateEncoding dev) x -> StateEncoding dev)
-> Generic (StateEncoding dev)
forall x. Rep (StateEncoding dev) x -> StateEncoding dev
forall x. StateEncoding dev -> Rep (StateEncoding dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) x.
Rep (StateEncoding dev) x -> StateEncoding dev
forall (dev :: (DeviceType, Natural)) x.
StateEncoding dev -> Rep (StateEncoding dev) x
$cfrom :: forall (dev :: (DeviceType, Natural)) x.
StateEncoding dev -> Rep (StateEncoding dev) x
from :: forall x. StateEncoding dev -> Rep (StateEncoding dev) x
$cto :: forall (dev :: (DeviceType, Natural)) x.
Rep (StateEncoding dev) x -> StateEncoding dev
to :: forall x. Rep (StateEncoding dev) x -> StateEncoding dev
Generic, StateEncoding dev -> ()
(StateEncoding dev -> ()) -> NFData (StateEncoding dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)). StateEncoding dev -> ()
$crnf :: forall (dev :: (DeviceType, Natural)). StateEncoding dev -> ()
rnf :: StateEncoding dev -> ()
NFData)

getFrozen
  :: forall dev t
   . (Foldable t, TT.KnownDevice dev)
  => Path (Maybe (t (Edge SPitch))) (Notes SPitch)
  -> (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen :: forall (dev :: (DeviceType, Natural)) (t :: Type -> Type).
(Foldable t, KnownDevice dev) =>
Path
  (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
  (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen Path
  (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
  (Notes SPitch)
frozen = ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]]
trEncs, [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs)
 where
  ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))]
trs, [StartStop (Notes SPitch)]
slcs) = [(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
  StartStop (Notes SPitch))]
-> ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))],
    [StartStop (Notes SPitch)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
   StartStop (Notes SPitch))]
 -> ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))],
     [StartStop (Notes SPitch)]))
-> [(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
     StartStop (Notes SPitch))]
-> ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))],
    [StartStop (Notes SPitch)])
forall a b. (a -> b) -> a -> b
$ Int
-> (Notes SPitch -> StartStop (Notes SPitch))
-> StartStop (Notes SPitch)
-> Path
     (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
     (Notes SPitch)
-> [(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
     StartStop (Notes SPitch))]
forall b b' a. Int -> (b -> b') -> b' -> Path a b -> [(a, b')]
pathTake Int
8 Notes SPitch -> StartStop (Notes SPitch)
forall a. a -> StartStop a
Inner StartStop (Notes SPitch)
forall a. StartStop a
Start Path
  (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
  (Notes SPitch)
frozen
  trEncs :: [TransitionEncoding dev '[]]
trEncs = Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges SPitch -> TransitionEncoding dev '[])
-> (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
    -> Edges SPitch)
-> Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
-> TransitionEncoding dev '[]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
-> Edges SPitch
forall (t :: Type -> Type) n.
(Foldable t, Ord n, Hashable n) =>
Maybe (t (Edge n)) -> Edges n
pvThaw (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
 -> TransitionEncoding dev '[])
-> [Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))]
-> [TransitionEncoding dev '[]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))]
trs
  slcEncs :: [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs = (Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice (StartStop (Notes SPitch)
 -> QStartStop dev '[] (SliceEncoding dev '[]))
-> [StartStop (Notes SPitch)]
-> [QStartStop dev '[] (SliceEncoding dev '[])]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [StartStop (Notes SPitch)]
slcs

getOpen
  :: (TT.KnownDevice dev)
  => Path (Edges SPitch) (Notes SPitch)
  -> (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen Path (Edges SPitch) (Notes SPitch)
open = ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]]
trEncs, [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs)
 where
  ([Edges SPitch]
trs, [StartStop (Notes SPitch)]
slcs) = [(Edges SPitch, StartStop (Notes SPitch))]
-> ([Edges SPitch], [StartStop (Notes SPitch)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Edges SPitch, StartStop (Notes SPitch))]
 -> ([Edges SPitch], [StartStop (Notes SPitch)]))
-> [(Edges SPitch, StartStop (Notes SPitch))]
-> ([Edges SPitch], [StartStop (Notes SPitch)])
forall a b. (a -> b) -> a -> b
$ Int
-> (Notes SPitch -> StartStop (Notes SPitch))
-> StartStop (Notes SPitch)
-> Path (Edges SPitch) (Notes SPitch)
-> [(Edges SPitch, StartStop (Notes SPitch))]
forall b b' a. Int -> (b -> b') -> b' -> Path a b -> [(a, b')]
pathTake Int
8 Notes SPitch -> StartStop (Notes SPitch)
forall a. a -> StartStop a
Inner StartStop (Notes SPitch)
forall a. StartStop a
Stop Path (Edges SPitch) (Notes SPitch)
open
  trEncs :: [TransitionEncoding dev '[]]
trEncs = Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges SPitch -> TransitionEncoding dev '[])
-> [Edges SPitch] -> [TransitionEncoding dev '[]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Edges SPitch]
trs
  slcEncs :: [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs = (Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
       b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice (StartStop (Notes SPitch)
 -> QStartStop dev '[] (SliceEncoding dev '[]))
-> [StartStop (Notes SPitch)]
-> [QStartStop dev '[] (SliceEncoding dev '[])]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [StartStop (Notes SPitch)]
slcs

encodePVState
  :: (TT.KnownDevice dev)
  => PVState
  -> StateEncoding dev
encodePVState :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> StateEncoding dev
encodePVState (GSFrozen Path
  (Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
  (Notes SPitch)
frozen) =
  QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
StateEncoding
    (SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice)
    ((TransitionEncoding dev '[FakeSize],
 QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
  QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
 -> QMaybe
      dev
      '[]
      (TransitionEncoding dev '[FakeSize],
       QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path
  (Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
  (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)) (t :: Type -> Type).
(Foldable t, KnownDevice dev) =>
Path
  (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
  (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen Path
  (Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
  (Notes SPitch)
frozen)
    ((TransitionEncoding dev '[FakeSize],
 QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition], [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice]))
encodePVState (GSOpen Path (Edges SPitch) (Notes SPitch)
open [PVLeftmost SPitch]
_) =
  QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
StateEncoding
    (SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice)
    ((TransitionEncoding dev '[FakeSize],
 QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition], [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice]))
    ((TransitionEncoding dev '[FakeSize],
 QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
  QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
 -> QMaybe
      dev
      '[]
      (TransitionEncoding dev '[FakeSize],
       QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen Path (Edges SPitch) (Notes SPitch)
open)
encodePVState (GSSemiOpen Path
  (Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
  (Notes SPitch)
frozen Notes SPitch
mid Path (Edges SPitch) (Notes SPitch)
open [PVLeftmost SPitch]
_) =
  QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
StateEncoding
    (SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qInner (SliceEncoding dev '[]
 -> QStartStop dev '[] (SliceEncoding dev '[]))
-> SliceEncoding dev '[]
-> QStartStop dev '[] (SliceEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice Notes SPitch
mid)
    ((TransitionEncoding dev '[FakeSize],
 QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
  QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
 -> QMaybe
      dev
      '[]
      (TransitionEncoding dev '[FakeSize],
       QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path
  (Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
  (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)) (t :: Type -> Type).
(Foldable t, KnownDevice dev) =>
Path
  (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
  (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen Path
  (Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
  (Notes SPitch)
frozen)
    ((TransitionEncoding dev '[FakeSize],
 QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
  QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
 -> QMaybe
      dev
      '[]
      (TransitionEncoding dev '[FakeSize],
       QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
    QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen Path (Edges SPitch) (Notes SPitch)
open)

-- Step Encoding
-- -------------

data QEncoding dev batchShape = QEncoding
  { forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> ActionEncoding dev batchShape
qActionEncoding :: !(ActionEncoding dev batchShape)
  , forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> StateEncoding dev
qStateEncoding :: !(StateEncoding dev)
  }
  deriving (Int -> QEncoding dev batchShape -> ShowS
[QEncoding dev batchShape] -> ShowS
QEncoding dev batchShape -> String
(Int -> QEncoding dev batchShape -> ShowS)
-> (QEncoding dev batchShape -> String)
-> ([QEncoding dev batchShape] -> ShowS)
-> Show (QEncoding dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> QEncoding dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[QEncoding dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> QEncoding dev batchShape -> ShowS
showsPrec :: Int -> QEncoding dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> String
show :: QEncoding dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[QEncoding dev batchShape] -> ShowS
showList :: [QEncoding dev batchShape] -> ShowS
Show, (forall x.
 QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x)
-> (forall x.
    Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape)
-> Generic (QEncoding dev batchShape)
forall x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
forall x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
from :: forall x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
to :: forall x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
Generic, QEncoding dev batchShape -> ()
(QEncoding dev batchShape -> ())
-> NFData (QEncoding dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> ()
rnf :: QEncoding dev batchShape -> ()
NFData)

instance Batchable (QEncoding dev shape) where
  type Batched (QEncoding dev shape) = QEncoding dev (1 : shape)
  addBatchDim :: QEncoding dev shape -> Batched (QEncoding dev shape)
addBatchDim (QEncoding ActionEncoding dev shape
ac StateEncoding dev
st) = ActionEncoding dev (1 : shape)
-> StateEncoding dev -> QEncoding dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape
-> StateEncoding dev -> QEncoding dev batchShape
QEncoding (ActionEncoding dev shape -> Batched (ActionEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim ActionEncoding dev shape
ac) StateEncoding dev
st

encodeStep
  :: (TT.KnownDevice dev)
  => PVState
  -> PVAction
  -> QEncoding dev '[]
encodeStep :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> PVAction -> QEncoding dev '[]
encodeStep PVState
state PVAction
action =
  ActionEncoding dev '[] -> StateEncoding dev -> QEncoding dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape
-> StateEncoding dev -> QEncoding dev batchShape
QEncoding
    (PVAction -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction PVAction
action)
    (PVState -> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> StateEncoding dev
encodePVState PVState
state)

withBatchedEncoding
  :: forall dev r
   . (TT.KnownDevice dev)
  => PVState
  -> NonEmpty PVAction
  -> (forall n. (KnownNat n) => QEncoding dev '[n] -> r)
  -> r
withBatchedEncoding :: forall (dev :: (DeviceType, Natural)) r.
KnownDevice dev =>
PVState
-> NonEmpty PVAction
-> (forall (n :: Natural). KnownNat n => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding PVState
state (PVAction
a0 :| [PVAction]
actions) forall (n :: Natural). KnownNat n => QEncoding dev '[n] -> r
f =
  [ActionEncoding dev '[]]
-> (forall (n :: Natural).
    KnownNat n =>
    Vector n (ActionEncoding dev '[]) -> r)
-> r
forall a r.
[a] -> (forall (n :: Natural). KnownNat n => Vector n a -> r) -> r
VS.withSizedList [ActionEncoding dev '[]]
aEncs Vector n (ActionEncoding dev '[]) -> r
forall (n :: Natural).
KnownNat n =>
Vector n (ActionEncoding dev '[]) -> r
inner
 where
  inner :: forall n. (KnownNat n) => VS.Vector n (ActionEncoding dev '[]) -> r
  inner :: forall (n :: Natural).
KnownNat n =>
Vector n (ActionEncoding dev '[]) -> r
inner Vector n (ActionEncoding dev '[])
aEncs' = QEncoding dev '[1 + n] -> r
forall (n :: Natural). KnownNat n => QEncoding dev '[n] -> r
f (QEncoding dev '[1 + n] -> r) -> QEncoding dev '[1 + n] -> r
forall a b. (a -> b) -> a -> b
$ ActionEncoding dev '[1 + n]
-> StateEncoding dev -> QEncoding dev '[1 + n]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape
-> StateEncoding dev -> QEncoding dev batchShape
QEncoding (Vector (1 + n) (ActionEncoding dev '[])
-> Stacked (ActionEncoding dev '[]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionEncoding dev '[])
-> Stacked (ActionEncoding dev '[]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (ActionEncoding dev '[]
-> Vector n (ActionEncoding dev '[])
-> Vector (1 + n) (ActionEncoding dev '[])
forall (n :: Natural) a. a -> Vector n a -> Vector (1 + n) a
VS.cons ActionEncoding dev '[]
a0Enc Vector n (ActionEncoding dev '[])
aEncs')) StateEncoding dev
sEnc
  a0Enc :: ActionEncoding dev '[]
a0Enc = PVAction -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction PVAction
a0
  aEncs :: [ActionEncoding dev '[]]
aEncs = PVAction -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction (PVAction -> ActionEncoding dev '[])
-> [PVAction] -> [ActionEncoding dev '[]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [PVAction]
actions
  sEnc :: StateEncoding dev
sEnc = PVState -> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> StateEncoding dev
encodePVState PVState
state