{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# HLINT ignore "Use <$>" #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -Wredundant-constraints #-}
module RL.Model where
import Common
import Control.DeepSeq
import Data.Foldable qualified as F
import Data.Kind (Type)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (type (:~:) (Refl), type (==))
import Data.TypeNums (KnownNat, Nat, TInt (..), intVal, intVal', type (*), type (+), type (-), type (<=))
import Debug.Trace qualified as DT
import GHC.ForeignPtr qualified as Ptr
import GHC.Generics (Generic)
import GHC.TypeLits (OrderingI (..), cmpNat, sameNat)
import GreedyParser (DoubleParent (DoubleParent), SingleParent (SingleParent))
import NoThunks.Class (NoThunks (..), OnlyCheckWhnf (..), allNoThunks)
import RL.Encoding
import RL.ModelTypes
import RL.TorchHelpers (withBatchDim)
import RL.TorchHelpers qualified as TH
import Torch qualified as T
import Torch.Jit qualified as TJit
import Torch.Lens qualified as TL
import Torch.Typed qualified as TT
import System.IO.Unsafe
import Torch.Internal.Cast (cast2)
import Torch.Internal.Managed.Type.Tensor qualified as ATen
activation :: (IsValidDevice dev) => QTensor dev shape -> QTensor dev shape
activation :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation = Tensor dev QDType shape -> Tensor dev QDType shape
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
GeluDTypeIsValid device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.gelu
expandAs :: T.Tensor -> T.Tensor -> T.Tensor
expandAs :: Tensor -> Tensor -> Tensor
expandAs Tensor
t1 Tensor
t2 = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Tensor -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_expand_as_t Tensor
t1 Tensor
t2
traceDyn :: TT.Tensor a b c -> TT.Tensor a b c
traceDyn :: forall (a :: (DeviceType, Nat)) (b :: DType) (c :: [Nat]).
Tensor a b c -> Tensor a b c
traceDyn Tensor a b c
t = [Int] -> Tensor a b c -> Tensor a b c
forall a b. Show a => a -> b -> b
DT.traceShow (Tensor -> [Int]
T.shape (Tensor -> [Int]) -> Tensor -> [Int]
forall a b. (a -> b) -> a -> b
$ Tensor a b c -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor a b c
t) Tensor a b c
t
unsafeReshape :: [Int] -> TT.Tensor dev dtype shape -> TT.Tensor dev dtype shape'
unsafeReshape :: forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
(shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [Int]
shape Tensor dev dtype shape
t = Tensor -> Tensor dev dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev dtype shape')
-> Tensor -> Tensor dev dtype shape'
forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
T.reshape [Int]
shape (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor dev dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor dev dtype shape
t
data ConstEmbSpec dev (shape :: [Nat]) = ConstEmbSpec
newtype ConstEmb dev shape = ConstEmb (TT.Parameter dev QDType shape)
deriving (Int -> ConstEmb dev shape -> ShowS
[ConstEmb dev shape] -> ShowS
ConstEmb dev shape -> String
(Int -> ConstEmb dev shape -> ShowS)
-> (ConstEmb dev shape -> String)
-> ([ConstEmb dev shape] -> ShowS)
-> Show (ConstEmb dev shape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Int -> ConstEmb dev shape -> ShowS
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
[ConstEmb dev shape] -> ShowS
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Int -> ConstEmb dev shape -> ShowS
showsPrec :: Int -> ConstEmb dev shape -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> String
show :: ConstEmb dev shape -> String
$cshowList :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
[ConstEmb dev shape] -> ShowS
showList :: [ConstEmb dev shape] -> ShowS
Show, (forall x. ConstEmb dev shape -> Rep (ConstEmb dev shape) x)
-> (forall x. Rep (ConstEmb dev shape) x -> ConstEmb dev shape)
-> Generic (ConstEmb dev shape)
forall x. Rep (ConstEmb dev shape) x -> ConstEmb dev shape
forall x. ConstEmb dev shape -> Rep (ConstEmb dev shape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
Rep (ConstEmb dev shape) x -> ConstEmb dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
ConstEmb dev shape -> Rep (ConstEmb dev shape) x
$cfrom :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
ConstEmb dev shape -> Rep (ConstEmb dev shape) x
from :: forall x. ConstEmb dev shape -> Rep (ConstEmb dev shape) x
$cto :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
Rep (ConstEmb dev shape) x -> ConstEmb dev shape
to :: forall x. Rep (ConstEmb dev shape) x -> ConstEmb dev shape
Generic)
deriving newtype (ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
(ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape)))
-> (ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape)
-> Parameterized (ConstEmb dev shape)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
$cflattenParameters :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
flattenParameters :: ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
replaceParameters :: ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
TT.Parameterized, ConstEmb dev shape -> ()
(ConstEmb dev shape -> ()) -> NFData (ConstEmb dev shape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> ()
$crnf :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> ()
rnf :: ConstEmb dev shape -> ()
NFData, Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
Proxy (ConstEmb dev shape) -> String
(Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo))
-> (Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo))
-> (Proxy (ConstEmb dev shape) -> String)
-> NoThunks (ConstEmb dev shape)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Proxy (ConstEmb dev shape) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
noThunks :: Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Proxy (ConstEmb dev shape) -> String
showTypeOf :: Proxy (ConstEmb dev shape) -> String
NoThunks)
instance
(IsValidDevice dev, TT.TensorOptions shape QDType dev)
=> T.Randomizable (ConstEmbSpec dev shape) (ConstEmb dev shape)
where
sample :: ConstEmbSpec dev shape -> IO (ConstEmb dev shape)
sample :: ConstEmbSpec dev shape -> IO (ConstEmb dev shape)
sample ConstEmbSpec dev shape
ConstEmbSpec = Parameter dev QDType shape -> ConstEmb dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Parameter dev QDType shape -> ConstEmb dev shape
ConstEmb (Parameter dev QDType shape -> ConstEmb dev shape)
-> IO (Parameter dev QDType shape) -> IO (ConstEmb dev shape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor dev QDType shape -> IO (Parameter dev QDType shape)
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
TT.makeIndependent (Tensor dev QDType shape -> IO (Parameter dev QDType shape))
-> IO (Tensor dev QDType shape) -> IO (Parameter dev QDType shape)
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor dev QDType shape)
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
TT.randn)
instance T.HasForward (ConstEmb dev size) () (QTensor dev size) where
forward :: ConstEmb dev size -> () -> QTensor dev size
forward :: ConstEmb dev size -> () -> QTensor dev size
forward (ConstEmb Parameter dev QDType size
emb) () = Parameter dev QDType size -> QTensor dev size
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
TT.toDependent Parameter dev QDType size
emb
forwardStoch :: ConstEmb dev size -> () -> IO (QTensor dev size)
forwardStoch :: ConstEmb dev size -> () -> IO (QTensor dev size)
forwardStoch ConstEmb dev size
model ()
input = QTensor dev size -> IO (QTensor dev size)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev size -> IO (QTensor dev size))
-> QTensor dev size -> IO (QTensor dev size)
forall a b. (a -> b) -> a -> b
$ ConstEmb dev size -> () -> QTensor dev size
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev size
model ()
input
data SliceSpec dev = SliceSpec
data SliceEncoder dev = SliceEncoder
{ forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
_slcL1 :: !(TT.Conv2d 1 QSliceHidden 1 1 QDType dev)
, forall (dev :: (DeviceType, Nat)).
SliceEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_slcL2 :: !(TT.Conv2d QSliceHidden EmbSize FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> ConstEmb dev EmbShape
_slcStart :: !(ConstEmb dev EmbShape)
, forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> ConstEmb dev EmbShape
_slcStop :: !(ConstEmb dev EmbShape)
}
deriving (Int -> SliceEncoder dev -> ShowS
[SliceEncoder dev] -> ShowS
SliceEncoder dev -> String
(Int -> SliceEncoder dev -> ShowS)
-> (SliceEncoder dev -> String)
-> ([SliceEncoder dev] -> ShowS)
-> Show (SliceEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)). Int -> SliceEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [SliceEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)). Int -> SliceEncoder dev -> ShowS
showsPrec :: Int -> SliceEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> String
show :: SliceEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [SliceEncoder dev] -> ShowS
showList :: [SliceEncoder dev] -> ShowS
Show, (forall x. SliceEncoder dev -> Rep (SliceEncoder dev) x)
-> (forall x. Rep (SliceEncoder dev) x -> SliceEncoder dev)
-> Generic (SliceEncoder dev)
forall x. Rep (SliceEncoder dev) x -> SliceEncoder dev
forall x. SliceEncoder dev -> Rep (SliceEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (SliceEncoder dev) x -> SliceEncoder dev
forall (dev :: (DeviceType, Nat)) x.
SliceEncoder dev -> Rep (SliceEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
SliceEncoder dev -> Rep (SliceEncoder dev) x
from :: forall x. SliceEncoder dev -> Rep (SliceEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (SliceEncoder dev) x -> SliceEncoder dev
to :: forall x. Rep (SliceEncoder dev) x -> SliceEncoder dev
Generic, SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
(SliceEncoder dev -> HList (Parameters (SliceEncoder dev)))
-> (SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev)
-> Parameterized (SliceEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
forall (dev :: (DeviceType, Nat)).
SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
flattenParameters :: SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
replaceParameters :: SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
TT.Parameterized, Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
Proxy (SliceEncoder dev) -> String
(Context -> SliceEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> SliceEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (SliceEncoder dev) -> String)
-> NoThunks (SliceEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (SliceEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (SliceEncoder dev) -> String
showTypeOf :: Proxy (SliceEncoder dev) -> String
NoThunks, SliceEncoder dev -> ()
(SliceEncoder dev -> ()) -> NFData (SliceEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> ()
rnf :: SliceEncoder dev -> ()
NFData)
instance (IsValidDevice dev) => T.Randomizable (SliceSpec dev) (SliceEncoder dev) where
sample :: SliceSpec dev -> IO (SliceEncoder dev)
sample :: SliceSpec dev -> IO (SliceEncoder dev)
sample SliceSpec dev
_ =
Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> SliceEncoder dev
Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> ConstEmb dev EmbShape
-> ConstEmb dev EmbShape
-> SliceEncoder dev
forall (dev :: (DeviceType, Nat)).
Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> ConstEmb dev EmbShape
-> ConstEmb dev EmbShape
-> SliceEncoder dev
SliceEncoder
(Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> SliceEncoder dev)
-> IO (Conv2d 1 QSliceHidden 1 1 QDType dev)
-> IO
(Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> SliceEncoder dev)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Conv2dSpec 1 QSliceHidden 1 1 QDType dev
-> IO (Conv2d 1 QSliceHidden 1 1 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec 1 QSliceHidden 1 1 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
TT.Conv2dSpec
IO
(Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> SliceEncoder dev)
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
-> IO
(ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
TT.Conv2dSpec
IO
(ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
-> IO (ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ConstEmbSpec dev '[QSliceHidden, 13, 5]
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
forall spec f. Randomizable spec f => spec -> IO f
T.sample (forall {k} (dev :: k) (shape :: [Nat]). ConstEmbSpec dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmbSpec dev shape
ConstEmbSpec @dev)
IO (ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
-> IO (SliceEncoder dev)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ConstEmbSpec dev '[QSliceHidden, 13, 5]
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
forall spec f. Randomizable spec f => spec -> IO f
T.sample (forall {k} (dev :: k) (shape :: [Nat]). ConstEmbSpec dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmbSpec dev shape
ConstEmbSpec @dev)
instance
(embshape ~ EmbShape, IsValidDevice dev)
=> T.HasForward (SliceEncoder dev) (SliceEncoding dev '[]) (QTensor dev embshape)
where
forward :: SliceEncoder dev -> SliceEncoding dev '[] -> QTensor dev embshape
forward (SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 ConstEmb dev EmbShape
_ ConstEmb dev EmbShape
_) SliceEncoding dev '[]
slice = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 Tensor dev QDType '[1, QSliceHidden, 13, 5]
QTensor dev (1 : EmbShape)
out2
where
input :: Tensor dev QDType '[1, 1, 13, 5]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 (Tensor dev QDType '[1, 13, 5] -> Tensor dev QDType '[1, 1, 13, 5])
-> Tensor dev QDType '[1, 13, 5]
-> Tensor dev QDType '[1, 1, 13, 5]
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 (Tensor dev QDType '[13, 5] -> Tensor dev QDType '[1, 13, 5])
-> Tensor dev QDType '[13, 5] -> Tensor dev QDType '[1, 13, 5]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[] -> QTensor dev ('[] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[]
slice
out1 :: QTensor dev (1 : QSliceHidden : PShape)
out1 :: QTensor dev (1 : EmbShape)
out1 = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
{w2 :: (DeviceType, Nat)}.
(Assert
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...),
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
KnownNat (Fst padding), KnownNat (Snd stride),
KnownNat (Snd padding)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w2
w1
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
TT.conv2dForward @'(1, 1) @'(0, 0) Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Tensor dev QDType '[1, 1, 13, 5]
input
out2 :: QTensor dev (1 : EmbShape)
out2 :: QTensor dev (1 : EmbShape)
out2 = QTensor dev (1 : EmbShape) -> QTensor dev (1 : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (1 : EmbShape) -> QTensor dev (1 : EmbShape))
-> QTensor dev (1 : EmbShape) -> QTensor dev (1 : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
{w2 :: (DeviceType, Nat)}.
(Assert
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...),
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
KnownNat (Fst padding), KnownNat (Snd stride),
KnownNat (Snd padding)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w2
w1
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 Tensor dev QDType '[1, QSliceHidden, 13, 5]
QTensor dev (1 : EmbShape)
out1
forwardStoch :: SliceEncoder dev
-> SliceEncoding dev '[] -> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> (SliceEncoding dev '[] -> QTensor dev embshape)
-> SliceEncoding dev '[]
-> IO (QTensor dev embshape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SliceEncoder dev -> SliceEncoding dev '[] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model
instance
( IsValidDevice dev
, embshape ~ '[batchSize, EmbSize, FifthSize, OctaveSize]
)
=> T.HasForward (SliceEncoder dev) (SliceEncoding dev '[batchSize]) (QTensor dev embshape)
where
forward :: SliceEncoder dev
-> SliceEncoding dev '[batchSize] -> QTensor dev embshape
forward (SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 ConstEmb dev EmbShape
_ ConstEmb dev EmbShape
_) SliceEncoding dev '[batchSize]
slice = QTensor dev embshape
QTensor dev (batchSize : EmbShape)
out2
where
input :: Tensor dev QDType '[batchSize, 1, 13, 5]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 (Tensor dev QDType '[batchSize, 13, 5]
-> Tensor dev QDType '[batchSize, 1, 13, 5])
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor dev QDType '[batchSize, 1, 13, 5]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[batchSize]
-> QTensor dev ('[batchSize] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[batchSize]
slice
out1 :: QTensor dev '[batchSize, QSliceHidden, FifthSize, OctaveSize]
out1 :: QTensor dev (batchSize : EmbShape)
out1 = QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(0, 0) Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Tensor dev QDType '[batchSize, 1, 13, 5]
input
out2 :: QTensor dev '[batchSize, EmbSize, FifthSize, OctaveSize]
out2 :: QTensor dev (batchSize : EmbShape)
out2 = QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
out1
forwardStoch :: SliceEncoder dev
-> SliceEncoding dev '[batchSize] -> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> (SliceEncoding dev '[batchSize] -> QTensor dev embshape)
-> SliceEncoding dev '[batchSize]
-> IO (QTensor dev embshape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SliceEncoder dev
-> SliceEncoding dev '[batchSize] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model
instance
(embshape ~ EmbShape, IsValidDevice dev)
=> TT.HasForward (SliceEncoder dev) (QStartStop dev '[] (SliceEncoding dev '[])) (QTensor dev embshape)
where
forward :: SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> QTensor dev embshape
forward model :: SliceEncoder dev
model@(SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
_ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_ ConstEmb dev EmbShape
start ConstEmb dev EmbShape
stop) (QStartStop Tensor dev 'Int64 '[]
tag SliceEncoding dev '[]
input) = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 Tensor dev QDType '[1, QSliceHidden, 13, 5]
QTensor dev (1 : EmbShape)
out
where
outStart :: QTensor dev (EmbSize : PShape)
outStart :: QTensor dev EmbShape
outStart = ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
start ()
outStop :: QTensor dev (EmbSize : PShape)
outStop :: QTensor dev EmbShape
outStop = ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
stop ()
outInner :: QTensor dev (EmbSize : PShape)
outInner :: QTensor dev EmbShape
outInner = SliceEncoder dev
-> SliceEncoding dev '[] -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model SliceEncoding dev '[]
input
combined :: QTensor dev (3 : EmbSize : PShape)
combined :: QTensor dev (3 : EmbShape)
combined = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.stack @0 (HList
'[QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5]]
-> QTensor dev (3 : EmbShape))
-> HList
'[QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5]]
-> QTensor dev (3 : EmbShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
outStart QTensor dev '[QSliceHidden, 13, 5]
-> HList
'[QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5]]
-> HList
'[QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
outInner QTensor dev '[QSliceHidden, 13, 5]
-> HList '[QTensor dev '[QSliceHidden, 13, 5]]
-> HList
'[QTensor dev '[QSliceHidden, 13, 5],
QTensor dev '[QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
outStop QTensor dev '[QSliceHidden, 13, 5]
-> HList '[] -> HList '[QTensor dev '[QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
tag' :: TT.Tensor dev TT.Int64 (1 : EmbSize : PShape)
tag' :: Tensor dev 'Int64 (1 : EmbShape)
tag' = Bool
-> Tensor dev 'Int64 '[1, 1, 1, 1]
-> Tensor dev 'Int64 (1 : EmbShape)
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
TT.expand Bool
False (Tensor dev 'Int64 '[1, 1, 1, 1]
-> Tensor dev 'Int64 (1 : EmbShape))
-> Tensor dev 'Int64 '[1, 1, 1, 1]
-> Tensor dev 'Int64 (1 : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape @[1, 1, 1, 1] Tensor dev 'Int64 '[]
tag
out :: QTensor dev (1 : EmbSize : PShape)
out :: QTensor dev (1 : EmbShape)
out = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
TT.gatherDim @0 Tensor dev 'Int64 '[1, QSliceHidden, 13, 5]
Tensor dev 'Int64 (1 : EmbShape)
tag' Tensor dev QDType '[3, QSliceHidden, 13, 5]
QTensor dev (3 : EmbShape)
combined
forwardStoch :: SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model QStartStop dev '[] (SliceEncoding dev '[])
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model QStartStop dev '[] (SliceEncoding dev '[])
input
instance
( IsValidDevice dev
, embshape ~ (batchSize : EmbSize : PShape)
)
=> TT.HasForward (SliceEncoder dev) (QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])) (QTensor dev embshape)
where
forward :: SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev embshape
forward model :: SliceEncoder dev
model@(SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
_ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_ ConstEmb dev EmbShape
start ConstEmb dev EmbShape
stop) (QStartStop Tensor dev 'Int64 '[batchSize]
tag SliceEncoding dev '[batchSize]
input) = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 Tensor dev QDType '[1, batchSize, QSliceHidden, 13, 5]
out
where
outStart :: QTensor dev (batchSize : EmbSize : PShape)
outStart :: QTensor dev (batchSize : EmbShape)
outStart = Tensor -> QTensor dev (batchSize : EmbShape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev (batchSize : EmbShape))
-> Tensor -> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
expandAs (QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[QSliceHidden, 13, 5] -> Tensor)
-> QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall a b. (a -> b) -> a -> b
$ ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
start ()) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner
outStop :: QTensor dev (batchSize : EmbSize : PShape)
outStop :: QTensor dev (batchSize : EmbShape)
outStop = Tensor -> QTensor dev (batchSize : EmbShape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev (batchSize : EmbShape))
-> Tensor -> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
expandAs (QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[QSliceHidden, 13, 5] -> Tensor)
-> QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall a b. (a -> b) -> a -> b
$ ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
stop ()) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner
outInner :: QTensor dev (batchSize : EmbSize : PShape)
outInner :: QTensor dev (batchSize : EmbShape)
outInner = SliceEncoder dev
-> SliceEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model SliceEncoding dev '[batchSize]
input
combined :: QTensor dev (3 : batchSize : EmbSize : PShape)
combined :: QTensor dev (3 : batchSize : EmbShape)
combined = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.stack @0 (HList
'[QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> QTensor dev (3 : batchSize : EmbShape))
-> HList
'[QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> QTensor dev (3 : batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outStart QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> HList
'[QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> HList
'[QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> HList '[QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> HList
'[QTensor dev '[batchSize, QSliceHidden, 13, 5],
QTensor dev '[batchSize, QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outStop QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> HList '[]
-> HList '[QTensor dev '[batchSize, QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
tag' :: TT.Tensor dev 'TT.Int64 (1 : batchSize : EmbSize : PShape)
tag' :: Tensor dev 'Int64 (1 : batchSize : EmbShape)
tag' =
Tensor -> Tensor dev 'Int64 (1 : batchSize : EmbShape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor
(Tensor -> Tensor dev 'Int64 (1 : batchSize : EmbShape))
-> Tensor -> Tensor dev 'Int64 (1 : batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Dim -> Tensor -> Tensor
T.unsqueeze (Int -> Dim
T.Dim Int
0)
(Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
expandAs
([Int] -> Tensor -> Tensor
T.reshape [-Int
1, Int
1, Int
1, Int
1] (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor dev 'Int64 '[batchSize] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor dev 'Int64 '[batchSize]
tag)
(Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner
out :: Tensor dev QDType '[1, batchSize, QSliceHidden, 13, 5]
out = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
TT.gatherDim @0 Tensor dev 'Int64 '[1, batchSize, QSliceHidden, 13, 5]
Tensor dev 'Int64 (1 : batchSize : EmbShape)
tag' Tensor dev QDType '[3, batchSize, QSliceHidden, 13, 5]
QTensor dev (3 : batchSize : EmbShape)
combined
forwardStoch :: SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
input
data TransitionSpec dev = TransitionSpec
data TransitionEncoder dev = TransitionEncoder
{ forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing :: !(TT.Conv2d 2 QTransHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: !(TT.Conv2d 2 QTransHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Left :: !(TT.Conv2d 1 QTransHidden 1 1 QDType dev)
, forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: !(TT.Conv2d 1 QTransHidden 1 1 QDType dev)
, forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> ConstEmb dev '[QSliceHidden]
trL1Root :: !(ConstEmb dev '[QTransHidden])
, forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL2 :: !(TT.Conv2d QTransHidden (EmbSize) FifthSize OctaveSize QDType dev)
}
deriving (Int -> TransitionEncoder dev -> ShowS
[TransitionEncoder dev] -> ShowS
TransitionEncoder dev -> String
(Int -> TransitionEncoder dev -> ShowS)
-> (TransitionEncoder dev -> String)
-> ([TransitionEncoder dev] -> ShowS)
-> Show (TransitionEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)).
Int -> TransitionEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [TransitionEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)).
Int -> TransitionEncoder dev -> ShowS
showsPrec :: Int -> TransitionEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> String
show :: TransitionEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [TransitionEncoder dev] -> ShowS
showList :: [TransitionEncoder dev] -> ShowS
Show, (forall x. TransitionEncoder dev -> Rep (TransitionEncoder dev) x)
-> (forall x.
Rep (TransitionEncoder dev) x -> TransitionEncoder dev)
-> Generic (TransitionEncoder dev)
forall x. Rep (TransitionEncoder dev) x -> TransitionEncoder dev
forall x. TransitionEncoder dev -> Rep (TransitionEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (TransitionEncoder dev) x -> TransitionEncoder dev
forall (dev :: (DeviceType, Nat)) x.
TransitionEncoder dev -> Rep (TransitionEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
TransitionEncoder dev -> Rep (TransitionEncoder dev) x
from :: forall x. TransitionEncoder dev -> Rep (TransitionEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (TransitionEncoder dev) x -> TransitionEncoder dev
to :: forall x. Rep (TransitionEncoder dev) x -> TransitionEncoder dev
Generic, TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
(TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev)))
-> (TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev)
-> Parameterized (TransitionEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
flattenParameters :: TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
replaceParameters :: TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
TT.Parameterized, Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
Proxy (TransitionEncoder dev) -> String
(Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (TransitionEncoder dev) -> String)
-> NoThunks (TransitionEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (TransitionEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (TransitionEncoder dev) -> String
showTypeOf :: Proxy (TransitionEncoder dev) -> String
NoThunks, TransitionEncoder dev -> ()
(TransitionEncoder dev -> ()) -> NFData (TransitionEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> ()
rnf :: TransitionEncoder dev -> ()
NFData)
instance (IsValidDevice dev) => T.Randomizable (TransitionSpec dev) (TransitionEncoder dev) where
sample :: TransitionSpec dev -> IO (TransitionEncoder dev)
sample :: TransitionSpec dev -> IO (TransitionEncoder dev)
sample TransitionSpec dev
_ = do
trL1Passing <- Conv2dSpec 2 QSliceHidden 13 5 QDType dev
-> IO (Conv2d 2 QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec 2 QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
TT.Conv2dSpec
trL1Inner <- T.sample TT.Conv2dSpec
trL1Left <- T.sample TT.Conv2dSpec
trL1Right <- T.sample TT.Conv2dSpec
trL1Root <- T.sample $ ConstEmbSpec @dev
trL2 <- T.sample TT.Conv2dSpec
pure $ TransitionEncoder{..}
instance
forall dev embshape
. ( IsValidDevice dev
, embshape ~ (EmbSize : PShape)
)
=> T.HasForward (TransitionEncoder dev) (TransitionEncoding dev '[]) (QTensor dev embshape)
where
forward :: TransitionEncoder dev
-> TransitionEncoding dev '[] -> QTensor dev embshape
forward TransitionEncoder{Conv2d 1 QSliceHidden 1 1 QDType dev
Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
ConstEmb dev '[QSliceHidden]
trL1Passing :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> ConstEmb dev '[QSliceHidden]
trL2 :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: ConstEmb dev '[QSliceHidden]
trL2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
..} TransitionEncoding{QTensor dev '[]
SliceEncoding dev '[]
QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencPassing :: QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencInner :: QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencLeft :: SliceEncoding dev '[]
trencRight :: SliceEncoding dev '[]
trencRoot :: QTensor dev '[]
trencRoot :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRight :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencInner :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
trencPassing :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
..} =
forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> QTensor dev embshape)
-> Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> QTensor dev embshape
forall a b. (a -> b) -> a -> b
$
Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1])
-> Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> Tensor
dev
QDType
'[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
forall a b. (a -> b) -> a -> b
$
forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
{w2 :: (DeviceType, Nat)}.
(Assert
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...),
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
KnownNat (Fst padding), KnownNat (Snd stride),
KnownNat (Snd padding)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w2
w1
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL2 (Tensor dev QDType '[1, QSliceHidden, 13, 5]
-> Tensor
dev
QDType
'[1, QSliceHidden,
Div ((13 + (2 * Fst '(FifthPadding, 2))) - 13) (Fst '(1, 1)) + 1,
Div ((5 + (2 * Snd '(FifthPadding, 2))) - 5) (Snd '(1, 1)) + 1])
-> Tensor dev QDType '[1, QSliceHidden, 13, 5]
-> Tensor
dev
QDType
'[1, QSliceHidden,
Div ((13 + (2 * Fst '(FifthPadding, 2))) - 13) (Fst '(1, 1)) + 1,
Div ((5 + (2 * Snd '(FifthPadding, 2))) - 5) (Snd '(1, 1)) + 1]
forall a b. (a -> b) -> a -> b
$
forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
all
where
runConv
:: (KnownNat nin)
=> TT.Conv2d nin QTransHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType MaxEdges '[] (nin : PShape)
-> QTensor dev (QTransHidden : PShape)
runConv :: forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (nin : PShape)
-> QTensor dev EmbShape
runConv Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv (QBoundedList QTensor dev ('[] ++ '[QSliceHidden])
mask Tensor dev QDType (('[] ++ '[QSliceHidden]) ++ (nin : PShape))
edges) = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @0 (Tensor
dev
QDType
(CheckBroadcast
'[QSliceHidden, 1, 1, 1]
'[QSliceHidden, QSliceHidden, 13, 5]
(ComputeBroadcast
'[1, 1, 1, QSliceHidden]
(ReverseImpl '[QSliceHidden, QSliceHidden, 13, 5] '[])))
-> QTensor dev EmbShape)
-> Tensor
dev
QDType
(CheckBroadcast
'[QSliceHidden, 1, 1, 1]
'[QSliceHidden, QSliceHidden, 13, 5]
(ComputeBroadcast
'[1, 1, 1, QSliceHidden]
(ReverseImpl '[QSliceHidden, QSliceHidden, 13, 5] '[])))
-> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[QSliceHidden, 1, 1, 1]
-> Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5]
-> Tensor
dev
QDType
(CheckBroadcast
'[QSliceHidden, 1, 1, 1]
'[QSliceHidden, QSliceHidden, 13, 5]
(ComputeBroadcast
'[1, 1, 1, QSliceHidden]
(ReverseImpl '[QSliceHidden, QSliceHidden, 13, 5] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul Tensor dev QDType '[QSliceHidden, 1, 1, 1]
mask' Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5]
QTensor dev (QSliceHidden : EmbShape)
out
where
out :: QTensor dev (MaxEdges : QTransHidden : PShape)
out :: QTensor dev (QSliceHidden : EmbShape)
out = QTensor dev (QSliceHidden : EmbShape)
-> QTensor dev (QSliceHidden : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (QSliceHidden : EmbShape)
-> QTensor dev (QSliceHidden : EmbShape))
-> QTensor dev (QSliceHidden : EmbShape)
-> QTensor dev (QSliceHidden : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
{w2 :: (DeviceType, Nat)}.
(Assert
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...),
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
KnownNat (Fst padding), KnownNat (Snd stride),
KnownNat (Snd padding)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w2
w1
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin QSliceHidden 13 5 QDType dev
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[QSliceHidden, nin, 13, 5]
Tensor dev QDType (('[] ++ '[QSliceHidden]) ++ (nin : PShape))
edges
mask' :: QTensor dev '[MaxEdges, 1, 1, 1]
mask' :: Tensor dev QDType '[QSliceHidden, 1, 1, 1]
mask' = QTensor dev '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden, 1, 1, 1]
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape QTensor dev '[QSliceHidden]
QTensor dev ('[] ++ '[QSliceHidden])
mask
runSlice :: Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
-> Tensor device dtype '[inputSize0, inputSize1]
-> Tensor
device
dtype
(SqueezeDimCheck
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
0
(SqueezeDimImpl
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
0))
runSlice Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
conv Tensor device dtype '[inputSize0, inputSize1]
slice = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor
device
dtype
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
-> Tensor
device
dtype
(SqueezeDimCheck
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
0
(SqueezeDimImpl
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
0)))
-> Tensor
device
dtype
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
-> Tensor
device
dtype
(SqueezeDimCheck
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
0
(SqueezeDimImpl
'[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
0))
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
{w2 :: (DeviceType, Nat)}.
(Assert
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...),
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
KnownNat (Fst padding), KnownNat (Snd stride),
KnownNat (Snd padding)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w2
w1
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
TT.conv2dForward @'(1, 1) @'(0, 0) Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
conv Tensor device dtype '[1, 1, inputSize0, inputSize1]
input
where
input :: Tensor device dtype '[1, 1, inputSize0, inputSize1]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 (Tensor device dtype '[1, inputSize0, inputSize1]
-> Tensor device dtype '[1, 1, inputSize0, inputSize1])
-> Tensor device dtype '[1, inputSize0, inputSize1]
-> Tensor device dtype '[1, 1, inputSize0, inputSize1]
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor device dtype '[inputSize0, inputSize1]
slice
pass :: QTensor dev (QTransHidden : PShape)
pass :: QTensor dev EmbShape
pass = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
-> QTensor dev EmbShape
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (nin : PShape)
-> QTensor dev EmbShape
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencPassing
inner :: QTensor dev (QTransHidden : PShape)
inner :: QTensor dev EmbShape
inner = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
-> QTensor dev EmbShape
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (nin : PShape)
-> QTensor dev EmbShape
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencInner
left :: QTensor dev (QTransHidden : PShape)
left :: QTensor dev EmbShape
left = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {kernelSize0 :: Nat} {inputSize0 :: Nat}
{kernelSize1 :: Nat} {inputSize1 :: Nat} {outputChannelSize :: Nat}
{dtype :: DType} {device :: (DeviceType, Nat)}.
(Assert
(OrdCond (CmpNat kernelSize0 inputSize0) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat kernelSize1 inputSize1) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
(TypeError ...),
KnownNat outputChannelSize, KnownNat kernelSize0,
KnownNat kernelSize1, KnownNat inputSize0, KnownNat inputSize1) =>
Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
-> Tensor device dtype '[inputSize0, inputSize1]
-> Tensor
device
dtype
'[outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Left (Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[] -> QTensor dev ('[] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[]
trencLeft
right :: QTensor dev (QTransHidden : PShape)
right :: QTensor dev EmbShape
right = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {kernelSize0 :: Nat} {inputSize0 :: Nat}
{kernelSize1 :: Nat} {inputSize1 :: Nat} {outputChannelSize :: Nat}
{dtype :: DType} {device :: (DeviceType, Nat)}.
(Assert
(OrdCond (CmpNat kernelSize0 inputSize0) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat kernelSize1 inputSize1) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
(TypeError ...),
KnownNat outputChannelSize, KnownNat kernelSize0,
KnownNat kernelSize1, KnownNat inputSize0, KnownNat inputSize1) =>
Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
-> Tensor device dtype '[inputSize0, inputSize1]
-> Tensor
device
dtype
'[outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right (Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[] -> QTensor dev ('[] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[]
trencRight
root :: QTensor dev '[QTransHidden, 1, 1]
root :: QTensor dev '[QSliceHidden, 1, 1]
root = Tensor
dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
-> QTensor dev '[QSliceHidden, 1, 1]
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape (Tensor
dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
-> QTensor dev '[QSliceHidden, 1, 1])
-> Tensor
dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
-> QTensor dev '[QSliceHidden, 1, 1]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[]
-> QTensor dev '[QSliceHidden]
-> Tensor
dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul QTensor dev '[]
trencRoot (QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (ConstEmb dev '[QSliceHidden] -> () -> QTensor dev '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[QSliceHidden]
trL1Root ()))
all :: QTensor dev (QTransHidden : PShape)
all :: QTensor dev EmbShape
all = (Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
pass Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
inner Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
left Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
right) Tensor dev QDType '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 1, 1]
-> Tensor
dev
QDType
(CheckBroadcast
'[QSliceHidden, 13, 5]
'[QSliceHidden, 1, 1]
(ComputeBroadcast
(ReverseImpl '[QSliceHidden, 13, 5] '[]) '[1, 1, QSliceHidden]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` QTensor dev '[QSliceHidden, 1, 1]
root
forwardStoch :: TransitionEncoder dev
-> TransitionEncoding dev '[] -> IO (QTensor dev embshape)
forwardStoch TransitionEncoder dev
tr TransitionEncoding dev '[]
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[]
input
instance
forall dev batchSize embshape
. ( IsValidDevice dev
, embshape ~ (batchSize : EmbSize : PShape)
)
=> T.HasForward (TransitionEncoder dev) (TransitionEncoding dev '[batchSize]) (QTensor dev embshape)
where
forward :: TransitionEncoder dev
-> TransitionEncoding dev '[batchSize] -> QTensor dev embshape
forward TransitionEncoder{Conv2d 1 QSliceHidden 1 1 QDType dev
Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
ConstEmb dev '[QSliceHidden]
trL1Passing :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> ConstEmb dev '[QSliceHidden]
trL2 :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: ConstEmb dev '[QSliceHidden]
trL2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
..} TransitionEncoding{QTensor dev '[batchSize]
SliceEncoding dev '[batchSize]
QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencRoot :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRight :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencInner :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
trencPassing :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
trencPassing :: QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencInner :: QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencLeft :: SliceEncoding dev '[batchSize]
trencRight :: SliceEncoding dev '[batchSize]
trencRoot :: QTensor dev '[batchSize]
..} =
QTensor dev embshape -> QTensor dev embshape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev embshape -> QTensor dev embshape)
-> QTensor dev embshape -> QTensor dev embshape
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL2 Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
all
where
runConv
:: forall nin
. (KnownNat nin)
=> TT.Conv2d nin QTransHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType MaxEdges '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : QTransHidden : PShape)
runConv :: forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : EmbShape)
runConv Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv (QBoundedList QTensor dev ('[batchSize] ++ '[QSliceHidden])
mask Tensor
dev QDType (('[batchSize] ++ '[QSliceHidden]) ++ (nin : PShape))
edges) = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @1 (Tensor
dev
QDType
(CheckBroadcast
'[batchSize, QSliceHidden, 1, 1, 1]
'[batchSize, QSliceHidden, QSliceHidden, 13, 5]
(ComputeBroadcast
'[1, 1, 1, QSliceHidden, batchSize]
(ReverseImpl '[batchSize, QSliceHidden, QSliceHidden, 13, 5] '[])))
-> QTensor dev (batchSize : EmbShape))
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, QSliceHidden, 1, 1, 1]
'[batchSize, QSliceHidden, QSliceHidden, 13, 5]
(ComputeBroadcast
'[1, 1, 1, QSliceHidden, batchSize]
(ReverseImpl '[batchSize, QSliceHidden, QSliceHidden, 13, 5] '[])))
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
-> Tensor
dev QDType '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, QSliceHidden, 1, 1, 1]
'[batchSize, QSliceHidden, QSliceHidden, 13, 5]
(ComputeBroadcast
'[1, 1, 1, QSliceHidden, batchSize]
(ReverseImpl '[batchSize, QSliceHidden, QSliceHidden, 13, 5] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
mask' Tensor dev QDType '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
QTensor dev (batchSize : QSliceHidden : EmbShape)
outReshaped
where
shape :: [Int]
shape = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @(nin : PShape)
shape' :: [Int]
shape' = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @(MaxEdges : QTransHidden : PShape)
inputShaped :: QTensor dev (batchSize * MaxEdges : nin : PShape)
inputShaped :: QTensor dev ((batchSize * QSliceHidden) : nin : PShape)
inputShaped = [Int]
-> Tensor dev QDType '[batchSize, QSliceHidden, nin, 13, 5]
-> Tensor dev QDType '[batchSize * QSliceHidden, nin, 13, 5]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
(shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape (-Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
shape) Tensor dev QDType '[batchSize, QSliceHidden, nin, 13, 5]
Tensor
dev QDType (('[batchSize] ++ '[QSliceHidden]) ++ (nin : PShape))
edges
out :: QTensor dev (batchSize * MaxEdges : QTransHidden : PShape)
out :: QTensor dev ((batchSize * QSliceHidden) : EmbShape)
out = QTensor dev ((batchSize * QSliceHidden) : EmbShape)
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev ((batchSize * QSliceHidden) : EmbShape)
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape))
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape)
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin QSliceHidden 13 5 QDType dev
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[batchSize * QSliceHidden, nin, 13, 5]
QTensor dev ((batchSize * QSliceHidden) : nin : PShape)
inputShaped
outReshaped :: QTensor dev (batchSize : MaxEdges : QTransHidden : PShape)
outReshaped :: QTensor dev (batchSize : QSliceHidden : EmbShape)
outReshaped = [Int]
-> Tensor
dev QDType '[batchSize * QSliceHidden, QSliceHidden, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
(shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape (-Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
shape') Tensor dev QDType '[batchSize * QSliceHidden, QSliceHidden, 13, 5]
QTensor dev ((batchSize * QSliceHidden) : EmbShape)
out
mask' :: QTensor dev '[batchSize, MaxEdges, 1, 1, 1]
mask' :: Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
mask' = [Int]
-> Tensor dev QDType '[batchSize, QSliceHidden]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
(shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [-Int
1, forall (n :: Nat). KnownNat n => Int
TT.natValI @MaxEdges, Int
1, Int
1, Int
1] Tensor dev QDType '[batchSize, QSliceHidden]
QTensor dev ('[batchSize] ++ '[QSliceHidden])
mask
runSlice :: Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor device dtype shape
-> Tensor
device
dtype
'[batchSize, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
runSlice Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
conv Tensor device dtype shape
slice = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(0, 0) Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
conv Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
where
input :: Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 Tensor device dtype shape
slice
pass :: QTensor dev (batchSize : QTransHidden : PShape)
pass :: QTensor dev (batchSize : EmbShape)
pass = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : EmbShape)
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencPassing
inner :: QTensor dev (batchSize : QTransHidden : PShape)
inner :: QTensor dev (batchSize : EmbShape)
inner = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : EmbShape)
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencInner
left :: QTensor dev (batchSize : QTransHidden : PShape)
left :: QTensor dev (batchSize : EmbShape)
left = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {shape :: [Nat]} {batchSize :: Nat}
{inputChannelSize :: Nat} {inputSize0 :: Nat} {inputSize1 :: Nat}
{kernelSize0 :: Nat} {kernelSize1 :: Nat}
{outputChannelSize :: Nat} {dtype :: DType}
{device :: (DeviceType, Nat)}.
(UnsqueezeCheck shape 1 (UnsqueezeImpl shape 1)
~ '[batchSize, inputChannelSize, inputSize0, inputSize1],
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat 1 ((inputSize0 - kernelSize0) + 1)) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat 1 ((inputSize1 - kernelSize1) + 1)) 'True 'True 'False)
(TypeError ...)) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor device dtype shape
-> Tensor
device
dtype
'[batchSize, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Left (Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[batchSize]
-> QTensor dev ('[batchSize] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[batchSize]
trencLeft
right :: QTensor dev (batchSize : QTransHidden : PShape)
right :: QTensor dev (batchSize : EmbShape)
right = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {shape :: [Nat]} {batchSize :: Nat}
{inputChannelSize :: Nat} {inputSize0 :: Nat} {inputSize1 :: Nat}
{kernelSize0 :: Nat} {kernelSize1 :: Nat}
{outputChannelSize :: Nat} {dtype :: DType}
{device :: (DeviceType, Nat)}.
(UnsqueezeCheck shape 1 (UnsqueezeImpl shape 1)
~ '[batchSize, inputChannelSize, inputSize0, inputSize1],
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat 1 ((inputSize0 - kernelSize0) + 1)) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat 1 ((inputSize1 - kernelSize1) + 1)) 'True 'True 'False)
(TypeError ...)) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor device dtype shape
-> Tensor
device
dtype
'[batchSize, outputChannelSize, (inputSize0 - kernelSize0) + 1,
(inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right (Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[batchSize]
-> QTensor dev ('[batchSize] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[batchSize]
trencRight
root :: QTensor dev '[batchSize, QTransHidden, 1, 1]
root :: QTensor dev '[batchSize, QSliceHidden, 1, 1]
root = [Int]
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, 1]
'[QSliceHidden]
(ComputeBroadcast
(ReverseImpl '[batchSize, 1] '[])
(ReverseImpl '[QSliceHidden] '[])))
-> QTensor dev '[batchSize, QSliceHidden, 1, 1]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
(shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [-Int
1, forall (n :: Nat). KnownNat n => Int
TT.natValI @QTransHidden, Int
1, Int
1] (Tensor
dev
QDType
(CheckBroadcast
'[batchSize, 1]
'[QSliceHidden]
(ComputeBroadcast
(ReverseImpl '[batchSize, 1] '[])
(ReverseImpl '[QSliceHidden] '[])))
-> QTensor dev '[batchSize, QSliceHidden, 1, 1])
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, 1]
'[QSliceHidden]
(ComputeBroadcast
(ReverseImpl '[batchSize, 1] '[])
(ReverseImpl '[QSliceHidden] '[])))
-> QTensor dev '[batchSize, QSliceHidden, 1, 1]
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[batchSize, 1]
-> QTensor dev '[QSliceHidden]
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, 1]
'[QSliceHidden]
(ComputeBroadcast
(ReverseImpl '[batchSize, 1] '[])
(ReverseImpl '[QSliceHidden] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul (forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 QTensor dev '[batchSize]
trencRoot) (QTensor dev '[QSliceHidden]
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, 1]
'[QSliceHidden]
(ComputeBroadcast
(ReverseImpl '[batchSize, 1] '[])
(ReverseImpl '[QSliceHidden] '[]))))
-> QTensor dev '[QSliceHidden]
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, 1]
'[QSliceHidden]
(ComputeBroadcast
(ReverseImpl '[batchSize, 1] '[])
(ReverseImpl '[QSliceHidden] '[])))
forall a b. (a -> b) -> a -> b
$ QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden])
-> QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden]
forall a b. (a -> b) -> a -> b
$ ConstEmb dev '[QSliceHidden] -> () -> QTensor dev '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[QSliceHidden]
trL1Root ()
all :: QTensor dev (batchSize : QTransHidden : PShape)
all :: QTensor dev (batchSize : EmbShape)
all = (Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
pass Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
inner Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
left Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
right) Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 1, 1]
-> Tensor
dev
QDType
(CheckBroadcast
'[batchSize, QSliceHidden, 13, 5]
'[batchSize, QSliceHidden, 1, 1]
(ComputeBroadcast
(ReverseImpl '[batchSize, QSliceHidden, 13, 5] '[])
'[1, 1, QSliceHidden, batchSize]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` QTensor dev '[batchSize, QSliceHidden, 1, 1]
root
forwardStoch :: TransitionEncoder dev
-> TransitionEncoding dev '[batchSize] -> IO (QTensor dev embshape)
forwardStoch TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[batchSize] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
input
data ActionSpec dev = ActionSpec
data ActionEncoder dev = ActionEncoder
{ forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sl :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 :: !(TT.Conv2d QActionHidden EmbSize FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actSplit :: ConstEmb dev '[EmbSize - 3]
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actSpread :: ConstEmb dev '[EmbSize - 3]
, forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actFreeze :: ConstEmb dev '[EmbSize - 3]
}
deriving (Int -> ActionEncoder dev -> ShowS
[ActionEncoder dev] -> ShowS
ActionEncoder dev -> String
(Int -> ActionEncoder dev -> ShowS)
-> (ActionEncoder dev -> String)
-> ([ActionEncoder dev] -> ShowS)
-> Show (ActionEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)).
Int -> ActionEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [ActionEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)).
Int -> ActionEncoder dev -> ShowS
showsPrec :: Int -> ActionEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> String
show :: ActionEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [ActionEncoder dev] -> ShowS
showList :: [ActionEncoder dev] -> ShowS
Show, (forall x. ActionEncoder dev -> Rep (ActionEncoder dev) x)
-> (forall x. Rep (ActionEncoder dev) x -> ActionEncoder dev)
-> Generic (ActionEncoder dev)
forall x. Rep (ActionEncoder dev) x -> ActionEncoder dev
forall x. ActionEncoder dev -> Rep (ActionEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (ActionEncoder dev) x -> ActionEncoder dev
forall (dev :: (DeviceType, Nat)) x.
ActionEncoder dev -> Rep (ActionEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
ActionEncoder dev -> Rep (ActionEncoder dev) x
from :: forall x. ActionEncoder dev -> Rep (ActionEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (ActionEncoder dev) x -> ActionEncoder dev
to :: forall x. Rep (ActionEncoder dev) x -> ActionEncoder dev
Generic, ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
(ActionEncoder dev -> HList (Parameters (ActionEncoder dev)))
-> (ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev)
-> Parameterized (ActionEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
flattenParameters :: ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
replaceParameters :: ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
TT.Parameterized, Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
Proxy (ActionEncoder dev) -> String
(Context -> ActionEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> ActionEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (ActionEncoder dev) -> String)
-> NoThunks (ActionEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (ActionEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (ActionEncoder dev) -> String
showTypeOf :: Proxy (ActionEncoder dev) -> String
NoThunks, ActionEncoder dev -> ()
(ActionEncoder dev -> ()) -> NFData (ActionEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> ()
rnf :: ActionEncoder dev -> ()
NFData)
instance (IsValidDevice dev) => T.Randomizable (ActionSpec dev) (ActionEncoder dev) where
sample :: ActionSpec dev -> IO (ActionEncoder dev)
sample :: ActionSpec dev -> IO (ActionEncoder dev)
sample ActionSpec dev
ActionSpec = do
actTop1sl <- Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
TT.Conv2dSpec
actTop1sm <- T.sample TT.Conv2dSpec
actTop1sr <- T.sample TT.Conv2dSpec
actTop1t1 <- T.sample TT.Conv2dSpec
actTop1t2 <- T.sample TT.Conv2dSpec
actTop2 <- T.sample TT.Conv2dSpec
actSplit <- T.sample $ ConstEmbSpec @dev
actSpread <- T.sample $ ConstEmbSpec @dev
actFreeze <- T.sample $ ConstEmbSpec @dev
pure ActionEncoder{..}
opTypes :: forall dev. (TT.KnownDevice dev) => QTensor dev '[6, 3]
opTypes :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[FifthPadding, 3]
opTypes =
Tensor -> Tensor dev QDType '[FifthPadding, 3]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor
(Tensor -> Tensor dev QDType '[FifthPadding, 3])
-> Tensor -> Tensor dev QDType '[FifthPadding, 3]
forall a b. (a -> b) -> a -> b
$! forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
T.asTensor' @[[QType]]
[ [QType
0, QType
0, QType
0]
, [QType
0, QType
1, QType
0]
, [QType
1, QType
0, QType
0]
, [QType
1, QType
0, QType
1]
, [QType
1, QType
1, QType
0]
, [QType
1, QType
1, QType
1]
]
(TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Nat)). KnownDevice dev => TensorOptions
opts @dev
instance
forall dev batchSize outShape
. ( IsValidDevice dev
, outShape ~ (batchSize : EmbSize : PShape)
, 1 <= batchSize
)
=> T.HasForward
(ActionEncoder dev)
(SliceEncoder dev, TransitionEncoder dev, ActionEncoding dev '[batchSize])
(QTensor dev outShape)
where
forward :: ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
ActionEncoding dev '[batchSize])
-> QTensor dev outShape
forward ActionEncoder{Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
ConstEmb dev '[QSliceHidden - 3]
actTop1sl :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actSplit :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actSpread :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actFreeze :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actTop1sl :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actSplit :: ConstEmb dev '[QSliceHidden - 3]
actSpread :: ConstEmb dev '[QSliceHidden - 3]
actFreeze :: ConstEmb dev '[QSliceHidden - 3]
..} (SliceEncoder dev
slc, TransitionEncoder dev
tr, ActionEncoding (ActionTop QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sl TransitionEncoding dev '[batchSize]
t1 (QMaybe QTensor dev '[batchSize]
smMask SliceEncoding dev '[batchSize]
sm) (QMaybe QTensor dev '[batchSize]
t2Mask TransitionEncoding dev '[batchSize]
t2) QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sr) Tensor dev 'Int64 '[batchSize]
opIndex) = QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
topEmb QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
-> QTensor dev outShape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
opEmbReshaped
where
runConv
:: TT.Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv :: forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batchSize : nin : PShape)
input =
QTensor dev (batchSize : nout : PShape)
-> QTensor dev (batchSize : nout : PShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (batchSize : nout : PShape)
-> QTensor dev (batchSize : nout : PShape))
-> QTensor dev (batchSize : nout : PShape)
-> QTensor dev (batchSize : nout : PShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin nout 13 5 QDType dev
Conv2d nin nout FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[batchSize, nin, 13, 5]
QTensor dev (batchSize : nin : PShape)
input
runConvMasked
:: QTensor dev '[batchSize]
-> TT.Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked :: forall (nin :: Nat) (nout :: Nat).
QTensor dev '[batchSize]
-> Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked QTensor dev '[batchSize]
mask Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batchSize : nin : PShape)
input =
Tensor dev QDType '[batchSize, 1, 1, 1]
-> Tensor dev QDType '[batchSize, nout, 13, 5]
-> Tensor dev QDType (batchSize : nout : PShape)
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul ([Int]
-> QTensor dev '[batchSize]
-> Tensor dev QDType '[batchSize, 1, 1, 1]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
(shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [-Int
1, Int
1, Int
1, Int
1] QTensor dev '[batchSize]
mask :: QTensor dev '[batchSize, 1, 1, 1]) (Tensor dev QDType '[batchSize, nout, 13, 5]
-> Tensor dev QDType (batchSize : nout : PShape))
-> Tensor dev QDType '[batchSize, nout, 13, 5]
-> Tensor dev QDType (batchSize : nout : PShape)
forall a b. (a -> b) -> a -> b
$ Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> Tensor dev QDType (batchSize : nout : PShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batchSize : nin : PShape)
input
embl :: QTensor dev (batchSize : QActionHidden : PShape)
embl :: QTensor dev (batchSize : EmbShape)
embl = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sl (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sl
embm :: QTensor dev (batchSize : EmbShape)
embm = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> SliceEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc SliceEncoding dev '[batchSize]
sm
embr :: QTensor dev (batchSize : EmbShape)
embr = QTensor dev '[batchSize]
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
QTensor dev '[batchSize]
-> Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked QTensor dev '[batchSize]
smMask Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sr
embt1 :: QTensor dev (batchSize : EmbShape)
embt1 = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
t1
embt2 :: QTensor dev (batchSize : EmbShape)
embt2 = QTensor dev '[batchSize]
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
QTensor dev '[batchSize]
-> Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked QTensor dev '[batchSize]
t2Mask Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 (QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
t2
topCombined :: QTensor dev (batchSize : QActionHidden : PShape)
topCombined :: QTensor dev (batchSize : EmbShape)
topCombined = QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embl QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embm QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embr QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embt1 QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embt2
topEmb :: QTensor dev (batchSize : EmbSize : PShape)
topEmb :: QTensor dev (batchSize : EmbShape)
topEmb = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 QTensor dev (batchSize : EmbShape)
topCombined
opFreeze :: QTensor dev '[5]
opFreeze = ConstEmb dev '[5] -> () -> QTensor dev '[5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[5]
ConstEmb dev '[QSliceHidden - 3]
actFreeze ()
opSplit :: QTensor dev '[5]
opSplit = ConstEmb dev '[5] -> () -> QTensor dev '[5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[5]
ConstEmb dev '[QSliceHidden - 3]
actSplit ()
opSpread :: QTensor dev '[5]
opSpread = ConstEmb dev '[5] -> () -> QTensor dev '[5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[5]
ConstEmb dev '[QSliceHidden - 3]
actSpread ()
opCombined :: Tensor dev QDType '[FifthPadding, 5]
opCombined = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.stack @0 (HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
-> Tensor dev QDType '[FifthPadding, 5])
-> HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
-> Tensor dev QDType '[FifthPadding, 5]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[5]
opFreeze QTensor dev '[5]
-> HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5], QTensor dev '[5]]
-> HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSplit QTensor dev '[5]
-> HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5]]
-> HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opFreeze QTensor dev '[5]
-> HList '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
-> HList
'[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSpread QTensor dev '[5]
-> HList '[QTensor dev '[5], QTensor dev '[5]]
-> HList '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSplit QTensor dev '[5]
-> HList '[QTensor dev '[5]]
-> HList '[QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSplit QTensor dev '[5] -> HList '[] -> HList '[QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
opEmbeddings :: QTensor dev '[6, EmbSize]
opEmbeddings :: QTensor dev '[FifthPadding, QSliceHidden]
opEmbeddings = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.cat @1 (HList
'[QTensor dev '[FifthPadding, 3],
Tensor dev QDType '[FifthPadding, 5]]
-> QTensor dev '[FifthPadding, QSliceHidden])
-> HList
'[QTensor dev '[FifthPadding, 3],
Tensor dev QDType '[FifthPadding, 5]]
-> QTensor dev '[FifthPadding, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[FifthPadding, 3]
opTypes @dev QTensor dev '[FifthPadding, 3]
-> HList '[Tensor dev QDType '[FifthPadding, 5]]
-> HList
'[QTensor dev '[FifthPadding, 3],
Tensor dev QDType '[FifthPadding, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. Tensor dev QDType '[FifthPadding, 5]
opCombined Tensor dev QDType '[FifthPadding, 5]
-> HList '[] -> HList '[Tensor dev QDType '[FifthPadding, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
opIndex' :: TT.Tensor dev TT.Int64 [batchSize, EmbSize]
opIndex' :: Tensor dev 'Int64 '[batchSize, QSliceHidden]
opIndex' = Tensor -> Tensor dev 'Int64 '[batchSize, QSliceHidden]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev 'Int64 '[batchSize, QSliceHidden])
-> Tensor -> Tensor dev 'Int64 '[batchSize, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ Tensor -> Bool -> [Int] -> Tensor
T.expand (Tensor dev 'Int64 '[batchSize, 1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (Tensor dev 'Int64 '[batchSize, 1] -> Tensor)
-> Tensor dev 'Int64 '[batchSize, 1] -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 Tensor dev 'Int64 '[batchSize]
opIndex) Bool
False [-Int
1, forall (n :: Nat). KnownNat n => Int
TT.natValI @EmbSize]
opEmb :: QTensor dev '[batchSize, EmbSize]
opEmb :: QTensor dev '[batchSize, QSliceHidden]
opEmb = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
TT.gatherDim @0 Tensor dev 'Int64 '[batchSize, QSliceHidden]
opIndex' QTensor dev '[FifthPadding, QSliceHidden]
opEmbeddings
opEmbReshaped :: QTensor dev '[batchSize, EmbSize, 1, 1]
opEmbReshaped :: Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
opEmbReshaped = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @3 (Tensor dev QDType '[batchSize, QSliceHidden, 1]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1])
-> Tensor dev QDType '[batchSize, QSliceHidden, 1]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @2 QTensor dev '[batchSize, QSliceHidden]
opEmb
forwardStoch :: ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
ActionEncoding dev '[batchSize])
-> IO (QTensor dev outShape)
forwardStoch ActionEncoder dev
a (SliceEncoder dev, TransitionEncoder dev,
ActionEncoding dev '[batchSize])
i = QTensor dev outShape -> IO (QTensor dev outShape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev outShape -> IO (QTensor dev outShape))
-> QTensor dev outShape -> IO (QTensor dev outShape)
forall a b. (a -> b) -> a -> b
$ ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
ActionEncoding dev '[batchSize])
-> QTensor dev outShape
forall f a b. HasForward f a b => f -> a -> b
T.forward ActionEncoder dev
a (SliceEncoder dev, TransitionEncoder dev,
ActionEncoding dev '[batchSize])
i
data StateSpec dev = StateSpec
data StateEncoder dev = StateEncoder
{ forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
, forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
, forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
, forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
, forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
, forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 :: TT.Conv2d QStateHidden QStateHidden FifthSize OctaveSize QDType dev
, forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 :: TT.Conv2d QStateHidden (EmbSize) FifthSize OctaveSize QDType dev
}
deriving (Int -> StateEncoder dev -> ShowS
[StateEncoder dev] -> ShowS
StateEncoder dev -> String
(Int -> StateEncoder dev -> ShowS)
-> (StateEncoder dev -> String)
-> ([StateEncoder dev] -> ShowS)
-> Show (StateEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)). Int -> StateEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [StateEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). StateEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)). Int -> StateEncoder dev -> ShowS
showsPrec :: Int -> StateEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). StateEncoder dev -> String
show :: StateEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [StateEncoder dev] -> ShowS
showList :: [StateEncoder dev] -> ShowS
Show, (forall x. StateEncoder dev -> Rep (StateEncoder dev) x)
-> (forall x. Rep (StateEncoder dev) x -> StateEncoder dev)
-> Generic (StateEncoder dev)
forall x. Rep (StateEncoder dev) x -> StateEncoder dev
forall x. StateEncoder dev -> Rep (StateEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (StateEncoder dev) x -> StateEncoder dev
forall (dev :: (DeviceType, Nat)) x.
StateEncoder dev -> Rep (StateEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
StateEncoder dev -> Rep (StateEncoder dev) x
from :: forall x. StateEncoder dev -> Rep (StateEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (StateEncoder dev) x -> StateEncoder dev
to :: forall x. Rep (StateEncoder dev) x -> StateEncoder dev
Generic, StateEncoder dev -> HList (Parameters (StateEncoder dev))
StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
(StateEncoder dev -> HList (Parameters (StateEncoder dev)))
-> (StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev)
-> Parameterized (StateEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
StateEncoder dev -> HList (Parameters (StateEncoder dev))
forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev -> HList (Parameters (StateEncoder dev))
flattenParameters :: StateEncoder dev -> HList (Parameters (StateEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
replaceParameters :: StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
TT.Parameterized, Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
Proxy (StateEncoder dev) -> String
(Context -> StateEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> StateEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (StateEncoder dev) -> String)
-> NoThunks (StateEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (StateEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (StateEncoder dev) -> String
showTypeOf :: Proxy (StateEncoder dev) -> String
NoThunks, StateEncoder dev -> ()
(StateEncoder dev -> ()) -> NFData (StateEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). StateEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). StateEncoder dev -> ()
rnf :: StateEncoder dev -> ()
NFData)
instance (IsValidDevice dev) => T.Randomizable (StateSpec dev) (StateEncoder dev) where
sample :: StateSpec dev -> IO (StateEncoder dev)
sample StateSpec dev
_ = do
stL1mid <- Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
TT.sample Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2dSpec
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
TT.Conv2dSpec
stL1frozenSlc <- TT.sample TT.Conv2dSpec
stL1frozenTr <- TT.sample TT.Conv2dSpec
stL1openSlc <- TT.sample TT.Conv2dSpec
stL1openTr <- TT.sample TT.Conv2dSpec
stL2 <- TT.sample TT.Conv2dSpec
stL3 <- TT.sample TT.Conv2dSpec
pure StateEncoder{..}
instance
forall dev outShape
. ( IsValidDevice dev
, outShape ~ (EmbSize : PShape)
)
=> T.HasForward
(StateEncoder dev)
(SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
(QTensor dev outShape)
where
forward :: StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> QTensor dev outShape
forward StateEncoder{Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
..} (SliceEncoder dev
slc, TransitionEncoder dev
tr, StateEncoding QStartStop dev '[] (SliceEncoding dev '[])
mid QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
frozen QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
open) = QTensor dev outShape
QTensor dev EmbShape
out3
where
runConv'
:: (KnownNat nin, KnownNat nout, KnownNat batch)
=> TT.Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' :: forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batch : nin : PShape)
input = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
{kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
{inputSize0 :: Nat} {inputChannelSize :: Nat}
{outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
{w2 :: (DeviceType, Nat)}.
(Assert
(OrdCond
(CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond
(CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
(TypeError ...)
~ (() :: Constraint),
Assert
(OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
'True
'True
'False)
(TypeError ...),
Assert
(OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
(TypeError ...),
Assert
(OrdCond
(CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
'True
'True
'False)
(TypeError ...),
KnownNat inputChannelSize, KnownNat outputChannelSize,
KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
KnownNat (Fst padding), KnownNat (Snd stride),
KnownNat (Snd padding)) =>
Conv2d
inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
w2
w1
'[batchSize, outputChannelSize,
Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
+ 1,
Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
+ 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin nout 13 5 QDType dev
Conv2d nin nout FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[batch, nin, 13, 5]
QTensor dev (batch : nin : PShape)
input
runConv
:: (KnownNat nin, KnownNat nout)
=> TT.Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape)
-> QTensor dev (nout : PShape)
runConv :: forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (nin : PShape)
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor dev QDType '[1, nout, 13, 5]
-> Tensor dev QDType (nout : PShape))
-> Tensor dev QDType '[1, nout, 13, 5]
-> Tensor dev QDType (nout : PShape)
forall a b. (a -> b) -> a -> b
$ Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (1 : nin : PShape)
-> QTensor dev (1 : nout : PShape)
forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d nin nout FifthSize OctaveSize QDType dev
conv (QTensor dev (1 : nin : PShape) -> QTensor dev (1 : nout : PShape))
-> QTensor dev (1 : nin : PShape)
-> QTensor dev (1 : nout : PShape)
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev QDType '[nin, 13, 5]
QTensor dev (nin : PShape)
input
embedSegments
:: TT.Conv2d EmbSize QStateHidden FifthSize OctaveSize QDType dev
-> TT.Conv2d EmbSize QStateHidden FifthSize OctaveSize QDType dev
-> QMaybe dev '[] (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbSize : PShape)
embedSegments :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbShape)
embedSegments Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trEnc Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
slcEnc (QMaybe QTensor dev '[]
mask (TransitionEncoding dev '[FakeSize]
ft, QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])
fs)) =
Tensor dev QDType '[1, 1, 1, 1]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev (FakeSize : EmbShape)
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul (forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape @[1, 1, 1, 1] QTensor dev '[]
mask) (QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev (FakeSize : EmbShape))
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev '[FakeSize, QSliceHidden, 13, 5]
QTensor dev (FakeSize : EmbShape)
ftEmb QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[FakeSize, QSliceHidden, 13, 5]
QTensor dev (FakeSize : EmbShape)
fsEmb
where
ftEmb :: QTensor dev (FakeSize : EmbSize : PShape)
ftEmb :: QTensor dev (FakeSize : EmbShape)
ftEmb = QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trEnc (QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[FakeSize]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[FakeSize]
ft
fsEmb :: QTensor dev (FakeSize : EmbSize : PShape)
fsEmb :: QTensor dev (FakeSize : EmbShape)
fsEmb = QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
slcEnc (QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])
fs
frozenEmb :: QTensor dev (EmbSize : PShape)
frozenEmb :: QTensor dev EmbShape
frozenEmb = forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.meanDim @0 (QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape)
-> QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbShape)
embedSegments Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
frozen
openEmb :: QTensor dev (EmbSize : PShape)
openEmb :: QTensor dev EmbShape
openEmb = forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.meanDim @0 (QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape)
-> QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbShape)
embedSegments Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc QMaybe
dev
'[]
(TransitionEncoding dev '[FakeSize],
QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
open
midEmb :: QTensor dev (QStateHidden : PShape)
midEmb :: QTensor dev EmbShape
midEmb = QTensor dev EmbShape -> QTensor dev EmbShape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[] (SliceEncoding dev '[])
mid
fullEmb :: QTensor dev (EmbSize : PShape)
fullEmb :: QTensor dev EmbShape
fullEmb = QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
midEmb QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
frozenEmb QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
openEmb
out2 :: QTensor dev (QStateHidden : PShape)
out2 :: QTensor dev EmbShape
out2 = QTensor dev EmbShape -> QTensor dev EmbShape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 QTensor dev EmbShape
fullEmb
out3 :: QTensor dev (EmbSize : PShape)
out3 :: QTensor dev EmbShape
out3 = QTensor dev EmbShape -> QTensor dev EmbShape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 QTensor dev EmbShape
out2
forwardStoch :: StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> IO (QTensor dev outShape)
forwardStoch StateEncoder dev
a (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
i = QTensor dev outShape -> IO (QTensor dev outShape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev outShape -> IO (QTensor dev outShape))
-> QTensor dev outShape -> IO (QTensor dev outShape)
forall a b. (a -> b) -> a -> b
$ StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> QTensor dev outShape
forall f a b. HasForward f a b => f -> a -> b
T.forward StateEncoder dev
a (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
i
data QSpec dev = QSpec
data QModel dev = QModel
{ forall (dev :: (DeviceType, Nat)). QModel dev -> SliceEncoder dev
qModelSlc :: !(SliceEncoder dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev -> TransitionEncoder dev
qModelTr :: !(TransitionEncoder dev)
, forall (dev :: (DeviceType, Nat)). QModel dev -> ActionEncoder dev
qModelAct :: !(ActionEncoder dev)
, forall (dev :: (DeviceType, Nat)). QModel dev -> StateEncoder dev
qModelSt :: !(StateEncoder dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
qModelFinal1 :: !(TT.Conv2d EmbSize QOutHidden FifthSize OctaveSize QDType dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev -> LayerNorm '[QSliceHidden] QDType dev
qModelNorm1 :: !(TT.LayerNorm '[QOutHidden] QDType dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev -> Linear QSliceHidden 1 QDType dev
qModelFinal2 :: !(TT.Linear QOutHidden 1 QDType dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev -> Linear QSliceHidden QSliceHidden QDType dev
qModelValue1 :: !(TT.Linear EmbSize QOutHidden QDType dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev -> LayerNorm '[QSliceHidden] QDType dev
qModelValueNorm :: !(TT.LayerNorm '[QOutHidden] QDType dev)
, forall (dev :: (DeviceType, Nat)).
QModel dev -> Linear QSliceHidden 1 QDType dev
qModelValue2 :: !(TT.Linear QOutHidden 1 QDType dev)
}
deriving (Int -> QModel dev -> ShowS
[QModel dev] -> ShowS
QModel dev -> String
(Int -> QModel dev -> ShowS)
-> (QModel dev -> String)
-> ([QModel dev] -> ShowS)
-> Show (QModel dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)). Int -> QModel dev -> ShowS
forall (dev :: (DeviceType, Nat)). [QModel dev] -> ShowS
forall (dev :: (DeviceType, Nat)). QModel dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)). Int -> QModel dev -> ShowS
showsPrec :: Int -> QModel dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). QModel dev -> String
show :: QModel dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [QModel dev] -> ShowS
showList :: [QModel dev] -> ShowS
Show, (forall x. QModel dev -> Rep (QModel dev) x)
-> (forall x. Rep (QModel dev) x -> QModel dev)
-> Generic (QModel dev)
forall x. Rep (QModel dev) x -> QModel dev
forall x. QModel dev -> Rep (QModel dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (QModel dev) x -> QModel dev
forall (dev :: (DeviceType, Nat)) x.
QModel dev -> Rep (QModel dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
QModel dev -> Rep (QModel dev) x
from :: forall x. QModel dev -> Rep (QModel dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (QModel dev) x -> QModel dev
to :: forall x. Rep (QModel dev) x -> QModel dev
Generic, QModel dev -> HList (Parameters (QModel dev))
QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
(QModel dev -> HList (Parameters (QModel dev)))
-> (QModel dev -> HList (Parameters (QModel dev)) -> QModel dev)
-> Parameterized (QModel dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev))
forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev))
flattenParameters :: QModel dev -> HList (Parameters (QModel dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
replaceParameters :: QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
TT.Parameterized, Context -> QModel dev -> IO (Maybe ThunkInfo)
Proxy (QModel dev) -> String
(Context -> QModel dev -> IO (Maybe ThunkInfo))
-> (Context -> QModel dev -> IO (Maybe ThunkInfo))
-> (Proxy (QModel dev) -> String)
-> NoThunks (QModel dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> QModel dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)). Proxy (QModel dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> QModel dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> QModel dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> QModel dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> QModel dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)). Proxy (QModel dev) -> String
showTypeOf :: Proxy (QModel dev) -> String
NoThunks, QModel dev -> ()
(QModel dev -> ()) -> NFData (QModel dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). QModel dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). QModel dev -> ()
rnf :: QModel dev -> ()
NFData)
instance (IsValidDevice dev) => T.Randomizable (QSpec dev) (QModel dev) where
sample :: QSpec dev -> IO (QModel dev)
sample :: QSpec dev -> IO (QModel dev)
sample QSpec dev
QSpec = do
qModelSlc <- SliceSpec dev -> IO (SliceEncoder dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample (SliceSpec dev -> IO (SliceEncoder dev))
-> SliceSpec dev -> IO (SliceEncoder dev)
forall a b. (a -> b) -> a -> b
$ forall {k} (dev :: k). SliceSpec dev
forall (dev :: (DeviceType, Nat)). SliceSpec dev
SliceSpec @dev
qModelTr <- T.sample $ TransitionSpec @dev
qModelAct <- T.sample $ ActionSpec @dev
qModelSt <- T.sample $ StateSpec @dev
qModelFinal1 <- T.sample TT.Conv2dSpec
qModelNorm1 <- T.sample $ TT.LayerNormSpec 1e-05
qModelFinal2 <- T.sample TT.LinearSpec
qModelValue1 <- T.sample TT.LinearSpec
qModelValueNorm <- T.sample $ TT.LayerNormSpec 1e-05
qModelValue2 <- T.sample TT.LinearSpec
pure QModel{..}
instance
(IsValidDevice dev, TT.CheckIsSuffixOf '[QOutHidden] [1, QOutHidden] (QOutHidden == QOutHidden))
=> T.HasForward (QModel dev) (QEncoding dev '[]) (QTensor dev '[1])
where
forward :: QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forward :: QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forward QModel dev
model QEncoding dev '[]
encoding = QTensor dev '[1] -> QTensor dev '[1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (QTensor dev '[1] -> QTensor dev '[1])
-> QTensor dev '[1] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[1] -> QTensor dev '[1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.sigmoid (QTensor dev '[1] -> QTensor dev '[1])
-> QTensor dev '[1] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardQModel QModel dev
model QEncoding dev '[]
encoding
forwardStoch :: QModel dev -> QEncoding dev '[] -> IO (QTensor dev '[1])
forwardStoch :: QModel dev -> QEncoding dev '[] -> IO (QTensor dev '[1])
forwardStoch QModel dev
model QEncoding dev '[]
input = QTensor dev '[1] -> IO (QTensor dev '[1])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev '[1] -> IO (QTensor dev '[1]))
-> QTensor dev '[1] -> IO (QTensor dev '[1])
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model QEncoding dev '[]
input
instance
( IsValidDevice dev
, KnownNat batchSize
, 1 <= batchSize
, TT.CheckIsSuffixOf '[QOutHidden] [batchSize, QOutHidden] (QOutHidden == QOutHidden)
)
=> T.HasForward (QModel dev) (QEncoding dev '[batchSize]) (QTensor dev '[batchSize, 1])
where
forward :: QModel dev -> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forward :: QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forward QModel dev
model QEncoding dev '[batchSize]
encoding =
QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1])
-> QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.sigmoid (QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1])
-> QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall a b. (a -> b) -> a -> b
$ QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched QModel dev
model QEncoding dev '[batchSize]
encoding
forwardStoch :: QModel dev
-> QEncoding dev '[batchSize] -> IO (QTensor dev '[batchSize, 1])
forwardStoch QModel dev
model QEncoding dev '[batchSize]
input = QTensor dev '[batchSize, 1] -> IO (QTensor dev '[batchSize, 1])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev '[batchSize, 1] -> IO (QTensor dev '[batchSize, 1]))
-> QTensor dev '[batchSize, 1] -> IO (QTensor dev '[batchSize, 1])
forall a b. (a -> b) -> a -> b
$ QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model QEncoding dev '[batchSize]
input
forwardQModel
:: (IsValidDevice dev)
=> QModel dev
-> QEncoding dev '[]
-> QTensor dev '[1]
forwardQModel :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardQModel QModel dev
model QEncoding dev '[]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor dev QDType '[1, 1] -> Tensor dev QDType '[1])
-> Tensor dev QDType '[1, 1] -> Tensor dev QDType '[1]
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[1] -> Tensor dev QDType '[1, 1]
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched QModel dev
model (QEncoding dev '[1] -> Tensor dev QDType '[1, 1])
-> QEncoding dev '[1] -> Tensor dev QDType '[1, 1]
forall a b. (a -> b) -> a -> b
$ QEncoding dev '[] -> Batched (QEncoding dev '[])
forall a. Batchable a => a -> Batched a
addBatchDim QEncoding dev '[]
input
forwardQModelBatched
:: forall dev batchSize
. ( IsValidDevice dev
, 1 <= batchSize
)
=> QModel dev
-> QEncoding dev '[batchSize]
-> QTensor dev '[batchSize, 1]
forwardQModelBatched :: forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched (QModel SliceEncoder dev
slc TransitionEncoder dev
tr ActionEncoder dev
act StateEncoder dev
st Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
final1 LayerNorm '[QSliceHidden] QDType dev
norm1 Linear QSliceHidden 1 QDType dev
final2 Linear QSliceHidden QSliceHidden QDType dev
_ LayerNorm '[QSliceHidden] QDType dev
_ Linear QSliceHidden 1 QDType dev
_) (QEncoding ActionEncoding dev '[batchSize]
actEncs StateEncoding dev
stEnc) = QTensor dev '[batchSize, 1]
out2
where
actEmb :: QTensor dev (batchSize : EmbSize : PShape)
actEmb :: QTensor dev (batchSize : EmbShape)
actEmb = ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
ActionEncoding dev '[batchSize])
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ActionEncoder dev
act (SliceEncoder dev
slc, TransitionEncoder dev
tr, ActionEncoding dev '[batchSize]
actEncs)
stEmb :: QTensor dev (EmbSize : PShape)
stEmb :: QTensor dev EmbShape
stEmb = StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward StateEncoder dev
st (SliceEncoder dev
slc, TransitionEncoder dev
tr, StateEncoding dev
stEnc)
inputEmb :: QTensor dev '[batchSize, QSliceHidden, 13, 5]
inputEmb = QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
actEmb QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
stEmb
out1 :: QTensor dev (batchSize : QOutHidden : PShape)
out1 :: QTensor dev (batchSize : EmbShape)
out1 = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
final1 QTensor dev '[batchSize, QSliceHidden, 13, 5]
inputEmb
sum1 :: QTensor dev '[batchSize, QOutHidden]
sum1 :: QTensor dev '[batchSize, QSliceHidden]
sum1 = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @2 (Tensor dev QDType (DropValue '[batchSize, QSliceHidden, 13, 5] 2)
-> QTensor dev '[batchSize, QSliceHidden])
-> Tensor
dev QDType (DropValue '[batchSize, QSliceHidden, 13, 5] 2)
-> QTensor dev '[batchSize, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @2 QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
out1
out1norm :: QTensor dev '[batchSize, QOutHidden]
out1norm :: QTensor dev '[batchSize, QSliceHidden]
out1norm = QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden])
-> QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ LayerNorm '[QSliceHidden] QDType dev
-> QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward LayerNorm '[QSliceHidden] QDType dev
norm1 QTensor dev '[batchSize, QSliceHidden]
sum1
out2 :: QTensor dev '[batchSize, 1]
out2 :: QTensor dev '[batchSize, 1]
out2 = Linear QSliceHidden 1 QDType dev
-> QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, 1]
forall f a b. HasForward f a b => f -> a -> b
T.forward Linear QSliceHidden 1 QDType dev
final2 QTensor dev '[batchSize, QSliceHidden]
out1norm
forwardPolicy
:: (_)
=> QModel dev
-> QEncoding dev '[]
-> QTensor dev '[1]
forwardPolicy :: QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardPolicy = QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardQModel
forwardPolicyBatched
:: forall dev batchSize
. (_)
=> QModel dev
-> QEncoding dev '[batchSize]
-> QTensor dev '[batchSize, 1]
forwardPolicyBatched :: QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardPolicyBatched = QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched
forwardValue
:: (IsValidDevice dev)
=> QModel dev
-> StateEncoding dev
-> QTensor dev '[1]
forwardValue :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> StateEncoding dev -> QTensor dev '[1]
forwardValue (QModel SliceEncoder dev
slc TransitionEncoder dev
tr ActionEncoder dev
_ StateEncoder dev
st Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_ LayerNorm '[QSliceHidden] QDType dev
_ Linear QSliceHidden 1 QDType dev
_ Linear QSliceHidden QSliceHidden QDType dev
value1 LayerNorm '[QSliceHidden] QDType dev
norm Linear QSliceHidden 1 QDType dev
value2) StateEncoding dev
stateEncoding = Tensor dev QDType '[1]
out2
where
outSlc :: Tensor
dev
(SumDType (SumDType QDType))
(DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
outSlc = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @1 (Tensor dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1)
-> Tensor
dev
(SumDType (SumDType QDType))
(DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1))
-> Tensor
dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1)
-> Tensor
dev
(SumDType (SumDType QDType))
(DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
forall a b. (a -> b) -> a -> b
$ forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @1 (Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor
dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1))
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor
dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1)
forall a b. (a -> b) -> a -> b
$ StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward StateEncoder dev
st (SliceEncoder dev
slc, TransitionEncoder dev
tr, StateEncoding dev
stateEncoding)
out1 :: Tensor dev QDType '[QSliceHidden]
out1 = Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden])
-> Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall a b. (a -> b) -> a -> b
$ LayerNorm '[QSliceHidden] QDType dev
-> Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward LayerNorm '[QSliceHidden] QDType dev
norm (Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden])
-> Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall a b. (a -> b) -> a -> b
$ Linear QSliceHidden QSliceHidden QDType dev
-> Tensor
dev
(SumDType (SumDType QDType))
(DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
-> Tensor dev QDType '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward Linear QSliceHidden QSliceHidden QDType dev
value1 Tensor
dev
(SumDType (SumDType QDType))
(DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
outSlc
out2 :: Tensor dev QDType '[1]
out2 = Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (Tensor dev QDType '[1] -> Tensor dev QDType '[1])
-> Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.sigmoid (Tensor dev QDType '[1] -> Tensor dev QDType '[1])
-> Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall a b. (a -> b) -> a -> b
$ Linear QSliceHidden 1 QDType dev
-> Tensor dev QDType '[QSliceHidden] -> Tensor dev QDType '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward Linear QSliceHidden 1 QDType dev
value2 Tensor dev QDType '[QSliceHidden]
out1
fakeLoss
:: forall dev ps
. (IsValidDevice dev, ps ~ TT.Parameters (QModel dev))
=> QModel dev
-> QTensor dev '[]
fakeLoss :: forall (dev :: (DeviceType, Nat)) (ps :: [Type]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
model = Tensor dev QDType '[]
tzero Tensor dev QDType '[]
-> Tensor dev QDType '[] -> Tensor dev QDType '[]
forall a. Num a => a -> a -> a
* Tensor dev QDType '[]
total
where
tzero :: QTensor dev '[]
tzero :: Tensor dev QDType '[]
tzero = Tensor dev QDType '[]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
params :: HList (Parameters (QModel dev))
params = QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model
deps :: (TT.HMap' TT.ToDependent ps ys) => TT.HList ys
deps :: forall (ys :: [Type]). HMap' ToDependent ps ys => HList ys
deps = ToDependent
-> HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList ys
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ToDependent
TT.ToDependent HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
HList (Parameters (QModel dev))
params
sums :: HList
'[Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[]]
sums = SumAll
-> HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
-> HList
'[Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' SumAll
TH.SumAll HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
forall (ys :: [Type]). HMap' ToDependent ps ys => HList ys
deps
total :: Tensor dev QDType '[]
total = Add
-> Tensor dev QDType '[]
-> HList
'[Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[]]
-> Tensor dev QDType '[]
forall {k} f acc (xs :: [k]) res.
HFoldr f acc xs res =>
f -> acc -> HList xs -> res
TT.hfoldr Add
TH.Add Tensor dev QDType '[]
tzero HList
'[Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[],
Tensor dev QDType '[], Tensor dev QDType '[]]
sums
mkQModel :: forall dev. (IsValidDevice dev) => IO (QModel dev)
mkQModel :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
IO (QModel dev)
mkQModel = QSpec dev -> IO (QModel dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample (QSpec dev -> IO (QModel dev)) -> QSpec dev -> IO (QModel dev)
forall a b. (a -> b) -> a -> b
$ forall {k} (dev :: k). QSpec dev
forall (dev :: (DeviceType, Nat)). QSpec dev
QSpec @dev
loadModel :: forall dev. (IsValidDevice dev) => FilePath -> IO (QModel dev)
loadModel :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
String -> IO (QModel dev)
loadModel String
path = do
modelPlaceholder <- forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
IO (QModel dev)
mkQModel @dev
tensors
:: (TT.HMap' TT.ToDependent (TT.Parameters (QModel dev)) ts)
=> TT.HList ts <-
TT.load path
let tensorsCPU = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
TT.toDevice @'(TT.CPU, 0) @dev HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
HMap'
ToDependent
(Parameters (QModel dev))
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]] =>
HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
tensors
let tensorsDevice = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
f g.
HasToDevice device' device f g =>
f -> g
TT.toDevice @dev @'(TT.CPU, 0) HList
'[Tensor '( 'CPU, 0) QDType '[QSliceHidden, 1, 1, 1],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, 2, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, 2, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, 1, 1, 1],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, 1, 1, 1],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[5], Tensor '( 'CPU, 0) QDType '[5],
Tensor '( 'CPU, 0) QDType '[5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[1, QSliceHidden],
Tensor '( 'CPU, 0) QDType '[1],
Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[QSliceHidden],
Tensor '( 'CPU, 0) QDType '[1, QSliceHidden],
Tensor '( 'CPU, 0) QDType '[1]]
tensorsCPU
params <- TT.hmapM' TT.MakeIndependent tensorsDevice
pure $ TT.replaceParameters modelPlaceholder params
saveModel :: FilePath -> QModel dev -> IO ()
saveModel :: forall (dev :: (DeviceType, Nat)). String -> QModel dev -> IO ()
saveModel String
path QModel dev
model = HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
-> String -> IO ()
forall (tensors :: [Type]).
Castable (HList tensors) [ForeignPtr Tensor] =>
HList tensors -> String -> IO ()
TT.save (ToDependent
-> HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ToDependent
TT.ToDependent (HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]])
-> HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
'[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 2, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, 1, 1, 1],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
Tensor dev QDType '[QSliceHidden, QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[QSliceHidden],
Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
forall a b. (a -> b) -> a -> b
$ QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model) String
path
modelSize :: QModel dev -> Int
modelSize :: forall (dev :: (DeviceType, Nat)). QModel dev -> Int
modelSize QModel dev
model = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
product ([Int] -> Int) -> [[Int]] -> [Int]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]]
sizes
where
sizes :: [[Int]]
sizes = ToList
-> [[Int]]
-> HList
'[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int]]
-> [[Int]]
forall {k} f acc (xs :: [k]) res.
HFoldr f acc xs res =>
f -> acc -> HList xs -> res
TT.hfoldr ToList
TH.ToList ([] :: [[Int]]) (HList
'[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int]]
-> [[Int]])
-> HList
'[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int]]
-> [[Int]]
forall a b. (a -> b) -> a -> b
$ ShapeVal
-> HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
'[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ShapeVal
TH.ShapeVal (HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
'[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int]])
-> HList
'[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 2, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, 1, 1, 1],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
Parameter dev QDType '[QSliceHidden, QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[QSliceHidden],
Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
'[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
[Int], [Int], [Int], [Int]]
forall a b. (a -> b) -> a -> b
$ QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model
runQ
:: (IsValidDevice dev)
=> (s -> a -> QEncoding dev '[])
-> QModel dev
-> s
-> a
-> QType
runQ :: forall (dev :: (DeviceType, Nat)) s a.
IsValidDevice dev =>
(s -> a -> QEncoding dev '[]) -> QModel dev -> s -> a -> QType
runQ !s -> a -> QEncoding dev '[]
encode !QModel dev
model s
s a
a = Tensor -> QType
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> QType) -> Tensor -> QType
forall a b. (a -> b) -> a -> b
$ QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> Tensor
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model (QEncoding dev '[] -> QTensor dev '[1])
-> QEncoding dev '[] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ s -> a -> QEncoding dev '[]
encode s
s a
a
runQ'
:: (IsValidDevice dev)
=> (s -> a -> QEncoding dev '[])
-> QModel dev
-> s
-> a
-> QTensor dev '[1]
runQ' :: forall (dev :: (DeviceType, Nat)) s a.
IsValidDevice dev =>
(s -> a -> QEncoding dev '[])
-> QModel dev -> s -> a -> QTensor dev '[1]
runQ' !s -> a -> QEncoding dev '[]
encode !QModel dev
model s
s a
a = QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model (QEncoding dev '[] -> QTensor dev '[1])
-> QEncoding dev '[] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ s -> a -> QEncoding dev '[]
encode s
s a
a
runBatchedPolicy
:: forall dev batchSize
. (IsValidDevice dev, KnownNat batchSize)
=> QModel dev
-> QEncoding dev '[batchSize]
-> T.Tensor
runBatchedPolicy :: forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
actor QEncoding dev '[batchSize]
encoding = Tensor dev QDType '[batchSize, 1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (Tensor dev QDType '[batchSize, 1] -> Tensor)
-> Tensor dev QDType '[batchSize, 1] -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
TT.softmax @0 (Tensor dev QDType '[batchSize, 1]
-> Tensor dev QDType '[batchSize, 1])
-> Tensor dev QDType '[batchSize, 1]
-> Tensor dev QDType '[batchSize, 1]
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[batchSize, 1]
policy
where
policy :: QTensor dev '[batchSize, 1]
policy :: Tensor dev QDType '[batchSize, 1]
policy = case Proxy 1 -> Proxy batchSize -> OrderingI 1 batchSize
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
(proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @1) (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @batchSize) of
OrderingI 1 batchSize
EQI -> forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(OrdCond (CmpNat 1 batchSize) 'True 'True 'False ~ 'True,
GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType,
BasicArithmeticDTypeIsValid dev QDType, SumDTypeIsValid dev QDType,
MeanDTypeValidation dev QDType,
StandardFloatingPointDTypeValidation dev QDType,
KnownDevice dev) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardPolicyBatched @dev @batchSize QModel dev
actor QEncoding dev '[batchSize]
encoding
OrderingI 1 batchSize
LTI -> forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(OrdCond (CmpNat 1 batchSize) 'True 'True 'False ~ 'True,
GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType,
BasicArithmeticDTypeIsValid dev QDType, SumDTypeIsValid dev QDType,
MeanDTypeValidation dev QDType,
StandardFloatingPointDTypeValidation dev QDType,
KnownDevice dev) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardPolicyBatched @dev @batchSize QModel dev
actor QEncoding dev '[batchSize]
encoding
OrderingI 1 batchSize
GTI -> String -> Tensor dev QDType '[batchSize, 1]
forall a. HasCallStack => String -> a
error String
"batched policy: no actions"