{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module RL.Encoding where
import Common
import GreedyParser
import Internal.MultiSet qualified as MS
import PVGrammar (Edge, Edges (Edges), Freeze (FreezeOp), InnerEdge, Note (..), Notes (Notes), PVAnalysis, PVLeftmost, Split, Spread)
import PVGrammar.Generate (derivationPlayerPV)
import PVGrammar.Parse (protoVoiceEvaluator, pvThaw)
import RL.ModelTypes
import Control.DeepSeq
import Data.Foldable qualified as F
import Data.HashSet qualified as HS
import Data.Hashable (Hashable)
import Data.List qualified
import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (catMaybes, mapMaybe)
import Data.Proxy (Proxy (..))
import Data.Type.Equality ((:~:) (..))
import Data.TypeNums (KnownInt, KnownNat, Nat, TInt (..), intVal, intVal', type (*), type (+), type (-), type (>=))
import Data.Vector qualified as V
import Data.Vector.Generic.Sized.Internal qualified as VSU
import Data.Vector.Sized qualified as VS
import Debug.Trace qualified as DT
import GHC.Exts (Proxy#, proxy#)
import GHC.Generics
import Musicology.Pitch
import Torch qualified as T
import Torch.Lens qualified as T
import Torch.Typed qualified as TT
import Unsafe.Coerce (unsafeCoerce)
class Stackable a where
type Stacked a (n :: Nat)
stack :: (KnownNat n, KnownNat (1 + n)) => VS.Vector (1 + n) a -> Stacked a (1 + n)
stackUnsafe :: (Stackable a) => [a] -> Stacked a FakeSize
stackUnsafe :: forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [a]
things = Vector (1 + 1336) a -> Stacked a (1 + 1336)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + 1336) a -> Stacked a (1 + 1336))
-> Vector (1 + 1336) a -> Stacked a (1 + 1336)
forall a b. (a -> b) -> a -> b
$ Vector a -> Vector (1 + 1336) a
forall (v :: Type -> Type) (n :: Natural) a. v a -> Vector v n a
VSU.Vector (Vector a -> Vector (1 + 1336) a)
-> Vector a -> Vector (1 + 1336) a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
things
class Batchable a where
type Batched a
addBatchDim :: a -> Batched a
instance Batchable (TT.Tensor dev dtype shape) where
type Batched (TT.Tensor dev dtype shape) = TT.Tensor dev dtype (1 : shape)
addBatchDim :: Tensor dev dtype shape -> Batched (Tensor dev dtype shape)
addBatchDim = forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0
data QMaybe dev (batchShape :: [Nat]) a = QMaybe
{ forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> QTensor dev batchShape
qmMask :: QTensor dev batchShape
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> a
qmContent :: a
}
deriving (Int -> QMaybe dev batchShape a -> ShowS
[QMaybe dev batchShape a] -> ShowS
QMaybe dev batchShape a -> String
(Int -> QMaybe dev batchShape a -> ShowS)
-> (QMaybe dev batchShape a -> String)
-> ([QMaybe dev batchShape a] -> ShowS)
-> Show (QMaybe dev batchShape a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QMaybe dev batchShape a -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QMaybe dev batchShape a] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QMaybe dev batchShape a -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QMaybe dev batchShape a -> ShowS
showsPrec :: Int -> QMaybe dev batchShape a -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QMaybe dev batchShape a -> String
show :: QMaybe dev batchShape a -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QMaybe dev batchShape a] -> ShowS
showList :: [QMaybe dev batchShape a] -> ShowS
Show, (forall x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x)
-> (forall x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a)
-> Generic (QMaybe dev batchShape a)
forall x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
forall x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
from :: forall x.
QMaybe dev batchShape a -> Rep (QMaybe dev batchShape a) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
to :: forall x.
Rep (QMaybe dev batchShape a) x -> QMaybe dev batchShape a
Generic, QMaybe dev batchShape a -> ()
(QMaybe dev batchShape a -> ()) -> NFData (QMaybe dev batchShape a)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QMaybe dev batchShape a -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QMaybe dev batchShape a -> ()
rnf :: QMaybe dev batchShape a -> ()
NFData)
qNothing
:: ( TT.TensorOptions batchShape QDType dev
)
=> a
-> QMaybe dev batchShape a
qNothing :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing = QTensor dev batchShape -> a -> QMaybe dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe QTensor dev batchShape
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
qJust
:: (TT.TensorOptions batchShape QDType dev)
=> a
-> QMaybe dev batchShape a
qJust :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust = QTensor dev batchShape -> a -> QMaybe dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe QTensor dev batchShape
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.ones
instance (Stackable a) => Stackable (QMaybe dev batchShape a) where
type Stacked (QMaybe dev batchShape a) n = QMaybe dev (n ': batchShape) (Stacked a n)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (QMaybe dev batchShape a)
-> Stacked (QMaybe dev batchShape a) (1 + n)
stack Vector (1 + n) (QMaybe dev batchShape a)
ms = QTensor dev ((1 + n) : batchShape)
-> Stacked a (1 + n)
-> QMaybe dev ((1 + n) : batchShape) (Stacked a (1 + n))
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe QTensor dev ((1 + n) : batchShape)
Tensor dev QDType (Insert 0 (1 + n) batchShape)
masks Stacked a (1 + n)
Stacked a (1 + n)
contents
where
masks :: Tensor dev QDType (Insert 0 (1 + n) batchShape)
masks = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType batchShape)
-> Tensor dev QDType (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev QDType batchShape)
-> Tensor dev QDType (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ QMaybe dev batchShape a -> Tensor dev QDType batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> QTensor dev batchShape
qmMask (QMaybe dev batchShape a -> Tensor dev QDType batchShape)
-> Vector Vector (1 + n) (QMaybe dev batchShape a)
-> Vector (1 + n) (Tensor dev QDType batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QMaybe dev batchShape a)
Vector (1 + n) (QMaybe dev batchShape a)
ms
contents :: Stacked a (1 + n)
contents = Vector (1 + n) a -> Stacked a (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) a -> Stacked a (1 + n))
-> Vector (1 + n) a -> Stacked a (1 + n)
forall a b. (a -> b) -> a -> b
$ QMaybe dev batchShape a -> a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QMaybe dev batchShape a -> a
qmContent (QMaybe dev batchShape a -> a)
-> Vector Vector (1 + n) (QMaybe dev batchShape a)
-> Vector Vector (1 + n) a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QMaybe dev batchShape a)
Vector (1 + n) (QMaybe dev batchShape a)
ms
instance (Batchable a) => Batchable (QMaybe dev shape a) where
type Batched (QMaybe dev shape a) = QMaybe dev (1 : shape) (Batched a)
addBatchDim :: QMaybe dev shape a -> Batched (QMaybe dev shape a)
addBatchDim (QMaybe QTensor dev shape
mask a
content) = QTensor dev (1 : shape)
-> Batched a -> QMaybe dev (1 : shape) (Batched a)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QTensor dev batchShape -> a -> QMaybe dev batchShape a
QMaybe (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 QTensor dev shape
mask) (a -> Batched a
forall a. Batchable a => a -> Batched a
addBatchDim a
content)
data QBoundedList dev (dtype :: TT.DType) (maxLen :: Nat) (batchShape :: [Nat]) (innerShape :: [Nat])
= QBoundedList
{ forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> QTensor dev (batchShape ++ '[maxLen])
qlMask :: QTensor dev (batchShape TT.++ '[maxLen])
, forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
qlContent :: TT.Tensor dev dtype (batchShape TT.++ '[maxLen] TT.++ innerShape)
}
deriving (Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
[QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
QBoundedList dev dtype maxLen batchShape innerShape -> String
(Int
-> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS)
-> (QBoundedList dev dtype maxLen batchShape innerShape -> String)
-> ([QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS)
-> Show (QBoundedList dev dtype maxLen batchShape innerShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
[QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
showsPrec :: Int -> QBoundedList dev dtype maxLen batchShape innerShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> String
show :: QBoundedList dev dtype maxLen batchShape innerShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
[QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
showList :: [QBoundedList dev dtype maxLen batchShape innerShape] -> ShowS
Show, (forall x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x)
-> (forall x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape)
-> Generic (QBoundedList dev dtype maxLen batchShape innerShape)
forall x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
forall x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]) x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]) x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]) x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
from :: forall x.
QBoundedList dev dtype maxLen batchShape innerShape
-> Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]) x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
to :: forall x.
Rep (QBoundedList dev dtype maxLen batchShape innerShape) x
-> QBoundedList dev dtype maxLen batchShape innerShape
Generic, QBoundedList dev dtype maxLen batchShape innerShape -> ()
(QBoundedList dev dtype maxLen batchShape innerShape -> ())
-> NFData (QBoundedList dev dtype maxLen batchShape innerShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape -> ()
rnf :: QBoundedList dev dtype maxLen batchShape innerShape -> ()
NFData)
qBoundedList
:: forall dev dtype maxLen innerShape
. ( KnownNat maxLen
, TT.KnownDevice dev
, TT.KnownShape innerShape
, TT.TensorOptions innerShape QDType dev
, TT.TensorOptions innerShape dtype dev
)
=> [TT.Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList :: forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
TensorOptions innerShape QDType dev,
TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList [] = QTensor dev ('[] ++ '[maxLen])
-> Tensor dev dtype (('[] ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen '[] innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList Tensor dev QDType '[maxLen]
QTensor dev ('[] ++ '[maxLen])
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros Tensor dev dtype (maxLen : innerShape)
Tensor dev dtype (('[] ++ '[maxLen]) ++ innerShape)
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
qBoundedList [Tensor dev dtype innerShape]
lst = QTensor dev ('[] ++ '[maxLen])
-> Tensor dev dtype (('[] ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen '[] innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList (Tensor -> Tensor dev QDType '[maxLen]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
mask) (Tensor -> Tensor dev dtype (maxLen : innerShape)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
paddedContent)
where
maxLen :: Int
maxLen = forall (n :: Natural). KnownNat n => Int
TT.natValI @maxLen
content :: Tensor
content = Dim -> [Tensor] -> Tensor
T.stack (Int -> Dim
T.Dim Int
0) ([Tensor] -> Tensor) -> [Tensor] -> Tensor
forall a b. (a -> b) -> a -> b
$ Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
take Int
maxLen ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ Tensor dev dtype innerShape -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (Tensor dev dtype innerShape -> Tensor)
-> [Tensor dev dtype innerShape] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Tensor dev dtype innerShape]
lst
len :: Int
len = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
maxLen (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ [Tensor dev dtype innerShape] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Tensor dev dtype innerShape]
lst
padLen :: Int
padLen = Int
maxLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
innerShape :: [Int]
innerShape = forall (shape :: [Natural]). KnownShape shape => [Int]
TT.shapeVal @innerShape
padSpec :: [Int]
padSpec = Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
innerShape) Int
0 [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0, Int
padLen]
paddedContent :: Tensor
paddedContent = [Int] -> Float -> Tensor -> Tensor
T.constantPadNd1d [Int]
padSpec Float
0 Tensor
content
mask :: Tensor
mask = Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) [[Int] -> TensorOptions -> Tensor
T.ones [Int
len] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev, [Int] -> TensorOptions -> Tensor
T.zeros [Int
padLen] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev]
instance Stackable (QBoundedList dev dtype maxLen batchShape innerShape) where
type
Stacked (QBoundedList dev dtype maxLen batchShape innerShape) n =
QBoundedList dev dtype maxLen (n ': batchShape) innerShape
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
-> Stacked
(QBoundedList dev dtype maxLen batchShape innerShape) (1 + n)
stack Vector
(1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
xs = QTensor dev (((1 + n) : batchShape) ++ '[maxLen])
-> Tensor
dev dtype ((((1 + n) : batchShape) ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen ((1 + n) : batchShape) innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList QTensor dev (((1 + n) : batchShape) ++ '[maxLen])
Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen]))
masks Tensor
dev dtype ((((1 + n) : batchShape) ++ '[maxLen]) ++ innerShape)
Tensor
dev
dtype
(Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape))
contents
where
masks :: Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen]))
masks = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType (batchShape ++ '[maxLen]))
-> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen])))
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[maxLen]))
-> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[maxLen]))
forall a b. (a -> b) -> a -> b
$ QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev QDType (batchShape ++ '[maxLen])
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> QTensor dev (batchShape ++ '[maxLen])
qlMask (QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev QDType (batchShape ++ '[maxLen]))
-> Vector
Vector
(1 + n)
(QBoundedList dev dtype maxLen batchShape innerShape)
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[maxLen]))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector
Vector
(1 + n)
(QBoundedList dev dtype maxLen batchShape innerShape)
Vector
(1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
xs
contents :: Tensor
dev
dtype
(Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape))
contents = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector
(1 + n)
(Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
-> Tensor
dev
dtype
(Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape)))
-> Vector
(1 + n)
(Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
-> Tensor
dev
dtype
(Insert 0 (1 + n) ((batchShape ++ '[maxLen]) ++ innerShape))
forall a b. (a -> b) -> a -> b
$ QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
qlContent (QBoundedList dev dtype maxLen batchShape innerShape
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
-> Vector
Vector
(1 + n)
(QBoundedList dev dtype maxLen batchShape innerShape)
-> Vector
(1 + n)
(Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector
Vector
(1 + n)
(QBoundedList dev dtype maxLen batchShape innerShape)
Vector
(1 + n) (QBoundedList dev dtype maxLen batchShape innerShape)
xs
instance Batchable (QBoundedList dev dtype maxLen batchShape innerShape) where
type
Batched (QBoundedList dev dtype maxLen batchShape innerShape) =
QBoundedList dev dtype maxLen (1 : batchShape) innerShape
addBatchDim :: QBoundedList dev dtype maxLen batchShape innerShape
-> Batched (QBoundedList dev dtype maxLen batchShape innerShape)
addBatchDim (QBoundedList QTensor dev (batchShape ++ '[maxLen])
mask Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
content) =
QTensor dev ((1 : batchShape) ++ '[maxLen])
-> Tensor dev dtype (((1 : batchShape) ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen (1 : batchShape) innerShape
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 QTensor dev (batchShape ++ '[maxLen])
mask) (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
content)
data QStartStop dev (batchShape :: [Nat]) a = QStartStop
{ forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape
qssTag :: TT.Tensor dev TT.Int64 batchShape
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> a
qssContent :: a
}
deriving (Int -> QStartStop dev batchShape a -> ShowS
[QStartStop dev batchShape a] -> ShowS
QStartStop dev batchShape a -> String
(Int -> QStartStop dev batchShape a -> ShowS)
-> (QStartStop dev batchShape a -> String)
-> ([QStartStop dev batchShape a] -> ShowS)
-> Show (QStartStop dev batchShape a)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QStartStop dev batchShape a -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QStartStop dev batchShape a] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QStartStop dev batchShape a -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
Int -> QStartStop dev batchShape a -> ShowS
showsPrec :: Int -> QStartStop dev batchShape a -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
QStartStop dev batchShape a -> String
show :: QStartStop dev batchShape a -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Show a =>
[QStartStop dev batchShape a] -> ShowS
showList :: [QStartStop dev batchShape a] -> ShowS
Show, (forall x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x)
-> (forall x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a)
-> Generic (QStartStop dev batchShape a)
forall x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
forall x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
from :: forall x.
QStartStop dev batchShape a -> Rep (QStartStop dev batchShape a) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a
x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
to :: forall x.
Rep (QStartStop dev batchShape a) x -> QStartStop dev batchShape a
Generic, QStartStop dev batchShape a -> ()
(QStartStop dev batchShape a -> ())
-> NFData (QStartStop dev batchShape a)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QStartStop dev batchShape a -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
NFData a =>
QStartStop dev batchShape a -> ()
rnf :: QStartStop dev batchShape a -> ()
NFData)
qInner :: (TT.TensorOptions batchShape TT.Int64 dev) => a -> QStartStop dev batchShape a
qInner :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qInner = Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (Int -> Tensor dev 'Int64 batchShape
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
1 :: Int))
qStart :: (TT.TensorOptions batchShape TT.Int64 dev) => a -> QStartStop dev batchShape a
qStart :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart = Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (Int -> Tensor dev 'Int64 batchShape
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
0 :: Int))
qStop :: (TT.TensorOptions batchShape TT.Int64 dev) => a -> QStartStop dev batchShape a
qStop :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop = Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (Int -> Tensor dev 'Int64 batchShape
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
2 :: Int))
qStartStop
:: (TT.TensorOptions batchShape TT.Int64 dev)
=> (a -> b)
-> b
-> StartStop a
-> QStartStop dev batchShape b
qStartStop :: forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop a -> b
f b
def StartStop a
val = case StartStop a
val of
StartStop a
Start -> b -> QStartStop dev batchShape b
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart b
def
StartStop a
Stop -> b -> QStartStop dev batchShape b
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop b
def
Inner a
x -> b -> QStartStop dev batchShape b
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qInner (b -> QStartStop dev batchShape b)
-> b -> QStartStop dev batchShape b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
x
instance (Stackable a) => Stackable (QStartStop dev batchShape a) where
type Stacked (QStartStop dev batchShape a) n = QStartStop dev (n ': batchShape) (Stacked a n)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (QStartStop dev batchShape a)
-> Stacked (QStartStop dev batchShape a) (1 + n)
stack Vector (1 + n) (QStartStop dev batchShape a)
xs = Tensor dev 'Int64 ((1 + n) : batchShape)
-> Stacked a (1 + n)
-> QStartStop dev ((1 + n) : batchShape) (Stacked a (1 + n))
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop Tensor dev 'Int64 ((1 + n) : batchShape)
Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
tags Stacked a (1 + n)
Stacked a (1 + n)
contents
where
tags :: Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
tags = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev 'Int64 batchShape)
-> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
-> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape
qssTag (QStartStop dev batchShape a -> Tensor dev 'Int64 batchShape)
-> Vector Vector (1 + n) (QStartStop dev batchShape a)
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QStartStop dev batchShape a)
Vector (1 + n) (QStartStop dev batchShape a)
xs
contents :: Stacked a (1 + n)
contents = Vector (1 + n) a -> Stacked a (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) a -> Stacked a (1 + n))
-> Vector (1 + n) a -> Stacked a (1 + n)
forall a b. (a -> b) -> a -> b
$ QStartStop dev batchShape a -> a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
QStartStop dev batchShape a -> a
qssContent (QStartStop dev batchShape a -> a)
-> Vector Vector (1 + n) (QStartStop dev batchShape a)
-> Vector Vector (1 + n) a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (QStartStop dev batchShape a)
Vector (1 + n) (QStartStop dev batchShape a)
xs
instance (Batchable a) => Batchable (QStartStop dev shape a) where
type Batched (QStartStop dev shape a) = QStartStop dev (1 : shape) (Batched a)
addBatchDim :: QStartStop dev shape a -> Batched (QStartStop dev shape a)
addBatchDim (QStartStop Tensor dev 'Int64 shape
tag a
content) =
Tensor dev 'Int64 (1 : shape)
-> Batched a -> QStartStop dev (1 : shape) (Batched a)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) a.
Tensor dev 'Int64 batchShape -> a -> QStartStop dev batchShape a
QStartStop (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev 'Int64 shape
tag) (a -> Batched a
forall a. Batchable a => a -> Batched a
addBatchDim a
content)
newtype SliceEncodingSparse dev batchShape = SliceEncodingSparse
{forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2]
getSliceEncodingSparse :: QBoundedList dev TT.Int64 MaxPitches batchShape '[2]}
deriving (Int -> SliceEncodingSparse dev batchShape -> ShowS
[SliceEncodingSparse dev batchShape] -> ShowS
SliceEncodingSparse dev batchShape -> String
(Int -> SliceEncodingSparse dev batchShape -> ShowS)
-> (SliceEncodingSparse dev batchShape -> String)
-> ([SliceEncodingSparse dev batchShape] -> ShowS)
-> Show (SliceEncodingSparse dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingSparse dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingSparse dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingSparse dev batchShape -> ShowS
showsPrec :: Int -> SliceEncodingSparse dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> String
show :: SliceEncodingSparse dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingSparse dev batchShape] -> ShowS
showList :: [SliceEncodingSparse dev batchShape] -> ShowS
Show, (forall x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x)
-> (forall x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape)
-> Generic (SliceEncodingSparse dev batchShape)
forall x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
forall x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
from :: forall x.
SliceEncodingSparse dev batchShape
-> Rep (SliceEncodingSparse dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
to :: forall x.
Rep (SliceEncodingSparse dev batchShape) x
-> SliceEncodingSparse dev batchShape
Generic)
deriving newtype (SliceEncodingSparse dev batchShape -> ()
(SliceEncodingSparse dev batchShape -> ())
-> NFData (SliceEncodingSparse dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape -> ()
rnf :: SliceEncodingSparse dev batchShape -> ()
NFData)
instance Stackable (SliceEncodingSparse dev batchShape) where
type Stacked (SliceEncodingSparse dev batchShape) n = SliceEncodingSparse dev (n ': batchShape)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncodingSparse dev batchShape)
-> Stacked (SliceEncodingSparse dev batchShape) (1 + n)
stack Vector (1 + n) (SliceEncodingSparse dev batchShape)
slices = QBoundedList dev 'Int64 MaxPitches ((1 + n) : batchShape) '[2]
-> SliceEncodingSparse dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev 'Int64 MaxPitches batchShape '[2]
-> SliceEncodingSparse dev batchShape
SliceEncodingSparse (QBoundedList dev 'Int64 MaxPitches ((1 + n) : batchShape) '[2]
-> SliceEncodingSparse dev ((1 + n) : batchShape))
-> QBoundedList dev 'Int64 MaxPitches ((1 + n) : batchShape) '[2]
-> SliceEncodingSparse dev ((1 + n) : batchShape)
forall a b. (a -> b) -> a -> b
$ Vector (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
(QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
(QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
(QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n))
-> Vector
(1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Stacked
(QBoundedList dev 'Int64 MaxPitches batchShape '[2]) (1 + n)
forall a b. (a -> b) -> a -> b
$ SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2]
getSliceEncodingSparse (SliceEncodingSparse dev batchShape
-> QBoundedList dev 'Int64 MaxPitches batchShape '[2])
-> Vector Vector (1 + n) (SliceEncodingSparse dev batchShape)
-> Vector
Vector (1 + n) (QBoundedList dev 'Int64 MaxPitches batchShape '[2])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (SliceEncodingSparse dev batchShape)
Vector (1 + n) (SliceEncodingSparse dev batchShape)
slices
instance Batchable (SliceEncodingSparse dev shape) where
type Batched (SliceEncodingSparse dev shape) = SliceEncodingSparse dev (1 ': shape)
addBatchDim :: SliceEncodingSparse dev shape
-> Batched (SliceEncodingSparse dev shape)
addBatchDim (SliceEncodingSparse QBoundedList dev 'Int64 MaxPitches shape '[2]
slice) = QBoundedList dev 'Int64 MaxPitches (1 : shape) '[2]
-> SliceEncodingSparse dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev 'Int64 MaxPitches batchShape '[2]
-> SliceEncodingSparse dev batchShape
SliceEncodingSparse (QBoundedList dev 'Int64 MaxPitches (1 : shape) '[2]
-> SliceEncodingSparse dev (1 : shape))
-> QBoundedList dev 'Int64 MaxPitches (1 : shape) '[2]
-> SliceEncodingSparse dev (1 : shape)
forall a b. (a -> b) -> a -> b
$ QBoundedList dev 'Int64 MaxPitches shape '[2]
-> Batched (QBoundedList dev 'Int64 MaxPitches shape '[2])
forall a. Batchable a => a -> Batched a
addBatchDim QBoundedList dev 'Int64 MaxPitches shape '[2]
slice
newtype SliceEncodingDense dev batchShape = SliceEncodingDense
{forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSliceEncodingDense :: QTensor dev (batchShape TT.++ PShape)}
deriving (Int -> SliceEncodingDense dev batchShape -> ShowS
[SliceEncodingDense dev batchShape] -> ShowS
SliceEncodingDense dev batchShape -> String
(Int -> SliceEncodingDense dev batchShape -> ShowS)
-> (SliceEncodingDense dev batchShape -> String)
-> ([SliceEncodingDense dev batchShape] -> ShowS)
-> Show (SliceEncodingDense dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingDense dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingDense dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> SliceEncodingDense dev batchShape -> ShowS
showsPrec :: Int -> SliceEncodingDense dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> String
show :: SliceEncodingDense dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[SliceEncodingDense dev batchShape] -> ShowS
showList :: [SliceEncodingDense dev batchShape] -> ShowS
Show, (forall x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x)
-> (forall x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape)
-> Generic (SliceEncodingDense dev batchShape)
forall x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
forall x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
from :: forall x.
SliceEncodingDense dev batchShape
-> Rep (SliceEncodingDense dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
to :: forall x.
Rep (SliceEncodingDense dev batchShape) x
-> SliceEncodingDense dev batchShape
Generic)
deriving newtype (SliceEncodingDense dev batchShape -> ()
(SliceEncodingDense dev batchShape -> ())
-> NFData (SliceEncodingDense dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape -> ()
rnf :: SliceEncodingDense dev batchShape -> ()
NFData)
instance Stackable (SliceEncodingDense dev batchShape) where
type Stacked (SliceEncodingDense dev batchShape) n = SliceEncodingDense dev (n ': batchShape)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncodingDense dev batchShape)
-> Stacked (SliceEncodingDense dev batchShape) (1 + n)
stack Vector (1 + n) (SliceEncodingDense dev batchShape)
slices = QTensor dev (((1 + n) : batchShape) ++ PShape)
-> SliceEncodingDense dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense (QTensor dev (((1 + n) : batchShape) ++ PShape)
-> SliceEncodingDense dev ((1 + n) : batchShape))
-> QTensor dev (((1 + n) : batchShape) ++ PShape)
-> SliceEncodingDense dev ((1 + n) : batchShape)
forall a b. (a -> b) -> a -> b
$ forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType (batchShape ++ '[13, 5]))
-> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[13, 5])))
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[13, 5]))
-> Tensor dev QDType (Insert 0 (1 + n) (batchShape ++ '[13, 5]))
forall a b. (a -> b) -> a -> b
$ SliceEncodingDense dev batchShape
-> Tensor dev QDType (batchShape ++ '[13, 5])
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSliceEncodingDense (SliceEncodingDense dev batchShape
-> Tensor dev QDType (batchShape ++ '[13, 5]))
-> Vector Vector (1 + n) (SliceEncodingDense dev batchShape)
-> Vector (1 + n) (Tensor dev QDType (batchShape ++ '[13, 5]))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (SliceEncodingDense dev batchShape)
Vector (1 + n) (SliceEncodingDense dev batchShape)
slices
instance Batchable (SliceEncodingDense dev batchShape) where
type Batched (SliceEncodingDense dev batchShape) = SliceEncodingDense dev (1 ': batchShape)
addBatchDim :: SliceEncodingDense dev batchShape
-> Batched (SliceEncodingDense dev batchShape)
addBatchDim (SliceEncodingDense QTensor dev (batchShape ++ PShape)
slice) = QTensor dev ((1 : batchShape) ++ PShape)
-> SliceEncodingDense dev (1 : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense (QTensor dev ((1 : batchShape) ++ PShape)
-> SliceEncodingDense dev (1 : batchShape))
-> QTensor dev ((1 : batchShape) ++ PShape)
-> SliceEncodingDense dev (1 : batchShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev (batchShape ++ '[13, 5])
-> Batched (QTensor dev (batchShape ++ '[13, 5]))
forall a. Batchable a => a -> Batched a
addBatchDim QTensor dev (batchShape ++ '[13, 5])
QTensor dev (batchShape ++ PShape)
slice
type SliceEncoding = SliceEncodingDense
getSlice
:: forall dev batchShape
. SliceEncoding dev batchShape
-> QTensor dev (batchShape TT.++ PShape)
getSlice :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSlice = SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
SliceEncodingDense dev batchShape
-> QTensor dev (batchShape ++ PShape)
getSliceEncodingDense
encodePitches
:: (TT.KnownDevice dev)
=> [SPitch]
-> SliceEncoding dev '[]
encodePitches :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
encodePitches = [SPitch] -> SliceEncodingDense dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum
pitch2index
:: SPitch
-> [Int]
pitch2index :: SPitch -> [Int]
pitch2index SPitch
p =
[ Int -> Int -> Int
forall {a}. (Ord a, Num a) => a -> a -> a
clamp Int
fifthSize (SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow)
, Int -> Int -> Int
forall {a}. (Ord a, Num a) => a -> a -> a
clamp Int
octaveSize (SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow)
]
where
clamp :: a -> a -> a
clamp a
m a
i = a -> a -> a
forall a. Ord a => a -> a -> a
max a
0 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. Ord a => a -> a -> a
min a
m a
i
fifthLow :: Int
fifthLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @FifthLow Proxy# FifthLow
forall {k} (a :: k). Proxy# a
proxy#
octaveLow :: Int
octaveLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @OctaveLow Proxy# OctaveLow
forall {k} (a :: k). Proxy# a
proxy#
fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
pitchesMultiHot
:: forall dev
. (TT.KnownDevice dev)
=> HS.HashSet SPitch
-> QTensor dev PShape
pitchesMultiHot :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
HashSet SPitch -> QTensor dev PShape
pitchesMultiHot HashSet SPitch
ps = Tensor -> Tensor dev QDType '[13, 5]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
out
where
out :: Tensor
out =
Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
if HashSet SPitch -> Bool
forall a. HashSet a -> Bool
HS.null HashSet SPitch
ps
then Tensor
zeros
else Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indices Tensor
values Tensor
zeros
~[Tensor]
indices = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose (SPitch -> [Int]
pitch2index (SPitch -> [Int]) -> [SPitch] -> [[Int]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> HashSet SPitch -> [SPitch]
forall a. HashSet a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
F.toList HashSet SPitch
ps)
values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [HashSet SPitch -> Int
forall a. HashSet a -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
F.length HashSet SPitch
ps] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
zeros :: Tensor
zeros = [Int] -> TensorOptions -> Tensor
T.zeros [Int]
dims (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
dims :: [Int]
dims = [Int
fifthSize, Int
octaveSize]
pitchesOneHotSum
:: forall dev
. (TT.KnownDevice dev)
=> [SPitch]
-> SliceEncodingDense dev '[]
pitchesOneHotSum :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum [] = QTensor dev ('[] ++ PShape) -> SliceEncodingDense dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense Tensor dev QDType '[13, 5]
QTensor dev ('[] ++ PShape)
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
pitchesOneHotSum [SPitch]
ps = QTensor dev ('[] ++ PShape) -> SliceEncodingDense dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QTensor dev (batchShape ++ PShape)
-> SliceEncodingDense dev batchShape
SliceEncodingDense QTensor dev PShape
QTensor dev ('[] ++ PShape)
out
where
pitches :: [SPitch]
pitches = [SPitch]
ps
n :: Int
n = [SPitch] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [SPitch]
pitches
mkIndex :: Int -> SPitch -> [Int]
mkIndex Int
i SPitch
pitch = Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: SPitch -> [Int]
pitch2index SPitch
pitch
indices :: [Tensor]
indices = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose ((Int -> SPitch -> [Int]) -> [Int] -> [SPitch] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SPitch -> [Int]
mkIndex [Int
0 ..] [SPitch]
pitches)
values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
zeros :: Tensor
zeros = [Int] -> TensorOptions -> Tensor
T.zeros (Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: forall (shape :: [Natural]). KnownShape shape => [Int]
TT.shapeVal @PShape) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
out :: QTensor dev PShape
out :: QTensor dev PShape
out = Tensor -> QTensor dev PShape
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev PShape) -> Tensor -> QTensor dev PShape
forall a b. (a -> b) -> a -> b
$ Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> KeepDim -> DType -> Tensor -> Tensor
T.sumDim (Int -> Dim
T.Dim Int
0) KeepDim
T.RemoveDim DType
qDType (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indices Tensor
values (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor
zeros
pitchesOneHots
:: forall dev
. (TT.KnownDevice dev)
=> [SPitch]
-> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots [] = QTensor dev ('[] ++ '[MaxPitches])
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ '[1, 13, 5])
-> QBoundedList dev QDType MaxPitches '[] '[1, 13, 5]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList Tensor dev QDType '[MaxPitches]
QTensor dev ('[] ++ '[MaxPitches])
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros Tensor dev QDType '[MaxPitches, 1, 13, 5]
Tensor dev QDType (('[] ++ '[MaxPitches]) ++ '[1, 13, 5])
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
pitchesOneHots [SPitch]
ps = QTensor dev ('[] ++ '[MaxPitches])
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ '[1, 13, 5])
-> QBoundedList dev QDType MaxPitches '[] '[1, 13, 5]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList Tensor dev QDType '[MaxPitches]
QTensor dev ('[] ++ '[MaxPitches])
mask (Tensor dev QDType '[MaxPitches, 13, 5]
-> Tensor dev QDType '[MaxPitches, 1, 13, 5]
forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape Tensor dev QDType '[MaxPitches, 13, 5]
QTensor dev (MaxPitches : PShape)
out)
where
pitches :: [SPitch]
pitches = Int -> [SPitch] -> [SPitch]
forall a. Int -> [a] -> [a]
take Int
maxPitches [SPitch]
ps
n :: Int
n = [SPitch] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [SPitch]
pitches
maxPitches :: Int
maxPitches = forall (n :: Natural). KnownNat n => Int
TT.natValI @MaxPitches
mkIndex :: Int -> SPitch -> [Int]
mkIndex Int
i SPitch
pitch = Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: SPitch -> [Int]
pitch2index SPitch
pitch
indices :: [Tensor]
indices = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose ((Int -> SPitch -> [Int]) -> [Int] -> [SPitch] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SPitch -> [Int]
mkIndex [Int
0 ..] [SPitch]
pitches)
values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
zeros :: QTensor '(TT.CPU, 0) (MaxPitches ': PShape)
zeros :: QTensor '( 'CPU, 0) (MaxPitches : PShape)
zeros = QTensor '( 'CPU, 0) '[MaxPitches, 13, 5]
QTensor '( 'CPU, 0) (MaxPitches : PShape)
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
out :: QTensor dev (MaxPitches : PShape)
out :: QTensor dev (MaxPitches : PShape)
out = Tensor -> QTensor dev (MaxPitches : PShape)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev (MaxPitches : PShape))
-> Tensor -> QTensor dev (MaxPitches : PShape)
forall a b. (a -> b) -> a -> b
$ Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indices Tensor
values (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor '( 'CPU, 0) '[MaxPitches, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor '( 'CPU, 0) '[MaxPitches, 13, 5]
QTensor '( 'CPU, 0) (MaxPitches : PShape)
zeros
mask :: QTensor dev '[MaxPitches]
mask :: Tensor dev QDType '[MaxPitches]
mask = Tensor -> Tensor dev QDType '[MaxPitches]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[MaxPitches])
-> Tensor -> Tensor dev QDType '[MaxPitches]
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) [[Int] -> TensorOptions -> Tensor
T.ones [Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev, [Int] -> TensorOptions -> Tensor
T.zeros [Int
maxPitches Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev]
pitchesTokens
:: forall dev
. (TT.KnownDevice dev)
=> [SPitch]
-> QBoundedList dev QDType MaxPitches '[] '[PSize]
pitchesTokens :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] '[PSize]
pitchesTokens [SPitch]
ps = [Tensor dev QDType '[18]]
-> QBoundedList dev QDType MaxPitches '[] '[18]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
TensorOptions innerShape QDType dev,
TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList (SPitch -> Tensor dev QDType '[18]
mkToken (SPitch -> Tensor dev QDType '[18])
-> [SPitch] -> [Tensor dev QDType '[18]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SPitch]
ps)
where
opts' :: DType -> TensorOptions -> TensorOptions
opts' = DType -> TensorOptions -> TensorOptions
T.withDType
mkToken :: SPitch -> Tensor dev QDType '[18]
mkToken SPitch
p =
Tensor -> Tensor dev QDType '[18]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[18])
-> Tensor -> Tensor dev QDType '[18]
forall a b. (a -> b) -> a -> b
$ DType -> Tensor -> Tensor
T.toDType DType
qDType (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) [Int -> Tensor -> Tensor
T.oneHot Int
fifthSize Tensor
f, Int -> Tensor -> Tensor
T.oneHot Int
octaveSize Tensor
o]
where
f :: Tensor
f = Int -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => Int -> opt -> Tensor
T.asTensor' (SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
o :: Tensor
o = Int -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => Int -> opt -> Tensor
T.asTensor' (SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
fifthLow :: Int
fifthLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @FifthLow Proxy# FifthLow
forall {k} (a :: k). Proxy# a
proxy#
octaveLow :: Int
octaveLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @OctaveLow Proxy# OctaveLow
forall {k} (a :: k). Proxy# a
proxy#
fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
pitchesIndices
:: forall dev
. (TT.KnownDevice dev)
=> [SPitch]
-> SliceEncodingSparse dev '[]
pitchesIndices :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncodingSparse dev '[]
pitchesIndices [SPitch]
ps = QBoundedList dev 'Int64 MaxPitches '[] '[2]
-> SliceEncodingSparse dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev 'Int64 MaxPitches batchShape '[2]
-> SliceEncodingSparse dev batchShape
SliceEncodingSparse (QBoundedList dev 'Int64 MaxPitches '[] '[2]
-> SliceEncodingSparse dev '[])
-> QBoundedList dev 'Int64 MaxPitches '[] '[2]
-> SliceEncodingSparse dev '[]
forall a b. (a -> b) -> a -> b
$ [Tensor dev 'Int64 '[2]]
-> QBoundedList dev 'Int64 MaxPitches '[] '[2]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
TensorOptions innerShape QDType dev,
TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList (SPitch -> Tensor dev 'Int64 '[2]
mkToken (SPitch -> Tensor dev 'Int64 '[2])
-> [SPitch] -> [Tensor dev 'Int64 '[2]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [SPitch]
ps)
where
mkIndex :: SPitch -> [Int]
mkIndex = SPitch -> [Int]
pitch2index
mkToken :: SPitch -> Tensor dev 'Int64 '[2]
mkToken SPitch
p = Tensor -> Tensor dev 'Int64 '[2]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev 'Int64 '[2])
-> Tensor -> Tensor dev 'Int64 '[2]
forall a b. (a -> b) -> a -> b
$ [Int] -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => [Int] -> opt -> Tensor
T.asTensor' (SPitch -> [Int]
mkIndex SPitch
p) (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
encodeSlice
:: (TT.KnownDevice dev)
=> Notes SPitch
-> SliceEncoding dev '[]
encodeSlice :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice (Notes HashSet (Note SPitch)
notes) = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
encodePitches ([SPitch] -> SliceEncoding dev '[])
-> [SPitch] -> SliceEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch) -> [Note SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> HashSet (Note SPitch) -> [Note SPitch]
forall a. HashSet a -> [a]
HS.toList HashSet (Note SPitch)
notes
emptySlice
:: (TT.KnownDevice dev) => SliceEncoding dev '[]
emptySlice :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
encodePitches []
data TransitionEncoding dev batchShape = TransitionEncoding
{ forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencPassing :: QBoundedList dev QDType MaxEdges batchShape (2 ': PShape)
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencInner :: QBoundedList dev QDType MaxEdges batchShape (2 ': PShape)
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft :: SliceEncoding dev batchShape
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencRight :: SliceEncoding dev batchShape
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRoot :: QTensor dev batchShape
}
deriving (Int -> TransitionEncoding dev batchShape -> ShowS
[TransitionEncoding dev batchShape] -> ShowS
TransitionEncoding dev batchShape -> String
(Int -> TransitionEncoding dev batchShape -> ShowS)
-> (TransitionEncoding dev batchShape -> String)
-> ([TransitionEncoding dev batchShape] -> ShowS)
-> Show (TransitionEncoding dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> TransitionEncoding dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[TransitionEncoding dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> TransitionEncoding dev batchShape -> ShowS
showsPrec :: Int -> TransitionEncoding dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> String
show :: TransitionEncoding dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[TransitionEncoding dev batchShape] -> ShowS
showList :: [TransitionEncoding dev batchShape] -> ShowS
Show, (forall x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x)
-> (forall x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape)
-> Generic (TransitionEncoding dev batchShape)
forall x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
forall x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
from :: forall x.
TransitionEncoding dev batchShape
-> Rep (TransitionEncoding dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
to :: forall x.
Rep (TransitionEncoding dev batchShape) x
-> TransitionEncoding dev batchShape
Generic, TransitionEncoding dev batchShape -> ()
(TransitionEncoding dev batchShape -> ())
-> NFData (TransitionEncoding dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> ()
rnf :: TransitionEncoding dev batchShape -> ()
NFData)
instance Stackable (TransitionEncoding dev batchShape) where
type
Stacked (TransitionEncoding dev batchShape) n =
TransitionEncoding dev (n ': batchShape)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
stack Vector (1 + n) (TransitionEncoding dev batchShape)
xs = QBoundedList
dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
-> QBoundedList
dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
-> SliceEncoding dev ((1 + n) : batchShape)
-> SliceEncoding dev ((1 + n) : batchShape)
-> QTensor dev ((1 + n) : batchShape)
-> TransitionEncoding dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> SliceEncoding dev batchShape
-> SliceEncoding dev batchShape
-> QTensor dev batchShape
-> TransitionEncoding dev batchShape
TransitionEncoding QBoundedList
dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
passing QBoundedList
dev QDType MaxPitches ((1 + n) : batchShape) (2 : PShape)
Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
inner SliceEncoding dev ((1 + n) : batchShape)
Stacked (SliceEncoding dev batchShape) (1 + n)
left SliceEncoding dev ((1 + n) : batchShape)
Stacked (SliceEncoding dev batchShape) (1 + n)
right QTensor dev ((1 + n) : batchShape)
Tensor dev QDType (Insert 0 (1 + n) batchShape)
root
where
passing :: Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
passing = Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
(1 + n))
-> Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencPassing (TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector
Vector
(1 + n)
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
inner :: Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
inner = Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
(1 + n))
-> Vector
(1 + n) (QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Stacked
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5]
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
trencInner (TransitionEncoding dev batchShape
-> QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector
Vector
(1 + n)
(QBoundedList dev QDType MaxPitches batchShape '[2, 13, 5])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
left :: Stacked (SliceEncoding dev batchShape) (1 + n)
left = Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n))
-> Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft (TransitionEncoding dev batchShape -> SliceEncoding dev batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector Vector (1 + n) (SliceEncoding dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
right :: Stacked (SliceEncoding dev batchShape) (1 + n)
right = Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n))
-> Vector (1 + n) (SliceEncoding dev batchShape)
-> Stacked (SliceEncoding dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencRight (TransitionEncoding dev batchShape -> SliceEncoding dev batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector Vector (1 + n) (SliceEncoding dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
root :: Tensor dev QDType (Insert 0 (1 + n) batchShape)
root = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev QDType batchShape)
-> Tensor dev QDType (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev QDType batchShape)
-> Tensor dev QDType (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev batchShape -> Tensor dev QDType batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRoot (TransitionEncoding dev batchShape -> Tensor dev QDType batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
-> Vector (1 + n) (Tensor dev QDType batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
Vector (1 + n) (TransitionEncoding dev batchShape)
xs
instance Batchable (TransitionEncoding dev shape) where
type Batched (TransitionEncoding dev shape) = TransitionEncoding dev (1 : shape)
addBatchDim :: TransitionEncoding dev shape
-> Batched (TransitionEncoding dev shape)
addBatchDim (TransitionEncoding QBoundedList dev QDType MaxPitches shape (2 : PShape)
p QBoundedList dev QDType MaxPitches shape (2 : PShape)
i SliceEncoding dev shape
l SliceEncoding dev shape
r QTensor dev shape
rt) =
QBoundedList dev QDType MaxPitches (1 : shape) (2 : PShape)
-> QBoundedList dev QDType MaxPitches (1 : shape) (2 : PShape)
-> SliceEncoding dev (1 : shape)
-> SliceEncoding dev (1 : shape)
-> QTensor dev (1 : shape)
-> TransitionEncoding dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> QBoundedList dev QDType MaxPitches batchShape (2 : PShape)
-> SliceEncoding dev batchShape
-> SliceEncoding dev batchShape
-> QTensor dev batchShape
-> TransitionEncoding dev batchShape
TransitionEncoding
(QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
-> Batched (QBoundedList dev QDType MaxPitches shape '[2, 13, 5])
forall a. Batchable a => a -> Batched a
addBatchDim QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
QBoundedList dev QDType MaxPitches shape (2 : PShape)
p)
(QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
-> Batched (QBoundedList dev QDType MaxPitches shape '[2, 13, 5])
forall a. Batchable a => a -> Batched a
addBatchDim QBoundedList dev QDType MaxPitches shape '[2, 13, 5]
QBoundedList dev QDType MaxPitches shape (2 : PShape)
i)
(SliceEncoding dev shape -> Batched (SliceEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim SliceEncoding dev shape
l)
(SliceEncoding dev shape -> Batched (SliceEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim SliceEncoding dev shape
r)
(forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 QTensor dev shape
rt)
edgesMultiHot
:: forall dev
. (TT.KnownDevice dev)
=> HS.HashSet (InnerEdge SPitch)
-> QTensor dev EShape'
edgesMultiHot :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
HashSet (InnerEdge SPitch) -> QTensor dev EShape'
edgesMultiHot HashSet (InnerEdge SPitch)
es = Tensor -> Tensor dev QDType '[FakeSize, 36]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor Tensor
out
where
out :: Tensor
out =
Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice (forall (device :: (DeviceType, Natural)).
KnownDevice device =>
Device
TT.deviceVal @dev) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
if HashSet (InnerEdge SPitch) -> Bool
forall a. HashSet a -> Bool
HS.null HashSet (InnerEdge SPitch)
es
then Tensor
zeros
else Bool -> [Tensor] -> Tensor -> Tensor -> Tensor
T.indexPut Bool
True [Tensor]
indexTensors Tensor
values Tensor
zeros
edge2index :: InnerEdge SPitch -> [Int]
edge2index (Note SPitch
p1 String
_, Note SPitch
p2 String
_) =
SPitch -> [Int]
pitch2index SPitch
p1
[Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ SPitch -> [Int]
pitch2index SPitch
p2
indices :: [[Int]]
indices = InnerEdge SPitch -> [Int]
edge2index (InnerEdge SPitch -> [Int]) -> [InnerEdge SPitch] -> [[Int]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> HashSet (InnerEdge SPitch) -> [InnerEdge SPitch]
forall a. HashSet a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
F.toList HashSet (InnerEdge SPitch)
es
~[Tensor]
indexTensors = [Int] -> Tensor
forall a. TensorLike a => a -> Tensor
T.asTensor ([Int] -> Tensor) -> [[Int]] -> [Tensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]] -> [[Int]]
forall a. [[a]] -> [[a]]
Data.List.transpose [[Int]]
indices
values :: Tensor
values = [Int] -> TensorOptions -> Tensor
T.ones [HashSet (InnerEdge SPitch) -> Int
forall a. HashSet a -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
F.length HashSet (InnerEdge SPitch)
es] (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
zeros :: Tensor
zeros = [Int] -> TensorOptions -> Tensor
T.zeros [Int]
dims (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @'(TT.CPU, 0)
fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
dims :: [Int]
dims = [Int
fifthSize, Int
octaveSize, Int
fifthSize, Int
octaveSize]
edgesOneHots
:: forall dev
. (TT.KnownDevice dev)
=> [InnerEdge SPitch]
-> QBoundedList dev QDType MaxEdges '[] (2 ': PShape)
edgesOneHots :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
edgesOneHots [InnerEdge SPitch]
es = QTensor dev ('[] ++ '[MaxPitches])
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (2 : PShape))
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (batchShape :: [Natural])
(innerShape :: [Natural]).
QTensor dev (batchShape ++ '[maxLen])
-> Tensor dev dtype ((batchShape ++ '[maxLen]) ++ innerShape)
-> QBoundedList dev dtype maxLen batchShape innerShape
QBoundedList QTensor dev ('[] ++ '[MaxPitches])
mask (Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (2 : PShape))
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape))
-> Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (2 : PShape))
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall a b. (a -> b) -> a -> b
$ forall (dim :: Natural) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Natural) (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
TT.cat @1 (Tensor dev QDType '[MaxPitches, 1, 13, 5]
Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots1 Tensor dev QDType '[MaxPitches, 1, 13, 5]
-> HList '[Tensor dev QDType '[MaxPitches, 1, 13, 5]]
-> HList
'[Tensor dev QDType '[MaxPitches, 1, 13, 5],
Tensor dev QDType '[MaxPitches, 1, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. Tensor dev QDType '[MaxPitches, 1, 13, 5]
Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots2 Tensor dev QDType '[MaxPitches, 1, 13, 5]
-> HList '[] -> HList '[Tensor dev QDType '[MaxPitches, 1, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil)
where
(QBoundedList QTensor dev ('[] ++ '[MaxPitches])
mask Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots1) = forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots @dev ([SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape))
-> [SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
forall a b. (a -> b) -> a -> b
$ (Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch)
-> (InnerEdge SPitch -> Note SPitch) -> InnerEdge SPitch -> SPitch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InnerEdge SPitch -> Note SPitch
forall a b. (a, b) -> a
fst) (InnerEdge SPitch -> SPitch) -> [InnerEdge SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [InnerEdge SPitch]
es
(QBoundedList QTensor dev ('[] ++ '[MaxPitches])
_ Tensor dev QDType (('[] ++ '[MaxPitches]) ++ (1 : PShape))
hots2) = forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
pitchesOneHots @dev ([SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape))
-> [SPitch] -> QBoundedList dev QDType MaxPitches '[] (1 : PShape)
forall a b. (a -> b) -> a -> b
$ (Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch)
-> (InnerEdge SPitch -> Note SPitch) -> InnerEdge SPitch -> SPitch
forall b c a. (b -> c) -> (a -> b) -> a -> c
. InnerEdge SPitch -> Note SPitch
forall a b. (a, b) -> b
snd) (InnerEdge SPitch -> SPitch) -> [InnerEdge SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [InnerEdge SPitch]
es
edgesTokens
:: forall dev
. (TT.KnownDevice dev)
=> [InnerEdge SPitch]
-> QBoundedList dev QDType MaxEdges '[] '[ESize]
edgesTokens :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] '[ESize]
edgesTokens [InnerEdge SPitch]
es = [Tensor dev QDType '[36]]
-> QBoundedList dev QDType MaxPitches '[] '[36]
forall (dev :: (DeviceType, Natural)) (dtype :: DType)
(maxLen :: Natural) (innerShape :: [Natural]).
(KnownNat maxLen, KnownDevice dev, KnownShape innerShape,
TensorOptions innerShape QDType dev,
TensorOptions innerShape dtype dev) =>
[Tensor dev dtype innerShape]
-> QBoundedList dev dtype maxLen '[] innerShape
qBoundedList (InnerEdge SPitch -> Tensor dev QDType '[36]
mkToken (InnerEdge SPitch -> Tensor dev QDType '[36])
-> [InnerEdge SPitch] -> [Tensor dev QDType '[36]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [InnerEdge SPitch]
es)
where
mkToken :: InnerEdge SPitch -> Tensor dev QDType '[36]
mkToken (Note SPitch
p1 String
_, Note SPitch
p2 String
_) =
Tensor -> Tensor dev QDType '[36]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
(shape :: [Natural]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[36])
-> Tensor -> Tensor dev QDType '[36]
forall a b. (a -> b) -> a -> b
$!
forall (dev :: (DeviceType, Natural)) a.
(KnownDevice dev, HasTypes a Tensor) =>
a -> a
toOpts @dev (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$
Dim -> [Tensor] -> Tensor
T.cat
(Int -> Dim
T.Dim Int
0)
[ Int -> Tensor -> Tensor
T.oneHot Int
fifthSize Tensor
f1
, Int -> Tensor -> Tensor
T.oneHot Int
octaveSize Tensor
o1
, Int -> Tensor -> Tensor
T.oneHot Int
fifthSize Tensor
f2
, Int -> Tensor -> Tensor
T.oneHot Int
octaveSize Tensor
o2
]
where
toIndex :: Int -> Tensor
toIndex Int
i = Int -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => Int -> opt -> Tensor
T.asTensor' Int
i (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
T.Int64 (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TensorOptions
opts @dev
f1 :: Tensor
f1 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow
o1 :: Tensor
o1 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow
f2 :: Tensor
f2 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
fifths SPitch
p2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
fifthLow
o2 :: Tensor
o2 = Int -> Tensor
toIndex (Int -> Tensor) -> Int -> Tensor
forall a b. (a -> b) -> a -> b
$ SPitch -> Int
forall i. Spelled i => i -> Int
octaves SPitch
p2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
octaveLow
fifthLow :: Int
fifthLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @FifthLow Proxy# FifthLow
forall {k} (a :: k). Proxy# a
proxy#
octaveLow :: Int
octaveLow = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ forall {k} (n :: k). KnownInt n => Proxy# n -> Integer
forall (n :: TInt). KnownInt n => Proxy# n -> Integer
intVal' @OctaveLow Proxy# OctaveLow
forall {k} (a :: k). Proxy# a
proxy#
fifthSize :: Int
fifthSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @FifthSize
octaveSize :: Int
octaveSize = forall (n :: Natural). KnownNat n => Int
TT.natValI @OctaveSize
encodeTransition
:: (TT.KnownDevice dev)
=> Edges SPitch
-> TransitionEncoding dev '[]
encodeTransition :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
reg MultiSet (InnerEdge SPitch)
pass) =
TransitionEncoding
{ trencPassing :: QBoundedList dev QDType MaxPitches '[] (2 : PShape)
trencPassing = [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
edgesOneHots ([InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape))
-> [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall a b. (a -> b) -> a -> b
$ MultiSet (InnerEdge SPitch) -> [InnerEdge SPitch]
forall a. MultiSet a -> [a]
MS.toList MultiSet (InnerEdge SPitch)
pass
,
trencInner :: QBoundedList dev QDType MaxPitches '[] (2 : PShape)
trencInner = [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
edgesOneHots ([InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape))
-> [InnerEdge SPitch]
-> QBoundedList dev QDType MaxPitches '[] (2 : PShape)
forall a b. (a -> b) -> a -> b
$ ((StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (InnerEdge SPitch))
-> [InnerEdge SPitch]
forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (InnerEdge SPitch)
forall {a} {b}. (StartStop a, StartStop b) -> Maybe (a, b)
getInner
,
trencLeft :: SliceEncoding dev '[]
trencLeft = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum ([SPitch] -> SliceEncoding dev '[])
-> [SPitch] -> SliceEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch) -> [Note SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (Note SPitch))
-> [Note SPitch]
forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (Note SPitch)
forall {a} {a}. (StartStop a, StartStop a) -> Maybe a
getLeft
, trencRight :: SliceEncoding dev '[]
trencRight = [SPitch] -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
[SPitch] -> SliceEncoding dev '[]
pitchesOneHotSum ([SPitch] -> SliceEncoding dev '[])
-> [SPitch] -> SliceEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ Note SPitch -> SPitch
forall n. Note n -> n
notePitch (Note SPitch -> SPitch) -> [Note SPitch] -> [SPitch]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (Note SPitch))
-> [Note SPitch]
forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch))
-> Maybe (Note SPitch)
forall {a} {a}. (StartStop a, StartStop a) -> Maybe a
getRight
, trencRoot :: QTensor dev '[]
trencRoot = if (StartStop (Note SPitch), StartStop (Note SPitch))
-> HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
-> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
HS.member (StartStop (Note SPitch)
forall a. StartStop a
Start, StartStop (Note SPitch)
forall a. StartStop a
Stop) HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
reg then QTensor dev '[]
isRoot else QTensor dev '[]
isNoRoot
}
where
regulars :: [(StartStop (Note SPitch), StartStop (Note SPitch))]
regulars = HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
-> [(StartStop (Note SPitch), StartStop (Note SPitch))]
forall a. HashSet a -> [a]
HS.toList HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
reg
getEdges :: (Hashable a) => (Edge SPitch -> Maybe a) -> [a]
getEdges :: forall a.
Hashable a =>
((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [a]
getEdges (StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a
f = ((StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a)
-> [(StartStop (Note SPitch), StartStop (Note SPitch))] -> [a]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (StartStop (Note SPitch), StartStop (Note SPitch)) -> Maybe a
f [(StartStop (Note SPitch), StartStop (Note SPitch))]
regulars
getInner :: (StartStop a, StartStop b) -> Maybe (a, b)
getInner (Inner a
a, Inner b
b) = (a, b) -> Maybe (a, b)
forall a. a -> Maybe a
Just (a
a, b
b)
getInner (StartStop a, StartStop b)
_ = Maybe (a, b)
forall a. Maybe a
Nothing
getLeft :: (StartStop a, StartStop a) -> Maybe a
getLeft (StartStop a
Start, Inner a
b) = a -> Maybe a
forall a. a -> Maybe a
Just a
b
getLeft (StartStop a, StartStop a)
_ = Maybe a
forall a. Maybe a
Nothing
getRight :: (StartStop a, StartStop a) -> Maybe a
getRight (Inner a
a, StartStop a
Stop) = a -> Maybe a
forall a. a -> Maybe a
Just a
a
getRight (StartStop a, StartStop a)
_ = Maybe a
forall a. Maybe a
Nothing
isRoot :: QTensor dev '[]
isRoot = QTensor dev '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.ones
isNoRoot :: QTensor dev '[]
isNoRoot = QTensor dev '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
emptyTransition
:: (TT.KnownDevice dev) => TransitionEncoding dev '[]
emptyTransition :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition = Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges SPitch -> TransitionEncoding dev '[])
-> Edges SPitch -> TransitionEncoding dev '[]
forall a b. (a -> b) -> a -> b
$ HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
-> MultiSet (InnerEdge SPitch) -> Edges SPitch
forall n. HashSet (Edge n) -> MultiSet (InnerEdge n) -> Edges n
Edges HashSet (StartStop (Note SPitch), StartStop (Note SPitch))
forall a. HashSet a
HS.empty MultiSet (InnerEdge SPitch)
forall a. MultiSet a
MS.empty
data ActionTop dev batchShape = ActionTop
{ forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSl :: !(QStartStop dev batchShape (SliceEncoding dev batchShape))
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> TransitionEncoding dev batchShape
atopT1 :: !(TransitionEncoding dev batchShape)
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
atopSm :: !(QMaybe dev batchShape (SliceEncoding dev batchShape))
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
atopT2 :: !(QMaybe dev batchShape (TransitionEncoding dev batchShape))
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSr :: !(QStartStop dev batchShape (SliceEncoding dev batchShape))
}
deriving (Int -> ActionTop dev batchShape -> ShowS
[ActionTop dev batchShape] -> ShowS
ActionTop dev batchShape -> String
(Int -> ActionTop dev batchShape -> ShowS)
-> (ActionTop dev batchShape -> String)
-> ([ActionTop dev batchShape] -> ShowS)
-> Show (ActionTop dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionTop dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionTop dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionTop dev batchShape -> ShowS
showsPrec :: Int -> ActionTop dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> String
show :: ActionTop dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionTop dev batchShape] -> ShowS
showList :: [ActionTop dev batchShape] -> ShowS
Show, (forall x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x)
-> (forall x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape)
-> Generic (ActionTop dev batchShape)
forall x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
forall x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
from :: forall x.
ActionTop dev batchShape -> Rep (ActionTop dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
to :: forall x.
Rep (ActionTop dev batchShape) x -> ActionTop dev batchShape
Generic, ActionTop dev batchShape -> ()
(ActionTop dev batchShape -> ())
-> NFData (ActionTop dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> ()
rnf :: ActionTop dev batchShape -> ()
NFData)
instance Stackable (ActionTop dev batchShape) where
type Stacked (ActionTop dev batchShape) n = ActionTop dev (n ': batchShape)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
stack Vector (1 + n) (ActionTop dev batchShape)
xs = QStartStop
dev
((1 + n) : batchShape)
(SliceEncoding dev ((1 + n) : batchShape))
-> TransitionEncoding dev ((1 + n) : batchShape)
-> QMaybe
dev
((1 + n) : batchShape)
(SliceEncoding dev ((1 + n) : batchShape))
-> QMaybe
dev
((1 + n) : batchShape)
(TransitionEncoding dev ((1 + n) : batchShape))
-> QStartStop
dev
((1 + n) : batchShape)
(SliceEncoding dev ((1 + n) : batchShape))
-> ActionTop dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop QStartStop
dev
((1 + n) : batchShape)
(SliceEncoding dev ((1 + n) : batchShape))
Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sl TransitionEncoding dev ((1 + n) : batchShape)
Stacked (TransitionEncoding dev batchShape) (1 + n)
t1 QMaybe
dev
((1 + n) : batchShape)
(SliceEncoding dev ((1 + n) : batchShape))
Stacked
(QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sm QMaybe
dev
((1 + n) : batchShape)
(TransitionEncoding dev ((1 + n) : batchShape))
Stacked
(QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
t2 QStartStop
dev
((1 + n) : batchShape)
(SliceEncoding dev ((1 + n) : batchShape))
Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sr
where
sl :: Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sl = Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n))
-> Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSl (ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
Vector
(1 + n)
(QStartStop dev batchShape (SliceEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
t1 :: Stacked (TransitionEncoding dev batchShape) (1 + n)
t1 = Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n))
-> Vector (1 + n) (TransitionEncoding dev batchShape)
-> Stacked (TransitionEncoding dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape -> TransitionEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape -> TransitionEncoding dev batchShape
atopT1 (ActionTop dev batchShape -> TransitionEncoding dev batchShape)
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector Vector (1 + n) (TransitionEncoding dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
sm :: Stacked
(QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sm = Vector
(1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n))
-> Vector
(1 + n) (QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
atopSm (ActionTop dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
Vector
(1 + n)
(QMaybe dev batchShape (SliceEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
t2 :: Stacked
(QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
t2 = Vector
(1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (TransitionEncoding dev batchShape))
(1 + n))
-> Vector
(1 + n) (QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Stacked
(QMaybe dev batchShape (TransitionEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
atopT2 (ActionTop dev batchShape
-> QMaybe dev batchShape (TransitionEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
Vector
(1 + n)
(QMaybe dev batchShape (TransitionEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
sr :: Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
sr = Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n))
-> Vector
(1 + n) (QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Stacked
(QStartStop dev batchShape (SliceEncoding dev batchShape)) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
atopSr (ActionTop dev batchShape
-> QStartStop dev batchShape (SliceEncoding dev batchShape))
-> Vector Vector (1 + n) (ActionTop dev batchShape)
-> Vector
Vector
(1 + n)
(QStartStop dev batchShape (SliceEncoding dev batchShape))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionTop dev batchShape)
Vector (1 + n) (ActionTop dev batchShape)
xs
instance Batchable (ActionTop dev shape) where
type Batched (ActionTop dev shape) = ActionTop dev (1 : shape)
addBatchDim :: ActionTop dev shape -> Batched (ActionTop dev shape)
addBatchDim (ActionTop QStartStop dev shape (SliceEncoding dev shape)
sl TransitionEncoding dev shape
t1 QMaybe dev shape (SliceEncoding dev shape)
sm QMaybe dev shape (TransitionEncoding dev shape)
t2 QStartStop dev shape (SliceEncoding dev shape)
sr) =
QStartStop dev (1 : shape) (SliceEncoding dev (1 : shape))
-> TransitionEncoding dev (1 : shape)
-> QMaybe dev (1 : shape) (SliceEncoding dev (1 : shape))
-> QMaybe dev (1 : shape) (TransitionEncoding dev (1 : shape))
-> QStartStop dev (1 : shape) (SliceEncoding dev (1 : shape))
-> ActionTop dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop
(QStartStop dev shape (SliceEncoding dev shape)
-> Batched (QStartStop dev shape (SliceEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QStartStop dev shape (SliceEncoding dev shape)
sl)
(TransitionEncoding dev shape
-> Batched (TransitionEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim TransitionEncoding dev shape
t1)
(QMaybe dev shape (SliceEncoding dev shape)
-> Batched (QMaybe dev shape (SliceEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QMaybe dev shape (SliceEncoding dev shape)
sm)
(QMaybe dev shape (TransitionEncoding dev shape)
-> Batched (QMaybe dev shape (TransitionEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QMaybe dev shape (TransitionEncoding dev shape)
t2)
(QStartStop dev shape (SliceEncoding dev shape)
-> Batched (QStartStop dev shape (SliceEncoding dev shape))
forall a. Batchable a => a -> Batched a
addBatchDim QStartStop dev shape (SliceEncoding dev shape)
sr)
data ActionEncoding dev batchShape = ActionEncoding
{ forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ActionTop dev batchShape
actionEncodingTop :: !(ActionTop dev batchShape)
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape
actionEncodingOp :: !(TT.Tensor dev 'TT.Int64 batchShape)
}
deriving (Int -> ActionEncoding dev batchShape -> ShowS
[ActionEncoding dev batchShape] -> ShowS
ActionEncoding dev batchShape -> String
(Int -> ActionEncoding dev batchShape -> ShowS)
-> (ActionEncoding dev batchShape -> String)
-> ([ActionEncoding dev batchShape] -> ShowS)
-> Show (ActionEncoding dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionEncoding dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionEncoding dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> ActionEncoding dev batchShape -> ShowS
showsPrec :: Int -> ActionEncoding dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> String
show :: ActionEncoding dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[ActionEncoding dev batchShape] -> ShowS
showList :: [ActionEncoding dev batchShape] -> ShowS
Show, (forall x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x)
-> (forall x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape)
-> Generic (ActionEncoding dev batchShape)
forall x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
forall x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
from :: forall x.
ActionEncoding dev batchShape
-> Rep (ActionEncoding dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
to :: forall x.
Rep (ActionEncoding dev batchShape) x
-> ActionEncoding dev batchShape
Generic, ActionEncoding dev batchShape -> ()
(ActionEncoding dev batchShape -> ())
-> NFData (ActionEncoding dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ()
rnf :: ActionEncoding dev batchShape -> ()
NFData)
instance Stackable (ActionEncoding dev batchShape) where
type Stacked (ActionEncoding dev batchShape) n = ActionEncoding dev (n ': batchShape)
stack :: forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionEncoding dev batchShape)
-> Stacked (ActionEncoding dev batchShape) (1 + n)
stack Vector (1 + n) (ActionEncoding dev batchShape)
xs = ActionTop dev ((1 + n) : batchShape)
-> Tensor dev 'Int64 ((1 + n) : batchShape)
-> ActionEncoding dev ((1 + n) : batchShape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding ActionTop dev ((1 + n) : batchShape)
Stacked (ActionTop dev batchShape) (1 + n)
tops Tensor dev 'Int64 ((1 + n) : batchShape)
Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
ops
where
tops :: Stacked (ActionTop dev batchShape) (1 + n)
tops = Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n))
-> Vector (1 + n) (ActionTop dev batchShape)
-> Stacked (ActionTop dev batchShape) (1 + n)
forall a b. (a -> b) -> a -> b
$ ActionEncoding dev batchShape -> ActionTop dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> ActionTop dev batchShape
actionEncodingTop (ActionEncoding dev batchShape -> ActionTop dev batchShape)
-> Vector Vector (1 + n) (ActionEncoding dev batchShape)
-> Vector Vector (1 + n) (ActionTop dev batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionEncoding dev batchShape)
Vector (1 + n) (ActionEncoding dev batchShape)
xs
ops :: Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
ops = forall (dim :: Natural) (n :: Natural) (shape :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, KnownNat n) =>
Vector n (Tensor device dtype shape)
-> Tensor device dtype (Insert dim n shape)
TT.vecStack @0 (Vector (1 + n) (Tensor dev 'Int64 batchShape)
-> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape))
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
-> Tensor dev 'Int64 (Insert 0 (1 + n) batchShape)
forall a b. (a -> b) -> a -> b
$ ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape
actionEncodingOp (ActionEncoding dev batchShape -> Tensor dev 'Int64 batchShape)
-> Vector Vector (1 + n) (ActionEncoding dev batchShape)
-> Vector (1 + n) (Tensor dev 'Int64 batchShape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Vector (1 + n) (ActionEncoding dev batchShape)
Vector (1 + n) (ActionEncoding dev batchShape)
xs
instance Batchable (ActionEncoding dev shape) where
type Batched (ActionEncoding dev shape) = ActionEncoding dev (1 : shape)
addBatchDim :: ActionEncoding dev shape -> Batched (ActionEncoding dev shape)
addBatchDim (ActionEncoding ActionTop dev shape
top Tensor dev 'Int64 shape
op) = ActionTop dev (1 : shape)
-> Tensor dev 'Int64 (1 : shape) -> ActionEncoding dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding (ActionTop dev shape -> Batched (ActionTop dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim ActionTop dev shape
top) (forall (dim :: Natural) (shape :: [Natural]) (shape' :: [Natural])
(dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev 'Int64 shape
op)
encodePVAction
:: (TT.KnownDevice dev)
=> PVAction
-> ActionEncoding dev '[]
encodePVAction :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction (Left (ActionSingle SingleParent (Notes SPitch) (Edges SPitch)
top LeftmostSingle (Split SPitch) (Freeze SPitch)
action)) = ActionTop dev '[]
-> Tensor dev 'Int64 '[] -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding ActionTop dev '[]
encTop Tensor dev 'Int64 '[]
encAction
where
(SingleParent StartStop (Notes SPitch)
sl Edges SPitch
t StartStop (Notes SPitch)
sr) = SingleParent (Notes SPitch) (Edges SPitch)
top
encTop :: ActionTop dev '[]
encTop =
QStartStop dev '[] (SliceEncoding dev '[])
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (SliceEncoding dev '[])
-> QMaybe dev '[] (TransitionEncoding dev '[])
-> QStartStop dev '[] (SliceEncoding dev '[])
-> ActionTop dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop
((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sl)
(Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition Edges SPitch
t)
(SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing (SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[]))
-> SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice)
(TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing (TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[]))
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition)
((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sr)
encAction :: Tensor dev 'Int64 '[]
encAction = case LeftmostSingle (Split SPitch) (Freeze SPitch)
action of
LMSingleFreeze Freeze SPitch
_freeze -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
0 :: Int)
LMSingleSplit Split SPitch
_split -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
1 :: Int)
encodePVAction (Right (ActionDouble DoubleParent (Notes SPitch) (Edges SPitch)
top LeftmostDouble (Split SPitch) (Freeze SPitch) (Spread SPitch)
action)) = ActionTop dev '[]
-> Tensor dev 'Int64 '[] -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionTop dev batchShape
-> Tensor dev 'Int64 batchShape -> ActionEncoding dev batchShape
ActionEncoding ActionTop dev '[]
encTop Tensor dev 'Int64 '[]
encAction
where
(DoubleParent StartStop (Notes SPitch)
sl Edges SPitch
t1 Notes SPitch
sm Edges SPitch
t2 StartStop (Notes SPitch)
sr) = DoubleParent (Notes SPitch) (Edges SPitch)
top
encTop :: ActionTop dev '[]
encTop =
QStartStop dev '[] (SliceEncoding dev '[])
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (SliceEncoding dev '[])
-> QMaybe dev '[] (TransitionEncoding dev '[])
-> QStartStop dev '[] (SliceEncoding dev '[])
-> ActionTop dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QStartStop dev batchShape (SliceEncoding dev batchShape)
-> TransitionEncoding dev batchShape
-> QMaybe dev batchShape (SliceEncoding dev batchShape)
-> QMaybe dev batchShape (TransitionEncoding dev batchShape)
-> QStartStop dev batchShape (SliceEncoding dev batchShape)
-> ActionTop dev batchShape
ActionTop
((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sl)
(Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition Edges SPitch
t1)
(SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust (SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[]))
-> SliceEncoding dev '[] -> QMaybe dev '[] (SliceEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice Notes SPitch
sm)
(TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust (TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[]))
-> TransitionEncoding dev '[]
-> QMaybe dev '[] (TransitionEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition Edges SPitch
t2)
((Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice StartStop (Notes SPitch)
sr)
encAction :: Tensor dev 'Int64 '[]
encAction = case LeftmostDouble (Split SPitch) (Freeze SPitch) (Spread SPitch)
action of
LMDoubleFreezeLeft Freeze SPitch
_freeze -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
2 :: Int)
LMDoubleSpread Spread SPitch
_spread -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
3 :: Int)
LMDoubleSplitLeft Split SPitch
_split -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
4 :: Int)
LMDoubleSplitRight Split SPitch
_split -> Int -> Tensor dev 'Int64 '[]
forall (shape :: [Natural]) (dtype :: DType)
(device :: (DeviceType, Natural)) a.
(TensorOptions shape dtype device, Scalar a) =>
a -> Tensor device dtype shape
TT.full (Int
5 :: Int)
data StateEncoding dev = StateEncoding
{ forall (dev :: (DeviceType, Natural)).
StateEncoding dev -> QStartStop dev '[] (SliceEncoding dev '[])
stateEncodingMid :: !(QStartStop dev '[] (SliceEncoding dev '[]))
, forall (dev :: (DeviceType, Natural)).
StateEncoding dev
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
stateEncodingFrozen :: !(QMaybe dev '[] (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
, forall (dev :: (DeviceType, Natural)).
StateEncoding dev
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
stateEncodingOpen :: !(QMaybe dev '[] (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
}
deriving (Int -> StateEncoding dev -> ShowS
[StateEncoding dev] -> ShowS
StateEncoding dev -> String
(Int -> StateEncoding dev -> ShowS)
-> (StateEncoding dev -> String)
-> ([StateEncoding dev] -> ShowS)
-> Show (StateEncoding dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)).
Int -> StateEncoding dev -> ShowS
forall (dev :: (DeviceType, Natural)). [StateEncoding dev] -> ShowS
forall (dev :: (DeviceType, Natural)). StateEncoding dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)).
Int -> StateEncoding dev -> ShowS
showsPrec :: Int -> StateEncoding dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)). StateEncoding dev -> String
show :: StateEncoding dev -> String
$cshowList :: forall (dev :: (DeviceType, Natural)). [StateEncoding dev] -> ShowS
showList :: [StateEncoding dev] -> ShowS
Show, (forall x. StateEncoding dev -> Rep (StateEncoding dev) x)
-> (forall x. Rep (StateEncoding dev) x -> StateEncoding dev)
-> Generic (StateEncoding dev)
forall x. Rep (StateEncoding dev) x -> StateEncoding dev
forall x. StateEncoding dev -> Rep (StateEncoding dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) x.
Rep (StateEncoding dev) x -> StateEncoding dev
forall (dev :: (DeviceType, Natural)) x.
StateEncoding dev -> Rep (StateEncoding dev) x
$cfrom :: forall (dev :: (DeviceType, Natural)) x.
StateEncoding dev -> Rep (StateEncoding dev) x
from :: forall x. StateEncoding dev -> Rep (StateEncoding dev) x
$cto :: forall (dev :: (DeviceType, Natural)) x.
Rep (StateEncoding dev) x -> StateEncoding dev
to :: forall x. Rep (StateEncoding dev) x -> StateEncoding dev
Generic, StateEncoding dev -> ()
(StateEncoding dev -> ()) -> NFData (StateEncoding dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)). StateEncoding dev -> ()
$crnf :: forall (dev :: (DeviceType, Natural)). StateEncoding dev -> ()
rnf :: StateEncoding dev -> ()
NFData)
getFrozen
:: forall dev t
. (Foldable t, TT.KnownDevice dev)
=> Path (Maybe (t (Edge SPitch))) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen :: forall (dev :: (DeviceType, Natural)) (t :: Type -> Type).
(Foldable t, KnownDevice dev) =>
Path
(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
(Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen Path
(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
(Notes SPitch)
frozen = ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]]
trEncs, [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs)
where
([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))]
trs, [StartStop (Notes SPitch)]
slcs) = [(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
StartStop (Notes SPitch))]
-> ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))],
[StartStop (Notes SPitch)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
StartStop (Notes SPitch))]
-> ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))],
[StartStop (Notes SPitch)]))
-> [(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
StartStop (Notes SPitch))]
-> ([Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))],
[StartStop (Notes SPitch)])
forall a b. (a -> b) -> a -> b
$ Int
-> (Notes SPitch -> StartStop (Notes SPitch))
-> StartStop (Notes SPitch)
-> Path
(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
(Notes SPitch)
-> [(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))),
StartStop (Notes SPitch))]
forall b b' a. Int -> (b -> b') -> b' -> Path a b -> [(a, b')]
pathTake Int
8 Notes SPitch -> StartStop (Notes SPitch)
forall a. a -> StartStop a
Inner StartStop (Notes SPitch)
forall a. StartStop a
Start Path
(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
(Notes SPitch)
frozen
trEncs :: [TransitionEncoding dev '[]]
trEncs = Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges SPitch -> TransitionEncoding dev '[])
-> (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
-> Edges SPitch)
-> Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
-> TransitionEncoding dev '[]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
-> Edges SPitch
forall (t :: Type -> Type) n.
(Foldable t, Ord n, Hashable n) =>
Maybe (t (Edge n)) -> Edges n
pvThaw (Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))
-> TransitionEncoding dev '[])
-> [Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))]
-> [TransitionEncoding dev '[]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch)))]
trs
slcEncs :: [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs = (Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice (StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[]))
-> [StartStop (Notes SPitch)]
-> [QStartStop dev '[] (SliceEncoding dev '[])]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [StartStop (Notes SPitch)]
slcs
getOpen
:: (TT.KnownDevice dev)
=> Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen Path (Edges SPitch) (Notes SPitch)
open = ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]]
trEncs, [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs)
where
([Edges SPitch]
trs, [StartStop (Notes SPitch)]
slcs) = [(Edges SPitch, StartStop (Notes SPitch))]
-> ([Edges SPitch], [StartStop (Notes SPitch)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Edges SPitch, StartStop (Notes SPitch))]
-> ([Edges SPitch], [StartStop (Notes SPitch)]))
-> [(Edges SPitch, StartStop (Notes SPitch))]
-> ([Edges SPitch], [StartStop (Notes SPitch)])
forall a b. (a -> b) -> a -> b
$ Int
-> (Notes SPitch -> StartStop (Notes SPitch))
-> StartStop (Notes SPitch)
-> Path (Edges SPitch) (Notes SPitch)
-> [(Edges SPitch, StartStop (Notes SPitch))]
forall b b' a. Int -> (b -> b') -> b' -> Path a b -> [(a, b')]
pathTake Int
8 Notes SPitch -> StartStop (Notes SPitch)
forall a. a -> StartStop a
Inner StartStop (Notes SPitch)
forall a. StartStop a
Stop Path (Edges SPitch) (Notes SPitch)
open
trEncs :: [TransitionEncoding dev '[]]
trEncs = Edges SPitch -> TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Edges SPitch -> TransitionEncoding dev '[]
encodeTransition (Edges SPitch -> TransitionEncoding dev '[])
-> [Edges SPitch] -> [TransitionEncoding dev '[]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Edges SPitch]
trs
slcEncs :: [QStartStop dev '[] (SliceEncoding dev '[])]
slcEncs = (Notes SPitch -> SliceEncoding dev '[])
-> SliceEncoding dev '[]
-> StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a
b.
TensorOptions batchShape 'Int64 dev =>
(a -> b) -> b -> StartStop a -> QStartStop dev batchShape b
qStartStop Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice (StartStop (Notes SPitch)
-> QStartStop dev '[] (SliceEncoding dev '[]))
-> [StartStop (Notes SPitch)]
-> [QStartStop dev '[] (SliceEncoding dev '[])]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [StartStop (Notes SPitch)]
slcs
encodePVState
:: (TT.KnownDevice dev)
=> PVState
-> StateEncoding dev
encodePVState :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> StateEncoding dev
encodePVState (GSFrozen Path
(Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
(Notes SPitch)
frozen) =
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
StateEncoding
(SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice)
((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path
(Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
(Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)) (t :: Type -> Type).
(Foldable t, KnownDevice dev) =>
Path
(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
(Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen Path
(Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
(Notes SPitch)
frozen)
((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition], [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStop SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice]))
encodePVState (GSOpen Path (Edges SPitch) (Notes SPitch)
open [PVLeftmost SPitch]
_) =
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
StateEncoding
(SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice)
((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qNothing ([TransitionEncoding dev '[]]
-> Stacked (TransitionEncoding dev '[]) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [TransitionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
TransitionEncoding dev '[]
emptyTransition], [QStartStop dev '[] (SliceEncoding dev '[])]
-> Stacked (QStartStop dev '[] (SliceEncoding dev '[])) FakeSize
forall a. Stackable a => [a] -> Stacked a FakeSize
stackUnsafe [SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qStart SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
SliceEncoding dev '[]
emptySlice]))
((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen Path (Edges SPitch) (Notes SPitch)
open)
encodePVState (GSSemiOpen Path
(Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
(Notes SPitch)
frozen Notes SPitch
mid Path (Edges SPitch) (Notes SPitch)
open [PVLeftmost SPitch]
_) =
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
QStartStop dev '[] (SliceEncoding dev '[])
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> StateEncoding dev
StateEncoding
(SliceEncoding dev '[] -> QStartStop dev '[] (SliceEncoding dev '[])
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape 'Int64 dev =>
a -> QStartStop dev batchShape a
qInner (SliceEncoding dev '[]
-> QStartStop dev '[] (SliceEncoding dev '[]))
-> SliceEncoding dev '[]
-> QStartStop dev '[] (SliceEncoding dev '[])
forall a b. (a -> b) -> a -> b
$ Notes SPitch -> SliceEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Notes SPitch -> SliceEncoding dev '[]
encodeSlice Notes SPitch
mid)
((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path
(Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
(Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)) (t :: Type -> Type).
(Foldable t, KnownDevice dev) =>
Path
(Maybe (t (StartStop (Note SPitch), StartStop (Note SPitch))))
(Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getFrozen Path
(Maybe [(StartStop (Note SPitch), StartStop (Note SPitch))])
(Notes SPitch)
frozen)
((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (batchShape :: [Natural]) (dev :: (DeviceType, Natural)) a.
TensorOptions batchShape QDType dev =>
a -> QMaybe dev batchShape a
qJust ((TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])))
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall a b. (a -> b) -> a -> b
$ Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
Path (Edges SPitch) (Notes SPitch)
-> (TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
getOpen Path (Edges SPitch) (Notes SPitch)
open)
data QEncoding dev batchShape = QEncoding
{ forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> ActionEncoding dev batchShape
qActionEncoding :: !(ActionEncoding dev batchShape)
, forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> StateEncoding dev
qStateEncoding :: !(StateEncoding dev)
}
deriving (Int -> QEncoding dev batchShape -> ShowS
[QEncoding dev batchShape] -> ShowS
QEncoding dev batchShape -> String
(Int -> QEncoding dev batchShape -> ShowS)
-> (QEncoding dev batchShape -> String)
-> ([QEncoding dev batchShape] -> ShowS)
-> Show (QEncoding dev batchShape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> QEncoding dev batchShape -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[QEncoding dev batchShape] -> ShowS
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
Int -> QEncoding dev batchShape -> ShowS
showsPrec :: Int -> QEncoding dev batchShape -> ShowS
$cshow :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> String
show :: QEncoding dev batchShape -> String
$cshowList :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
[QEncoding dev batchShape] -> ShowS
showList :: [QEncoding dev batchShape] -> ShowS
Show, (forall x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x)
-> (forall x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape)
-> Generic (QEncoding dev batchShape)
forall x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
forall x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
$cfrom :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
from :: forall x.
QEncoding dev batchShape -> Rep (QEncoding dev batchShape) x
$cto :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]) x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
to :: forall x.
Rep (QEncoding dev batchShape) x -> QEncoding dev batchShape
Generic, QEncoding dev batchShape -> ()
(QEncoding dev batchShape -> ())
-> NFData (QEncoding dev batchShape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> ()
$crnf :: forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
QEncoding dev batchShape -> ()
rnf :: QEncoding dev batchShape -> ()
NFData)
instance Batchable (QEncoding dev shape) where
type Batched (QEncoding dev shape) = QEncoding dev (1 : shape)
addBatchDim :: QEncoding dev shape -> Batched (QEncoding dev shape)
addBatchDim (QEncoding ActionEncoding dev shape
ac StateEncoding dev
st) = ActionEncoding dev (1 : shape)
-> StateEncoding dev -> QEncoding dev (1 : shape)
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape
-> StateEncoding dev -> QEncoding dev batchShape
QEncoding (ActionEncoding dev shape -> Batched (ActionEncoding dev shape)
forall a. Batchable a => a -> Batched a
addBatchDim ActionEncoding dev shape
ac) StateEncoding dev
st
encodeStep
:: (TT.KnownDevice dev)
=> PVState
-> PVAction
-> QEncoding dev '[]
encodeStep :: forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> PVAction -> QEncoding dev '[]
encodeStep PVState
state PVAction
action =
ActionEncoding dev '[] -> StateEncoding dev -> QEncoding dev '[]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape
-> StateEncoding dev -> QEncoding dev batchShape
QEncoding
(PVAction -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction PVAction
action)
(PVState -> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> StateEncoding dev
encodePVState PVState
state)
withBatchedEncoding
:: forall dev r
. (TT.KnownDevice dev)
=> PVState
-> NonEmpty PVAction
-> (forall n. (KnownNat n) => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding :: forall (dev :: (DeviceType, Natural)) r.
KnownDevice dev =>
PVState
-> NonEmpty PVAction
-> (forall (n :: Natural). KnownNat n => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding PVState
state (PVAction
a0 :| [PVAction]
actions) forall (n :: Natural). KnownNat n => QEncoding dev '[n] -> r
f =
[ActionEncoding dev '[]]
-> (forall (n :: Natural).
KnownNat n =>
Vector n (ActionEncoding dev '[]) -> r)
-> r
forall a r.
[a] -> (forall (n :: Natural). KnownNat n => Vector n a -> r) -> r
VS.withSizedList [ActionEncoding dev '[]]
aEncs Vector n (ActionEncoding dev '[]) -> r
forall (n :: Natural).
KnownNat n =>
Vector n (ActionEncoding dev '[]) -> r
inner
where
inner :: forall n. (KnownNat n) => VS.Vector n (ActionEncoding dev '[]) -> r
inner :: forall (n :: Natural).
KnownNat n =>
Vector n (ActionEncoding dev '[]) -> r
inner Vector n (ActionEncoding dev '[])
aEncs' = QEncoding dev '[1 + n] -> r
forall (n :: Natural). KnownNat n => QEncoding dev '[n] -> r
f (QEncoding dev '[1 + n] -> r) -> QEncoding dev '[1 + n] -> r
forall a b. (a -> b) -> a -> b
$ ActionEncoding dev '[1 + n]
-> StateEncoding dev -> QEncoding dev '[1 + n]
forall (dev :: (DeviceType, Natural)) (batchShape :: [Natural]).
ActionEncoding dev batchShape
-> StateEncoding dev -> QEncoding dev batchShape
QEncoding (Vector (1 + n) (ActionEncoding dev '[])
-> Stacked (ActionEncoding dev '[]) (1 + n)
forall (n :: Natural).
(KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) (ActionEncoding dev '[])
-> Stacked (ActionEncoding dev '[]) (1 + n)
forall a (n :: Natural).
(Stackable a, KnownNat n, KnownNat (1 + n)) =>
Vector (1 + n) a -> Stacked a (1 + n)
stack (ActionEncoding dev '[]
-> Vector n (ActionEncoding dev '[])
-> Vector (1 + n) (ActionEncoding dev '[])
forall (n :: Natural) a. a -> Vector n a -> Vector (1 + n) a
VS.cons ActionEncoding dev '[]
a0Enc Vector n (ActionEncoding dev '[])
aEncs')) StateEncoding dev
sEnc
a0Enc :: ActionEncoding dev '[]
a0Enc = PVAction -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction PVAction
a0
aEncs :: [ActionEncoding dev '[]]
aEncs = PVAction -> ActionEncoding dev '[]
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVAction -> ActionEncoding dev '[]
encodePVAction (PVAction -> ActionEncoding dev '[])
-> [PVAction] -> [ActionEncoding dev '[]]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [PVAction]
actions
sEnc :: StateEncoding dev
sEnc = PVState -> StateEncoding dev
forall (dev :: (DeviceType, Natural)).
KnownDevice dev =>
PVState -> StateEncoding dev
encodePVState PVState
state