{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
-- {-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# HLINT ignore "Use <$>" #-}
-- {-# OPTIONS_GHC -O0 #-}
-- {-# OPTIONS_GHC -v #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# OPTIONS_GHC -Wredundant-constraints #-}

-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
-- {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module RL.Model where

import Common
import Control.DeepSeq
import Data.Foldable qualified as F
import Data.Kind (Type)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (type (:~:) (Refl), type (==))
import Data.TypeNums (KnownNat, Nat, TInt (..), intVal, intVal', type (*), type (+), type (-), type (<=))
import Debug.Trace qualified as DT
import GHC.ForeignPtr qualified as Ptr
import GHC.Generics (Generic)
import GHC.TypeLits (OrderingI (..), cmpNat, sameNat)
import GreedyParser (DoubleParent (DoubleParent), SingleParent (SingleParent))
import NoThunks.Class (NoThunks (..), OnlyCheckWhnf (..), allNoThunks)
import RL.Encoding
import RL.ModelTypes
import RL.TorchHelpers (withBatchDim)
import RL.TorchHelpers qualified as TH
import Torch qualified as T
import Torch.Jit qualified as TJit
import Torch.Lens qualified as TL
import Torch.Typed qualified as TT

import System.IO.Unsafe
import Torch.Internal.Cast (cast2)
import Torch.Internal.Managed.Type.Tensor qualified as ATen

-- Global Settings
-- ===============

activation :: (IsValidDevice dev) => QTensor dev shape -> QTensor dev shape
activation :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation = Tensor dev QDType shape -> Tensor dev QDType shape
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
GeluDTypeIsValid device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.gelu

-- helpers
-- =======

expandAs :: T.Tensor -> T.Tensor -> T.Tensor
expandAs :: Tensor -> Tensor -> Tensor
expandAs Tensor
t1 Tensor
t2 = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Tensor -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_expand_as_t Tensor
t1 Tensor
t2

traceDyn :: TT.Tensor a b c -> TT.Tensor a b c
traceDyn :: forall (a :: (DeviceType, Nat)) (b :: DType) (c :: [Nat]).
Tensor a b c -> Tensor a b c
traceDyn Tensor a b c
t = [Int] -> Tensor a b c -> Tensor a b c
forall a b. Show a => a -> b -> b
DT.traceShow (Tensor -> [Int]
T.shape (Tensor -> [Int]) -> Tensor -> [Int]
forall a b. (a -> b) -> a -> b
$ Tensor a b c -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor a b c
t) Tensor a b c
t

unsafeReshape :: [Int] -> TT.Tensor dev dtype shape -> TT.Tensor dev dtype shape'
unsafeReshape :: forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
       (shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [Int]
shape Tensor dev dtype shape
t = Tensor -> Tensor dev dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev dtype shape')
-> Tensor -> Tensor dev dtype shape'
forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
T.reshape [Int]
shape (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor dev dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor dev dtype shape
t

-- Q net
-- =====

-- Learned Constant Embeddings
-- ---------------------------

data ConstEmbSpec dev (shape :: [Nat]) = ConstEmbSpec

newtype ConstEmb dev shape = ConstEmb (TT.Parameter dev QDType shape)
  deriving (Int -> ConstEmb dev shape -> ShowS
[ConstEmb dev shape] -> ShowS
ConstEmb dev shape -> String
(Int -> ConstEmb dev shape -> ShowS)
-> (ConstEmb dev shape -> String)
-> ([ConstEmb dev shape] -> ShowS)
-> Show (ConstEmb dev shape)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Int -> ConstEmb dev shape -> ShowS
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
[ConstEmb dev shape] -> ShowS
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Int -> ConstEmb dev shape -> ShowS
showsPrec :: Int -> ConstEmb dev shape -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> String
show :: ConstEmb dev shape -> String
$cshowList :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
[ConstEmb dev shape] -> ShowS
showList :: [ConstEmb dev shape] -> ShowS
Show, (forall x. ConstEmb dev shape -> Rep (ConstEmb dev shape) x)
-> (forall x. Rep (ConstEmb dev shape) x -> ConstEmb dev shape)
-> Generic (ConstEmb dev shape)
forall x. Rep (ConstEmb dev shape) x -> ConstEmb dev shape
forall x. ConstEmb dev shape -> Rep (ConstEmb dev shape) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
Rep (ConstEmb dev shape) x -> ConstEmb dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
ConstEmb dev shape -> Rep (ConstEmb dev shape) x
$cfrom :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
ConstEmb dev shape -> Rep (ConstEmb dev shape) x
from :: forall x. ConstEmb dev shape -> Rep (ConstEmb dev shape) x
$cto :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]) x.
Rep (ConstEmb dev shape) x -> ConstEmb dev shape
to :: forall x. Rep (ConstEmb dev shape) x -> ConstEmb dev shape
Generic)
  deriving newtype (ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
(ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape)))
-> (ConstEmb dev shape
    -> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape)
-> Parameterized (ConstEmb dev shape)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
$cflattenParameters :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
flattenParameters :: ConstEmb dev shape -> HList (Parameters (ConstEmb dev shape))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
replaceParameters :: ConstEmb dev shape
-> HList (Parameters (ConstEmb dev shape)) -> ConstEmb dev shape
TT.Parameterized, ConstEmb dev shape -> ()
(ConstEmb dev shape -> ()) -> NFData (ConstEmb dev shape)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> ()
$crnf :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmb dev shape -> ()
rnf :: ConstEmb dev shape -> ()
NFData, Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
Proxy (ConstEmb dev shape) -> String
(Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo))
-> (Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo))
-> (Proxy (ConstEmb dev shape) -> String)
-> NoThunks (ConstEmb dev shape)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Proxy (ConstEmb dev shape) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
noThunks :: Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> ConstEmb dev shape -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Proxy (ConstEmb dev shape) -> String
showTypeOf :: Proxy (ConstEmb dev shape) -> String
NoThunks)

instance
  (IsValidDevice dev, TT.TensorOptions shape QDType dev)
  => T.Randomizable (ConstEmbSpec dev shape) (ConstEmb dev shape)
  where
  sample :: ConstEmbSpec dev shape -> IO (ConstEmb dev shape)
  sample :: ConstEmbSpec dev shape -> IO (ConstEmb dev shape)
sample ConstEmbSpec dev shape
ConstEmbSpec = Parameter dev QDType shape -> ConstEmb dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
Parameter dev QDType shape -> ConstEmb dev shape
ConstEmb (Parameter dev QDType shape -> ConstEmb dev shape)
-> IO (Parameter dev QDType shape) -> IO (ConstEmb dev shape)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor dev QDType shape -> IO (Parameter dev QDType shape)
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
TT.makeIndependent (Tensor dev QDType shape -> IO (Parameter dev QDType shape))
-> IO (Tensor dev QDType shape) -> IO (Parameter dev QDType shape)
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor dev QDType shape)
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
TT.randn)

instance T.HasForward (ConstEmb dev size) () (QTensor dev size) where
  forward :: ConstEmb dev size -> () -> QTensor dev size
  forward :: ConstEmb dev size -> () -> QTensor dev size
forward (ConstEmb Parameter dev QDType size
emb) () = Parameter dev QDType size -> QTensor dev size
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
TT.toDependent Parameter dev QDType size
emb
  forwardStoch :: ConstEmb dev size -> () -> IO (QTensor dev size)
  forwardStoch :: ConstEmb dev size -> () -> IO (QTensor dev size)
forwardStoch ConstEmb dev size
model ()
input = QTensor dev size -> IO (QTensor dev size)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev size -> IO (QTensor dev size))
-> QTensor dev size -> IO (QTensor dev size)
forall a b. (a -> b) -> a -> b
$ ConstEmb dev size -> () -> QTensor dev size
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev size
model ()
input

-- Slice Encoder
-- -------------

data SliceSpec dev = SliceSpec

data SliceEncoder dev = SliceEncoder
  { forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
_slcL1 :: !(TT.Conv2d 1 QSliceHidden 1 1 QDType dev) -- !(TT.Linear (PSize spec) hidden QDType QDevice)
  , forall (dev :: (DeviceType, Nat)).
SliceEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_slcL2 :: !(TT.Conv2d QSliceHidden EmbSize FifthSize OctaveSize QDType dev) -- !(TT.Linear hidden (EmbSize spec) QDType QDevice)
  , forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> ConstEmb dev EmbShape
_slcStart :: !(ConstEmb dev EmbShape)
  , forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> ConstEmb dev EmbShape
_slcStop :: !(ConstEmb dev EmbShape)
  -- TODO: learn embedding for empty slice
  }
  deriving (Int -> SliceEncoder dev -> ShowS
[SliceEncoder dev] -> ShowS
SliceEncoder dev -> String
(Int -> SliceEncoder dev -> ShowS)
-> (SliceEncoder dev -> String)
-> ([SliceEncoder dev] -> ShowS)
-> Show (SliceEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)). Int -> SliceEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [SliceEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)). Int -> SliceEncoder dev -> ShowS
showsPrec :: Int -> SliceEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> String
show :: SliceEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [SliceEncoder dev] -> ShowS
showList :: [SliceEncoder dev] -> ShowS
Show, (forall x. SliceEncoder dev -> Rep (SliceEncoder dev) x)
-> (forall x. Rep (SliceEncoder dev) x -> SliceEncoder dev)
-> Generic (SliceEncoder dev)
forall x. Rep (SliceEncoder dev) x -> SliceEncoder dev
forall x. SliceEncoder dev -> Rep (SliceEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (SliceEncoder dev) x -> SliceEncoder dev
forall (dev :: (DeviceType, Nat)) x.
SliceEncoder dev -> Rep (SliceEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
SliceEncoder dev -> Rep (SliceEncoder dev) x
from :: forall x. SliceEncoder dev -> Rep (SliceEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (SliceEncoder dev) x -> SliceEncoder dev
to :: forall x. Rep (SliceEncoder dev) x -> SliceEncoder dev
Generic, SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
(SliceEncoder dev -> HList (Parameters (SliceEncoder dev)))
-> (SliceEncoder dev
    -> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev)
-> Parameterized (SliceEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
forall (dev :: (DeviceType, Nat)).
SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
flattenParameters :: SliceEncoder dev -> HList (Parameters (SliceEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
replaceParameters :: SliceEncoder dev
-> HList (Parameters (SliceEncoder dev)) -> SliceEncoder dev
TT.Parameterized, Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
Proxy (SliceEncoder dev) -> String
(Context -> SliceEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> SliceEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (SliceEncoder dev) -> String)
-> NoThunks (SliceEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (SliceEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> SliceEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (SliceEncoder dev) -> String
showTypeOf :: Proxy (SliceEncoder dev) -> String
NoThunks, SliceEncoder dev -> ()
(SliceEncoder dev -> ()) -> NFData (SliceEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). SliceEncoder dev -> ()
rnf :: SliceEncoder dev -> ()
NFData)

instance (IsValidDevice dev) => T.Randomizable (SliceSpec dev) (SliceEncoder dev) where
  sample :: SliceSpec dev -> IO (SliceEncoder dev)
  sample :: SliceSpec dev -> IO (SliceEncoder dev)
sample SliceSpec dev
_ =
    Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> ConstEmb dev '[QSliceHidden, 13, 5]
-> SliceEncoder dev
Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> ConstEmb dev EmbShape
-> ConstEmb dev EmbShape
-> SliceEncoder dev
forall (dev :: (DeviceType, Nat)).
Conv2d 1 QSliceHidden 1 1 QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> ConstEmb dev EmbShape
-> ConstEmb dev EmbShape
-> SliceEncoder dev
SliceEncoder
      (Conv2d 1 QSliceHidden 1 1 QDType dev
 -> Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
 -> ConstEmb dev '[QSliceHidden, 13, 5]
 -> ConstEmb dev '[QSliceHidden, 13, 5]
 -> SliceEncoder dev)
-> IO (Conv2d 1 QSliceHidden 1 1 QDType dev)
-> IO
     (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
      -> ConstEmb dev '[QSliceHidden, 13, 5]
      -> ConstEmb dev '[QSliceHidden, 13, 5]
      -> SliceEncoder dev)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Conv2dSpec 1 QSliceHidden 1 1 QDType dev
-> IO (Conv2d 1 QSliceHidden 1 1 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec 1 QSliceHidden 1 1 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
TT.Conv2dSpec
      IO
  (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
   -> ConstEmb dev '[QSliceHidden, 13, 5]
   -> ConstEmb dev '[QSliceHidden, 13, 5]
   -> SliceEncoder dev)
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
-> IO
     (ConstEmb dev '[QSliceHidden, 13, 5]
      -> ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
TT.Conv2dSpec
      IO
  (ConstEmb dev '[QSliceHidden, 13, 5]
   -> ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
-> IO (ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ConstEmbSpec dev '[QSliceHidden, 13, 5]
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
forall spec f. Randomizable spec f => spec -> IO f
T.sample (forall {k} (dev :: k) (shape :: [Nat]). ConstEmbSpec dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmbSpec dev shape
ConstEmbSpec @dev)
      IO (ConstEmb dev '[QSliceHidden, 13, 5] -> SliceEncoder dev)
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
-> IO (SliceEncoder dev)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ConstEmbSpec dev '[QSliceHidden, 13, 5]
-> IO (ConstEmb dev '[QSliceHidden, 13, 5])
forall spec f. Randomizable spec f => spec -> IO f
T.sample (forall {k} (dev :: k) (shape :: [Nat]). ConstEmbSpec dev shape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
ConstEmbSpec dev shape
ConstEmbSpec @dev)

-- | HasFoward for slice (unbatched)
instance
  (embshape ~ EmbShape, IsValidDevice dev)
  => T.HasForward (SliceEncoder dev) (SliceEncoding dev '[]) (QTensor dev embshape)
  where
  forward :: SliceEncoder dev -> SliceEncoding dev '[] -> QTensor dev embshape
forward (SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 ConstEmb dev EmbShape
_ ConstEmb dev EmbShape
_) SliceEncoding dev '[]
slice = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 Tensor dev QDType '[1, QSliceHidden, 13, 5]
QTensor dev (1 : EmbShape)
out2
   where
    input :: Tensor dev QDType '[1, 1, 13, 5]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 (Tensor dev QDType '[1, 13, 5] -> Tensor dev QDType '[1, 1, 13, 5])
-> Tensor dev QDType '[1, 13, 5]
-> Tensor dev QDType '[1, 1, 13, 5]
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 (Tensor dev QDType '[13, 5] -> Tensor dev QDType '[1, 13, 5])
-> Tensor dev QDType '[13, 5] -> Tensor dev QDType '[1, 13, 5]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[] -> QTensor dev ('[] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[]
slice
    out1 :: QTensor dev (1 : QSliceHidden : PShape)
    out1 :: QTensor dev (1 : EmbShape)
out1 = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
       {w2 :: (DeviceType, Nat)}.
(Assert
   (OrdCond
      (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond
      (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...),
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
 KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
 KnownNat (Fst padding), KnownNat (Snd stride),
 KnownNat (Snd padding)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
     w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w2
     w1
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
TT.conv2dForward @'(1, 1) @'(0, 0) Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Tensor dev QDType '[1, 1, 13, 5]
input
    out2 :: QTensor dev (1 : EmbShape)
    out2 :: QTensor dev (1 : EmbShape)
out2 = QTensor dev (1 : EmbShape) -> QTensor dev (1 : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (1 : EmbShape) -> QTensor dev (1 : EmbShape))
-> QTensor dev (1 : EmbShape) -> QTensor dev (1 : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
       {w2 :: (DeviceType, Nat)}.
(Assert
   (OrdCond
      (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond
      (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...),
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
 KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
 KnownNat (Fst padding), KnownNat (Snd stride),
 KnownNat (Snd padding)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
     w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w2
     w1
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 Tensor dev QDType '[1, QSliceHidden, 13, 5]
QTensor dev (1 : EmbShape)
out1
  forwardStoch :: SliceEncoder dev
-> SliceEncoding dev '[] -> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> (SliceEncoding dev '[] -> QTensor dev embshape)
-> SliceEncoding dev '[]
-> IO (QTensor dev embshape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SliceEncoder dev -> SliceEncoding dev '[] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model

-- | HasFoward for slice (batched)
instance
  ( IsValidDevice dev
  , embshape ~ '[batchSize, EmbSize, FifthSize, OctaveSize]
  )
  => T.HasForward (SliceEncoder dev) (SliceEncoding dev '[batchSize]) (QTensor dev embshape)
  where
  forward :: SliceEncoder dev
-> SliceEncoding dev '[batchSize] -> QTensor dev embshape
forward (SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 ConstEmb dev EmbShape
_ ConstEmb dev EmbShape
_) SliceEncoding dev '[batchSize]
slice = QTensor dev embshape
QTensor dev (batchSize : EmbShape)
out2
   where
    input :: Tensor dev QDType '[batchSize, 1, 13, 5]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 (Tensor dev QDType '[batchSize, 13, 5]
 -> Tensor dev QDType '[batchSize, 1, 13, 5])
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor dev QDType '[batchSize, 1, 13, 5]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[batchSize]
-> QTensor dev ('[batchSize] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[batchSize]
slice
    out1 :: QTensor dev '[batchSize, QSliceHidden, FifthSize, OctaveSize]
    out1 :: QTensor dev (batchSize : EmbShape)
out1 = QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(0, 0) Conv2d 1 QSliceHidden 1 1 QDType dev
l1 Tensor dev QDType '[batchSize, 1, 13, 5]
input
    out2 :: QTensor dev '[batchSize, EmbSize, FifthSize, OctaveSize]
    out2 :: QTensor dev (batchSize : EmbShape)
out2 = QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
l2 Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
out1
  forwardStoch :: SliceEncoder dev
-> SliceEncoding dev '[batchSize] -> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> (SliceEncoding dev '[batchSize] -> QTensor dev embshape)
-> SliceEncoding dev '[batchSize]
-> IO (QTensor dev embshape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SliceEncoder dev
-> SliceEncoding dev '[batchSize] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model

-- | HasForward for slice wrappend in QStartStop (unbatched).
instance
  (embshape ~ EmbShape, IsValidDevice dev)
  => TT.HasForward (SliceEncoder dev) (QStartStop dev '[] (SliceEncoding dev '[])) (QTensor dev embshape)
  where
  forward :: SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> QTensor dev embshape
forward model :: SliceEncoder dev
model@(SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
_ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_ ConstEmb dev EmbShape
start ConstEmb dev EmbShape
stop) (QStartStop Tensor dev 'Int64 '[]
tag SliceEncoding dev '[]
input) = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 Tensor dev QDType '[1, QSliceHidden, 13, 5]
QTensor dev (1 : EmbShape)
out
   where
    -- compute the possible outputs for start/stop/inner
    outStart :: QTensor dev (EmbSize : PShape)
    outStart :: QTensor dev EmbShape
outStart = ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
start ()
    outStop :: QTensor dev (EmbSize : PShape)
    outStop :: QTensor dev EmbShape
outStop = ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
stop ()
    outInner :: QTensor dev (EmbSize : PShape)
    outInner :: QTensor dev EmbShape
outInner = SliceEncoder dev
-> SliceEncoding dev '[] -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model SliceEncoding dev '[]
input
    -- combine the outputs into one tensor
    combined :: QTensor dev (3 : EmbSize : PShape)
    combined :: QTensor dev (3 : EmbShape)
combined = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.stack @0 (HList
   '[QTensor dev '[QSliceHidden, 13, 5],
     QTensor dev '[QSliceHidden, 13, 5],
     QTensor dev '[QSliceHidden, 13, 5]]
 -> QTensor dev (3 : EmbShape))
-> HList
     '[QTensor dev '[QSliceHidden, 13, 5],
       QTensor dev '[QSliceHidden, 13, 5],
       QTensor dev '[QSliceHidden, 13, 5]]
-> QTensor dev (3 : EmbShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
outStart QTensor dev '[QSliceHidden, 13, 5]
-> HList
     '[QTensor dev '[QSliceHidden, 13, 5],
       QTensor dev '[QSliceHidden, 13, 5]]
-> HList
     '[QTensor dev '[QSliceHidden, 13, 5],
       QTensor dev '[QSliceHidden, 13, 5],
       QTensor dev '[QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
outInner QTensor dev '[QSliceHidden, 13, 5]
-> HList '[QTensor dev '[QSliceHidden, 13, 5]]
-> HList
     '[QTensor dev '[QSliceHidden, 13, 5],
       QTensor dev '[QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
outStop QTensor dev '[QSliceHidden, 13, 5]
-> HList '[] -> HList '[QTensor dev '[QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
    -- use gather to select the right output.
    -- gather can select different elements from 'dim' for each position,
    -- so we expand the tag to the right shape, selecting the *same* 'dim'-index everywhere
    tag' :: TT.Tensor dev TT.Int64 (1 : EmbSize : PShape)
    tag' :: Tensor dev 'Int64 (1 : EmbShape)
tag' = Bool
-> Tensor dev 'Int64 '[1, 1, 1, 1]
-> Tensor dev 'Int64 (1 : EmbShape)
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
TT.expand Bool
False (Tensor dev 'Int64 '[1, 1, 1, 1]
 -> Tensor dev 'Int64 (1 : EmbShape))
-> Tensor dev 'Int64 '[1, 1, 1, 1]
-> Tensor dev 'Int64 (1 : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape @[1, 1, 1, 1] Tensor dev 'Int64 '[]
tag
    out :: QTensor dev (1 : EmbSize : PShape)
    out :: QTensor dev (1 : EmbShape)
out = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
TT.gatherDim @0 Tensor dev 'Int64 '[1, QSliceHidden, 13, 5]
Tensor dev 'Int64 (1 : EmbShape)
tag' Tensor dev QDType '[3, QSliceHidden, 13, 5]
QTensor dev (3 : EmbShape)
combined
  forwardStoch :: SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model QStartStop dev '[] (SliceEncoding dev '[])
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model QStartStop dev '[] (SliceEncoding dev '[])
input

-- | HasForward for slice wrapped in QStartStop (batched).
instance
  ( IsValidDevice dev
  , embshape ~ (batchSize : EmbSize : PShape)
  )
  => TT.HasForward (SliceEncoder dev) (QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])) (QTensor dev embshape)
  where
  forward :: SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev embshape
forward model :: SliceEncoder dev
model@(SliceEncoder Conv2d 1 QSliceHidden 1 1 QDType dev
_ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_ ConstEmb dev EmbShape
start ConstEmb dev EmbShape
stop) (QStartStop Tensor dev 'Int64 '[batchSize]
tag SliceEncoding dev '[batchSize]
input) = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 Tensor dev QDType '[1, batchSize, QSliceHidden, 13, 5]
out
   where
    -- compute the possible outputs for start/stop/inner
    outStart :: QTensor dev (batchSize : EmbSize : PShape)
    outStart :: QTensor dev (batchSize : EmbShape)
outStart = Tensor -> QTensor dev (batchSize : EmbShape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev (batchSize : EmbShape))
-> Tensor -> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
expandAs (QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[QSliceHidden, 13, 5] -> Tensor)
-> QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall a b. (a -> b) -> a -> b
$ ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
start ()) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner
    outStop :: QTensor dev (batchSize : EmbSize : PShape)
    outStop :: QTensor dev (batchSize : EmbShape)
outStop = Tensor -> QTensor dev (batchSize : EmbShape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> QTensor dev (batchSize : EmbShape))
-> Tensor -> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
expandAs (QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[QSliceHidden, 13, 5] -> Tensor)
-> QTensor dev '[QSliceHidden, 13, 5] -> Tensor
forall a b. (a -> b) -> a -> b
$ ConstEmb dev '[QSliceHidden, 13, 5]
-> () -> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
TT.forward ConstEmb dev '[QSliceHidden, 13, 5]
ConstEmb dev EmbShape
stop ()) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner
    outInner :: QTensor dev (batchSize : EmbSize : PShape)
    outInner :: QTensor dev (batchSize : EmbShape)
outInner = SliceEncoder dev
-> SliceEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model SliceEncoding dev '[batchSize]
input
    -- combine the outputs into one tensor
    combined :: QTensor dev (3 : batchSize : EmbSize : PShape)
    combined :: QTensor dev (3 : batchSize : EmbShape)
combined = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.stack @0 (HList
   '[QTensor dev '[batchSize, QSliceHidden, 13, 5],
     QTensor dev '[batchSize, QSliceHidden, 13, 5],
     QTensor dev '[batchSize, QSliceHidden, 13, 5]]
 -> QTensor dev (3 : batchSize : EmbShape))
-> HList
     '[QTensor dev '[batchSize, QSliceHidden, 13, 5],
       QTensor dev '[batchSize, QSliceHidden, 13, 5],
       QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> QTensor dev (3 : batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outStart QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> HList
     '[QTensor dev '[batchSize, QSliceHidden, 13, 5],
       QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> HList
     '[QTensor dev '[batchSize, QSliceHidden, 13, 5],
       QTensor dev '[batchSize, QSliceHidden, 13, 5],
       QTensor dev '[batchSize, QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> HList '[QTensor dev '[batchSize, QSliceHidden, 13, 5]]
-> HList
     '[QTensor dev '[batchSize, QSliceHidden, 13, 5],
       QTensor dev '[batchSize, QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outStop QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> HList '[]
-> HList '[QTensor dev '[batchSize, QSliceHidden, 13, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
    -- use gather to select the right output.
    -- gather can select different elements from 'dim' for each position,
    -- so we expand the tag to the right shape, selecting the *same* 'dim'-index everywhere
    tag' :: TT.Tensor dev 'TT.Int64 (1 : batchSize : EmbSize : PShape)
    tag' :: Tensor dev 'Int64 (1 : batchSize : EmbShape)
tag' =
      Tensor -> Tensor dev 'Int64 (1 : batchSize : EmbShape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor
        (Tensor -> Tensor dev 'Int64 (1 : batchSize : EmbShape))
-> Tensor -> Tensor dev 'Int64 (1 : batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Dim -> Tensor -> Tensor
T.unsqueeze (Int -> Dim
T.Dim Int
0)
        (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
expandAs
          ([Int] -> Tensor -> Tensor
T.reshape [-Int
1, Int
1, Int
1, Int
1] (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor dev 'Int64 '[batchSize] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor dev 'Int64 '[batchSize]
tag)
        (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, QSliceHidden, 13, 5] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
outInner
    out :: Tensor dev QDType '[1, batchSize, QSliceHidden, 13, 5]
out = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
TT.gatherDim @0 Tensor dev 'Int64 '[1, batchSize, QSliceHidden, 13, 5]
Tensor dev 'Int64 (1 : batchSize : EmbShape)
tag' Tensor dev QDType '[3, batchSize, QSliceHidden, 13, 5]
QTensor dev (3 : batchSize : EmbShape)
combined
  forwardStoch :: SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> IO (QTensor dev embshape)
forwardStoch SliceEncoder dev
model QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
model QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
input

-- Transition Encoder
-- ------------------

data TransitionSpec dev = TransitionSpec

data TransitionEncoder dev = TransitionEncoder
  { forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing :: !(TT.Conv2d 2 QTransHidden FifthSize OctaveSize QDType dev) -- !(TT.Linear (ESize spec) hidden QDType QDevice)
  , forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: !(TT.Conv2d 2 QTransHidden FifthSize OctaveSize QDType dev) -- !(TT.Linear (ESize spec) hidden QDType QDevice)
  , forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Left :: !(TT.Conv2d 1 QTransHidden 1 1 QDType dev) -- !(TT.Linear (PSize spec) hidden QDType QDevice)
  , forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: !(TT.Conv2d 1 QTransHidden 1 1 QDType dev) -- !(TT.Linear (PSize spec) hidden QDType QDevice)
  , forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> ConstEmb dev '[QSliceHidden]
trL1Root :: !(ConstEmb dev '[QTransHidden])
  , forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL2 :: !(TT.Conv2d QTransHidden (EmbSize) FifthSize OctaveSize QDType dev) -- !(TT.Linear hidden (EmbSize) QDType QDevice)
  }
  deriving (Int -> TransitionEncoder dev -> ShowS
[TransitionEncoder dev] -> ShowS
TransitionEncoder dev -> String
(Int -> TransitionEncoder dev -> ShowS)
-> (TransitionEncoder dev -> String)
-> ([TransitionEncoder dev] -> ShowS)
-> Show (TransitionEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)).
Int -> TransitionEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [TransitionEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)).
Int -> TransitionEncoder dev -> ShowS
showsPrec :: Int -> TransitionEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> String
show :: TransitionEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [TransitionEncoder dev] -> ShowS
showList :: [TransitionEncoder dev] -> ShowS
Show, (forall x. TransitionEncoder dev -> Rep (TransitionEncoder dev) x)
-> (forall x.
    Rep (TransitionEncoder dev) x -> TransitionEncoder dev)
-> Generic (TransitionEncoder dev)
forall x. Rep (TransitionEncoder dev) x -> TransitionEncoder dev
forall x. TransitionEncoder dev -> Rep (TransitionEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (TransitionEncoder dev) x -> TransitionEncoder dev
forall (dev :: (DeviceType, Nat)) x.
TransitionEncoder dev -> Rep (TransitionEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
TransitionEncoder dev -> Rep (TransitionEncoder dev) x
from :: forall x. TransitionEncoder dev -> Rep (TransitionEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (TransitionEncoder dev) x -> TransitionEncoder dev
to :: forall x. Rep (TransitionEncoder dev) x -> TransitionEncoder dev
Generic, TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
(TransitionEncoder dev
 -> HList (Parameters (TransitionEncoder dev)))
-> (TransitionEncoder dev
    -> HList (Parameters (TransitionEncoder dev))
    -> TransitionEncoder dev)
-> Parameterized (TransitionEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
flattenParameters :: TransitionEncoder dev -> HList (Parameters (TransitionEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
replaceParameters :: TransitionEncoder dev
-> HList (Parameters (TransitionEncoder dev))
-> TransitionEncoder dev
TT.Parameterized, Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
Proxy (TransitionEncoder dev) -> String
(Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (TransitionEncoder dev) -> String)
-> NoThunks (TransitionEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (TransitionEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> TransitionEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (TransitionEncoder dev) -> String
showTypeOf :: Proxy (TransitionEncoder dev) -> String
NoThunks, TransitionEncoder dev -> ()
(TransitionEncoder dev -> ()) -> NFData (TransitionEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). TransitionEncoder dev -> ()
rnf :: TransitionEncoder dev -> ()
NFData)

instance (IsValidDevice dev) => T.Randomizable (TransitionSpec dev) (TransitionEncoder dev) where
  sample :: TransitionSpec dev -> IO (TransitionEncoder dev)
  sample :: TransitionSpec dev -> IO (TransitionEncoder dev)
sample TransitionSpec dev
_ = do
    trL1Passing <- Conv2dSpec 2 QSliceHidden 13 5 QDType dev
-> IO (Conv2d 2 QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec 2 QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
TT.Conv2dSpec
    trL1Inner <- T.sample TT.Conv2dSpec
    trL1Left <- T.sample TT.Conv2dSpec
    trL1Right <- T.sample TT.Conv2dSpec
    trL1Root <- T.sample $ ConstEmbSpec @dev
    trL2 <- T.sample TT.Conv2dSpec
    pure $ TransitionEncoder{..}

-- | HasForward for transitions (unbatched)
instance
  forall dev embshape
   . ( IsValidDevice dev
     , embshape ~ (EmbSize : PShape)
     )
  => T.HasForward (TransitionEncoder dev) (TransitionEncoding dev '[]) (QTensor dev embshape)
  where
  forward :: TransitionEncoder dev
-> TransitionEncoding dev '[] -> QTensor dev embshape
forward TransitionEncoder{Conv2d 1 QSliceHidden 1 1 QDType dev
Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
ConstEmb dev '[QSliceHidden]
trL1Passing :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> ConstEmb dev '[QSliceHidden]
trL2 :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: ConstEmb dev '[QSliceHidden]
trL2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
..} TransitionEncoding{QTensor dev '[]
SliceEncoding dev '[]
QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencPassing :: QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencInner :: QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencLeft :: SliceEncoding dev '[]
trencRight :: SliceEncoding dev '[]
trencRoot :: QTensor dev '[]
trencRoot :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRight :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencInner :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
trencPassing :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
..} =
    forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor
   dev
   QDType
   '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
 -> QTensor dev embshape)
-> Tensor
     dev
     QDType
     '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> QTensor dev embshape
forall a b. (a -> b) -> a -> b
$
      Tensor
  dev
  QDType
  '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> Tensor
     dev
     QDType
     '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (Tensor
   dev
   QDType
   '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
 -> Tensor
      dev
      QDType
      '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1])
-> Tensor
     dev
     QDType
     '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
-> Tensor
     dev
     QDType
     '[1, QSliceHidden, ((13 + 12) - 13) + 1, ((5 + 4) - 5) + 1]
forall a b. (a -> b) -> a -> b
$
        forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
       {w2 :: (DeviceType, Nat)}.
(Assert
   (OrdCond
      (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond
      (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...),
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
 KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
 KnownNat (Fst padding), KnownNat (Snd stride),
 KnownNat (Snd padding)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
     w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w2
     w1
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL2 (Tensor dev QDType '[1, QSliceHidden, 13, 5]
 -> Tensor
      dev
      QDType
      '[1, QSliceHidden,
        Div ((13 + (2 * Fst '(FifthPadding, 2))) - 13) (Fst '(1, 1)) + 1,
        Div ((5 + (2 * Snd '(FifthPadding, 2))) - 5) (Snd '(1, 1)) + 1])
-> Tensor dev QDType '[1, QSliceHidden, 13, 5]
-> Tensor
     dev
     QDType
     '[1, QSliceHidden,
       Div ((13 + (2 * Fst '(FifthPadding, 2))) - 13) (Fst '(1, 1)) + 1,
       Div ((5 + (2 * Snd '(FifthPadding, 2))) - 5) (Snd '(1, 1)) + 1]
forall a b. (a -> b) -> a -> b
$
          forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
all
   where
    runConv
      :: (KnownNat nin)
      => TT.Conv2d nin QTransHidden FifthSize OctaveSize QDType dev
      -> QBoundedList dev QDType MaxEdges '[] (nin : PShape)
      -> QTensor dev (QTransHidden : PShape)
    runConv :: forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (nin : PShape)
-> QTensor dev EmbShape
runConv Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv (QBoundedList QTensor dev ('[] ++ '[QSliceHidden])
mask Tensor dev QDType (('[] ++ '[QSliceHidden]) ++ (nin : PShape))
edges) = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @0 (Tensor
   dev
   QDType
   (CheckBroadcast
      '[QSliceHidden, 1, 1, 1]
      '[QSliceHidden, QSliceHidden, 13, 5]
      (ComputeBroadcast
         '[1, 1, 1, QSliceHidden]
         (ReverseImpl '[QSliceHidden, QSliceHidden, 13, 5] '[])))
 -> QTensor dev EmbShape)
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[QSliceHidden, 1, 1, 1]
        '[QSliceHidden, QSliceHidden, 13, 5]
        (ComputeBroadcast
           '[1, 1, 1, QSliceHidden]
           (ReverseImpl '[QSliceHidden, QSliceHidden, 13, 5] '[])))
-> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[QSliceHidden, 1, 1, 1]
-> Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[QSliceHidden, 1, 1, 1]
        '[QSliceHidden, QSliceHidden, 13, 5]
        (ComputeBroadcast
           '[1, 1, 1, QSliceHidden]
           (ReverseImpl '[QSliceHidden, QSliceHidden, 13, 5] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul Tensor dev QDType '[QSliceHidden, 1, 1, 1]
mask' Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5]
QTensor dev (QSliceHidden : EmbShape)
out
     where
      out :: QTensor dev (MaxEdges : QTransHidden : PShape)
      out :: QTensor dev (QSliceHidden : EmbShape)
out = QTensor dev (QSliceHidden : EmbShape)
-> QTensor dev (QSliceHidden : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (QSliceHidden : EmbShape)
 -> QTensor dev (QSliceHidden : EmbShape))
-> QTensor dev (QSliceHidden : EmbShape)
-> QTensor dev (QSliceHidden : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
       {w2 :: (DeviceType, Nat)}.
(Assert
   (OrdCond
      (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond
      (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...),
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
 KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
 KnownNat (Fst padding), KnownNat (Snd stride),
 KnownNat (Snd padding)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
     w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w2
     w1
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin QSliceHidden 13 5 QDType dev
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[QSliceHidden, nin, 13, 5]
Tensor dev QDType (('[] ++ '[QSliceHidden]) ++ (nin : PShape))
edges
      mask' :: QTensor dev '[MaxEdges, 1, 1, 1]
      mask' :: Tensor dev QDType '[QSliceHidden, 1, 1, 1]
mask' = QTensor dev '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden, 1, 1, 1]
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape QTensor dev '[QSliceHidden]
QTensor dev ('[] ++ '[QSliceHidden])
mask
    runSlice :: Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
-> Tensor device dtype '[inputSize0, inputSize1]
-> Tensor
     device
     dtype
     (SqueezeDimCheck
        '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
          (inputSize1 - kernelSize1) + 1]
        0
        (SqueezeDimImpl
           '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
             (inputSize1 - kernelSize1) + 1]
           0))
runSlice Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
conv Tensor device dtype '[inputSize0, inputSize1]
slice = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor
   device
   dtype
   '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
     (inputSize1 - kernelSize1) + 1]
 -> Tensor
      device
      dtype
      (SqueezeDimCheck
         '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
           (inputSize1 - kernelSize1) + 1]
         0
         (SqueezeDimImpl
            '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
              (inputSize1 - kernelSize1) + 1]
            0)))
-> Tensor
     device
     dtype
     '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
       (inputSize1 - kernelSize1) + 1]
-> Tensor
     device
     dtype
     (SqueezeDimCheck
        '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
          (inputSize1 - kernelSize1) + 1]
        0
        (SqueezeDimImpl
           '[1, outputChannelSize, (inputSize0 - kernelSize0) + 1,
             (inputSize1 - kernelSize1) + 1]
           0))
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
       {w2 :: (DeviceType, Nat)}.
(Assert
   (OrdCond
      (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond
      (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...),
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
 KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
 KnownNat (Fst padding), KnownNat (Snd stride),
 KnownNat (Snd padding)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
     w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w2
     w1
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
TT.conv2dForward @'(1, 1) @'(0, 0) Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
conv Tensor device dtype '[1, 1, inputSize0, inputSize1]
input
     where
      input :: Tensor device dtype '[1, 1, inputSize0, inputSize1]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 (Tensor device dtype '[1, inputSize0, inputSize1]
 -> Tensor device dtype '[1, 1, inputSize0, inputSize1])
-> Tensor device dtype '[1, inputSize0, inputSize1]
-> Tensor device dtype '[1, 1, inputSize0, inputSize1]
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor device dtype '[inputSize0, inputSize1]
slice
    pass :: QTensor dev (QTransHidden : PShape)
    pass :: QTensor dev EmbShape
pass = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
-> QTensor dev EmbShape
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (nin : PShape)
-> QTensor dev EmbShape
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencPassing
    inner :: QTensor dev (QTransHidden : PShape)
    inner :: QTensor dev EmbShape
inner = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
-> QTensor dev EmbShape
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[] (nin : PShape)
-> QTensor dev EmbShape
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner QBoundedList dev QDType QSliceHidden '[] (2 : PShape)
trencInner
    left :: QTensor dev (QTransHidden : PShape)
    left :: QTensor dev EmbShape
left = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {kernelSize0 :: Nat} {inputSize0 :: Nat}
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {outputChannelSize :: Nat}
       {dtype :: DType} {device :: (DeviceType, Nat)}.
(Assert
   (OrdCond (CmpNat kernelSize0 inputSize0) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat kernelSize1 inputSize1) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
   (TypeError ...),
 KnownNat outputChannelSize, KnownNat kernelSize0,
 KnownNat kernelSize1, KnownNat inputSize0, KnownNat inputSize1) =>
Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
-> Tensor device dtype '[inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[outputChannelSize, (inputSize0 - kernelSize0) + 1,
       (inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Left (Tensor dev QDType '[13, 5]
 -> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[] -> QTensor dev ('[] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[]
trencLeft
    right :: QTensor dev (QTransHidden : PShape)
    right :: QTensor dev EmbShape
right = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {kernelSize0 :: Nat} {inputSize0 :: Nat}
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {outputChannelSize :: Nat}
       {dtype :: DType} {device :: (DeviceType, Nat)}.
(Assert
   (OrdCond (CmpNat kernelSize0 inputSize0) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat kernelSize1 inputSize1) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
   (TypeError ...),
 KnownNat outputChannelSize, KnownNat kernelSize0,
 KnownNat kernelSize1, KnownNat inputSize0, KnownNat inputSize1) =>
Conv2d 1 outputChannelSize kernelSize0 kernelSize1 dtype device
-> Tensor device dtype '[inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[outputChannelSize, (inputSize0 - kernelSize0) + 1,
       (inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right (Tensor dev QDType '[13, 5]
 -> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[13, 5]
-> Tensor dev QDType '[QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[] -> QTensor dev ('[] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[]
trencRight
    root :: QTensor dev '[QTransHidden, 1, 1]
    root :: QTensor dev '[QSliceHidden, 1, 1]
root = Tensor
  dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
-> QTensor dev '[QSliceHidden, 1, 1]
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape (Tensor
   dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
 -> QTensor dev '[QSliceHidden, 1, 1])
-> Tensor
     dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
-> QTensor dev '[QSliceHidden, 1, 1]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[]
-> QTensor dev '[QSliceHidden]
-> Tensor
     dev QDType (ReverseImpl (ReverseImpl '[QSliceHidden] '[]) '[])
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul QTensor dev '[]
trencRoot (QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (ConstEmb dev '[QSliceHidden] -> () -> QTensor dev '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[QSliceHidden]
trL1Root ()))
    all :: QTensor dev (QTransHidden : PShape)
    all :: QTensor dev EmbShape
all = (Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
pass Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
inner Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
left Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[QSliceHidden, 13, 5]
QTensor dev EmbShape
right) Tensor dev QDType '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 1, 1]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[QSliceHidden, 13, 5]
        '[QSliceHidden, 1, 1]
        (ComputeBroadcast
           (ReverseImpl '[QSliceHidden, 13, 5] '[]) '[1, 1, QSliceHidden]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` QTensor dev '[QSliceHidden, 1, 1]
root

  forwardStoch :: TransitionEncoder dev
-> TransitionEncoding dev '[] -> IO (QTensor dev embshape)
forwardStoch TransitionEncoder dev
tr TransitionEncoding dev '[]
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[]
input

-- | HasForward for transitions (batched)
instance
  forall dev batchSize embshape
   . ( IsValidDevice dev
     , embshape ~ (batchSize : EmbSize : PShape)
     )
  => T.HasForward (TransitionEncoder dev) (TransitionEncoding dev '[batchSize]) (QTensor dev embshape)
  where
  forward :: TransitionEncoder dev
-> TransitionEncoding dev '[batchSize] -> QTensor dev embshape
forward TransitionEncoder{Conv2d 1 QSliceHidden 1 1 QDType dev
Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
ConstEmb dev '[QSliceHidden]
trL1Passing :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev -> ConstEmb dev '[QSliceHidden]
trL2 :: forall (dev :: (DeviceType, Nat)).
TransitionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner :: Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Left :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right :: Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Root :: ConstEmb dev '[QSliceHidden]
trL2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
..} TransitionEncoding{QTensor dev '[batchSize]
SliceEncoding dev '[batchSize]
QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencRoot :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> QTensor dev batchShape
trencRight :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencLeft :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape -> SliceEncoding dev batchShape
trencInner :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
trencPassing :: forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
TransitionEncoding dev batchShape
-> QBoundedList dev QDType QSliceHidden batchShape (2 : PShape)
trencPassing :: QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencInner :: QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencLeft :: SliceEncoding dev '[batchSize]
trencRight :: SliceEncoding dev '[batchSize]
trencRoot :: QTensor dev '[batchSize]
..} =
    QTensor dev embshape -> QTensor dev embshape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev embshape -> QTensor dev embshape)
-> QTensor dev embshape -> QTensor dev embshape
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trL2 Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
all
   where
    runConv
      :: forall nin
       . (KnownNat nin)
      => TT.Conv2d nin QTransHidden FifthSize OctaveSize QDType dev
      -> QBoundedList dev QDType MaxEdges '[batchSize] (nin : PShape)
      -> QTensor dev (batchSize : QTransHidden : PShape)
    runConv :: forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : EmbShape)
runConv Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv (QBoundedList QTensor dev ('[batchSize] ++ '[QSliceHidden])
mask Tensor
  dev QDType (('[batchSize] ++ '[QSliceHidden]) ++ (nin : PShape))
edges) = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @1 (Tensor
   dev
   QDType
   (CheckBroadcast
      '[batchSize, QSliceHidden, 1, 1, 1]
      '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
      (ComputeBroadcast
         '[1, 1, 1, QSliceHidden, batchSize]
         (ReverseImpl '[batchSize, QSliceHidden, QSliceHidden, 13, 5] '[])))
 -> QTensor dev (batchSize : EmbShape))
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, QSliceHidden, 1, 1, 1]
        '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
        (ComputeBroadcast
           '[1, 1, 1, QSliceHidden, batchSize]
           (ReverseImpl '[batchSize, QSliceHidden, QSliceHidden, 13, 5] '[])))
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
-> Tensor
     dev QDType '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, QSliceHidden, 1, 1, 1]
        '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
        (ComputeBroadcast
           '[1, 1, 1, QSliceHidden, batchSize]
           (ReverseImpl '[batchSize, QSliceHidden, QSliceHidden, 13, 5] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
mask' Tensor dev QDType '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
QTensor dev (batchSize : QSliceHidden : EmbShape)
outReshaped
     where
      shape :: [Int]
shape = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @(nin : PShape)
      shape' :: [Int]
shape' = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @(MaxEdges : QTransHidden : PShape)
      inputShaped :: QTensor dev (batchSize * MaxEdges : nin : PShape)
      inputShaped :: QTensor dev ((batchSize * QSliceHidden) : nin : PShape)
inputShaped = [Int]
-> Tensor dev QDType '[batchSize, QSliceHidden, nin, 13, 5]
-> Tensor dev QDType '[batchSize * QSliceHidden, nin, 13, 5]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
       (shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape (-Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
shape) Tensor dev QDType '[batchSize, QSliceHidden, nin, 13, 5]
Tensor
  dev QDType (('[batchSize] ++ '[QSliceHidden]) ++ (nin : PShape))
edges
      out :: QTensor dev (batchSize * MaxEdges : QTransHidden : PShape)
      out :: QTensor dev ((batchSize * QSliceHidden) : EmbShape)
out = QTensor dev ((batchSize * QSliceHidden) : EmbShape)
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev ((batchSize * QSliceHidden) : EmbShape)
 -> QTensor dev ((batchSize * QSliceHidden) : EmbShape))
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape)
-> QTensor dev ((batchSize * QSliceHidden) : EmbShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin QSliceHidden 13 5 QDType dev
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[batchSize * QSliceHidden, nin, 13, 5]
QTensor dev ((batchSize * QSliceHidden) : nin : PShape)
inputShaped
      outReshaped :: QTensor dev (batchSize : MaxEdges : QTransHidden : PShape)
      outReshaped :: QTensor dev (batchSize : QSliceHidden : EmbShape)
outReshaped = [Int]
-> Tensor
     dev QDType '[batchSize * QSliceHidden, QSliceHidden, 13, 5]
-> Tensor
     dev QDType '[batchSize, QSliceHidden, QSliceHidden, 13, 5]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
       (shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape (-Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
shape') Tensor dev QDType '[batchSize * QSliceHidden, QSliceHidden, 13, 5]
QTensor dev ((batchSize * QSliceHidden) : EmbShape)
out
      mask' :: QTensor dev '[batchSize, MaxEdges, 1, 1, 1]
      mask' :: Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
mask' = [Int]
-> Tensor dev QDType '[batchSize, QSliceHidden]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1, 1]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
       (shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [-Int
1, forall (n :: Nat). KnownNat n => Int
TT.natValI @MaxEdges, Int
1, Int
1, Int
1] Tensor dev QDType '[batchSize, QSliceHidden]
QTensor dev ('[batchSize] ++ '[QSliceHidden])
mask
    runSlice :: Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor device dtype shape
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, (inputSize0 - kernelSize0) + 1,
       (inputSize1 - kernelSize1) + 1]
runSlice Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
conv Tensor device dtype shape
slice = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(0, 0) Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
conv Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
     where
      input :: Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 Tensor device dtype shape
slice
    pass :: QTensor dev (batchSize : QTransHidden : PShape)
    pass :: QTensor dev (batchSize : EmbShape)
pass = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : EmbShape)
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Passing QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencPassing
    inner :: QTensor dev (batchSize : QTransHidden : PShape)
    inner :: QTensor dev (batchSize : EmbShape)
inner = Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat).
KnownNat nin =>
Conv2d nin QSliceHidden FifthSize OctaveSize QDType dev
-> QBoundedList dev QDType QSliceHidden '[batchSize] (nin : PShape)
-> QTensor dev (batchSize : EmbShape)
runConv Conv2d 2 QSliceHidden FifthSize OctaveSize QDType dev
trL1Inner QBoundedList dev QDType QSliceHidden '[batchSize] (2 : PShape)
trencInner
    left :: QTensor dev (batchSize : QTransHidden : PShape)
    left :: QTensor dev (batchSize : EmbShape)
left = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
     dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {shape :: [Nat]} {batchSize :: Nat}
       {inputChannelSize :: Nat} {inputSize0 :: Nat} {inputSize1 :: Nat}
       {kernelSize0 :: Nat} {kernelSize1 :: Nat}
       {outputChannelSize :: Nat} {dtype :: DType}
       {device :: (DeviceType, Nat)}.
(UnsqueezeCheck shape 1 (UnsqueezeImpl shape 1)
 ~ '[batchSize, inputChannelSize, inputSize0, inputSize1],
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat 1 ((inputSize0 - kernelSize0) + 1)) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat 1 ((inputSize1 - kernelSize1) + 1)) 'True 'True 'False)
   (TypeError ...)) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor device dtype shape
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, (inputSize0 - kernelSize0) + 1,
       (inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Left (Tensor dev QDType '[batchSize, 13, 5]
 -> Tensor
      dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
     dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[batchSize]
-> QTensor dev ('[batchSize] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[batchSize]
trencLeft
    right :: QTensor dev (batchSize : QTransHidden : PShape)
    right :: QTensor dev (batchSize : EmbShape)
right = Conv2d 1 QSliceHidden 1 1 QDType dev
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
     dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall {shape :: [Nat]} {batchSize :: Nat}
       {inputChannelSize :: Nat} {inputSize0 :: Nat} {inputSize1 :: Nat}
       {kernelSize0 :: Nat} {kernelSize1 :: Nat}
       {outputChannelSize :: Nat} {dtype :: DType}
       {device :: (DeviceType, Nat)}.
(UnsqueezeCheck shape 1 (UnsqueezeImpl shape 1)
 ~ '[batchSize, inputChannelSize, inputSize0, inputSize1],
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize0 - 1) inputSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat 1 ((inputSize0 - kernelSize0) + 1)) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat (kernelSize1 - 1) inputSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat 1 ((inputSize1 - kernelSize1) + 1)) 'True 'True 'False)
   (TypeError ...)) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor device dtype shape
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, (inputSize0 - kernelSize0) + 1,
       (inputSize1 - kernelSize1) + 1]
runSlice Conv2d 1 QSliceHidden 1 1 QDType dev
trL1Right (Tensor dev QDType '[batchSize, 13, 5]
 -> Tensor
      dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1])
-> Tensor dev QDType '[batchSize, 13, 5]
-> Tensor
     dev QDType '[batchSize, QSliceHidden, (13 - 1) + 1, (5 - 1) + 1]
forall a b. (a -> b) -> a -> b
$ SliceEncoding dev '[batchSize]
-> QTensor dev ('[batchSize] ++ PShape)
forall (dev :: (DeviceType, Nat)) (batchShape :: [Nat]).
SliceEncoding dev batchShape -> QTensor dev (batchShape ++ PShape)
getSlice SliceEncoding dev '[batchSize]
trencRight
    root :: QTensor dev '[batchSize, QTransHidden, 1, 1]
    root :: QTensor dev '[batchSize, QSliceHidden, 1, 1]
root = [Int]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, 1]
        '[QSliceHidden]
        (ComputeBroadcast
           (ReverseImpl '[batchSize, 1] '[])
           (ReverseImpl '[QSliceHidden] '[])))
-> QTensor dev '[batchSize, QSliceHidden, 1, 1]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
       (shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [-Int
1, forall (n :: Nat). KnownNat n => Int
TT.natValI @QTransHidden, Int
1, Int
1] (Tensor
   dev
   QDType
   (CheckBroadcast
      '[batchSize, 1]
      '[QSliceHidden]
      (ComputeBroadcast
         (ReverseImpl '[batchSize, 1] '[])
         (ReverseImpl '[QSliceHidden] '[])))
 -> QTensor dev '[batchSize, QSliceHidden, 1, 1])
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, 1]
        '[QSliceHidden]
        (ComputeBroadcast
           (ReverseImpl '[batchSize, 1] '[])
           (ReverseImpl '[QSliceHidden] '[])))
-> QTensor dev '[batchSize, QSliceHidden, 1, 1]
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[batchSize, 1]
-> QTensor dev '[QSliceHidden]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, 1]
        '[QSliceHidden]
        (ComputeBroadcast
           (ReverseImpl '[batchSize, 1] '[])
           (ReverseImpl '[QSliceHidden] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul (forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 QTensor dev '[batchSize]
trencRoot) (QTensor dev '[QSliceHidden]
 -> Tensor
      dev
      QDType
      (CheckBroadcast
         '[batchSize, 1]
         '[QSliceHidden]
         (ComputeBroadcast
            (ReverseImpl '[batchSize, 1] '[])
            (ReverseImpl '[QSliceHidden] '[]))))
-> QTensor dev '[QSliceHidden]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, 1]
        '[QSliceHidden]
        (ComputeBroadcast
           (ReverseImpl '[batchSize, 1] '[])
           (ReverseImpl '[QSliceHidden] '[])))
forall a b. (a -> b) -> a -> b
$ QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden])
-> QTensor dev '[QSliceHidden] -> QTensor dev '[QSliceHidden]
forall a b. (a -> b) -> a -> b
$ ConstEmb dev '[QSliceHidden] -> () -> QTensor dev '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[QSliceHidden]
trL1Root ()
    all :: QTensor dev (batchSize : QTransHidden : PShape)
    all :: QTensor dev (batchSize : EmbShape)
all = (Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
pass Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
inner Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
left Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
right) Tensor dev QDType '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 1, 1]
-> Tensor
     dev
     QDType
     (CheckBroadcast
        '[batchSize, QSliceHidden, 13, 5]
        '[batchSize, QSliceHidden, 1, 1]
        (ComputeBroadcast
           (ReverseImpl '[batchSize, QSliceHidden, 13, 5] '[])
           '[1, 1, QSliceHidden, batchSize]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` QTensor dev '[batchSize, QSliceHidden, 1, 1]
root

  forwardStoch :: TransitionEncoder dev
-> TransitionEncoding dev '[batchSize] -> IO (QTensor dev embshape)
forwardStoch TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
input = QTensor dev embshape -> IO (QTensor dev embshape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev embshape -> IO (QTensor dev embshape))
-> QTensor dev embshape -> IO (QTensor dev embshape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[batchSize] -> QTensor dev embshape
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
input

-- ActionEncoder
-- -------------

data ActionSpec dev = ActionSpec

data ActionEncoder dev = ActionEncoder
  { forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sl :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev) -- TT.Linear (EmbSize) hidden QDType dev
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev) -- TT.Linear (EmbSize) hidden QDType dev
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev) -- TT.Linear (EmbSize) hidden QDType dev
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev) -- TT.Linear (EmbSize) hidden QDType dev
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 :: !(TT.Conv2d EmbSize QActionHidden FifthSize OctaveSize QDType dev) -- TT.Linear (EmbSize) hidden QDType dev
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 :: !(TT.Conv2d QActionHidden EmbSize FifthSize OctaveSize QDType dev) -- TT.Linear hidden (EmbSize) QDType dev
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actSplit :: ConstEmb dev '[EmbSize - 3] -- TODO: fill in with actual module
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actSpread :: ConstEmb dev '[EmbSize - 3] -- TODO: fill in with actual module
  , forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actFreeze :: ConstEmb dev '[EmbSize - 3]
  }
  deriving (Int -> ActionEncoder dev -> ShowS
[ActionEncoder dev] -> ShowS
ActionEncoder dev -> String
(Int -> ActionEncoder dev -> ShowS)
-> (ActionEncoder dev -> String)
-> ([ActionEncoder dev] -> ShowS)
-> Show (ActionEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)).
Int -> ActionEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [ActionEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)).
Int -> ActionEncoder dev -> ShowS
showsPrec :: Int -> ActionEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> String
show :: ActionEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [ActionEncoder dev] -> ShowS
showList :: [ActionEncoder dev] -> ShowS
Show, (forall x. ActionEncoder dev -> Rep (ActionEncoder dev) x)
-> (forall x. Rep (ActionEncoder dev) x -> ActionEncoder dev)
-> Generic (ActionEncoder dev)
forall x. Rep (ActionEncoder dev) x -> ActionEncoder dev
forall x. ActionEncoder dev -> Rep (ActionEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (ActionEncoder dev) x -> ActionEncoder dev
forall (dev :: (DeviceType, Nat)) x.
ActionEncoder dev -> Rep (ActionEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
ActionEncoder dev -> Rep (ActionEncoder dev) x
from :: forall x. ActionEncoder dev -> Rep (ActionEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (ActionEncoder dev) x -> ActionEncoder dev
to :: forall x. Rep (ActionEncoder dev) x -> ActionEncoder dev
Generic, ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
(ActionEncoder dev -> HList (Parameters (ActionEncoder dev)))
-> (ActionEncoder dev
    -> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev)
-> Parameterized (ActionEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
flattenParameters :: ActionEncoder dev -> HList (Parameters (ActionEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
replaceParameters :: ActionEncoder dev
-> HList (Parameters (ActionEncoder dev)) -> ActionEncoder dev
TT.Parameterized, Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
Proxy (ActionEncoder dev) -> String
(Context -> ActionEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> ActionEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (ActionEncoder dev) -> String)
-> NoThunks (ActionEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (ActionEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> ActionEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (ActionEncoder dev) -> String
showTypeOf :: Proxy (ActionEncoder dev) -> String
NoThunks, ActionEncoder dev -> ()
(ActionEncoder dev -> ()) -> NFData (ActionEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). ActionEncoder dev -> ()
rnf :: ActionEncoder dev -> ()
NFData)

instance (IsValidDevice dev) => T.Randomizable (ActionSpec dev) (ActionEncoder dev) where
  sample :: ActionSpec dev -> IO (ActionEncoder dev)
  sample :: ActionSpec dev -> IO (ActionEncoder dev)
sample ActionSpec dev
ActionSpec = do
    actTop1sl <- Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
TT.Conv2dSpec
    actTop1sm <- T.sample TT.Conv2dSpec
    actTop1sr <- T.sample TT.Conv2dSpec
    actTop1t1 <- T.sample TT.Conv2dSpec
    actTop1t2 <- T.sample TT.Conv2dSpec
    actTop2 <- T.sample TT.Conv2dSpec
    actSplit <- T.sample $ ConstEmbSpec @dev
    actSpread <- T.sample $ ConstEmbSpec @dev
    actFreeze <- T.sample $ ConstEmbSpec @dev
    pure ActionEncoder{..}

opTypes :: forall dev. (TT.KnownDevice dev) => QTensor dev '[6, 3]
opTypes :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[FifthPadding, 3]
opTypes =
  Tensor -> Tensor dev QDType '[FifthPadding, 3]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor
    (Tensor -> Tensor dev QDType '[FifthPadding, 3])
-> Tensor -> Tensor dev QDType '[FifthPadding, 3]
forall a b. (a -> b) -> a -> b
$! forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
T.asTensor' @[[QType]]
      [ [QType
0, QType
0, QType
0] -- freeze only
      , [QType
0, QType
1, QType
0] -- split only
      , [QType
1, QType
0, QType
0] -- freeze left
      , [QType
1, QType
0, QType
1] -- spread
      , [QType
1, QType
1, QType
0] -- freeze left
      , [QType
1, QType
1, QType
1] -- freeze right
      ]
    (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Nat)). KnownDevice dev => TensorOptions
opts @dev

-- | HasForward for actions (batched)
instance
  forall dev batchSize outShape
   . ( IsValidDevice dev
     , outShape ~ (batchSize : EmbSize : PShape)
     , 1 <= batchSize
     )
  => T.HasForward
      (ActionEncoder dev)
      (SliceEncoder dev, TransitionEncoder dev, ActionEncoding dev '[batchSize])
      (QTensor dev outShape)
  where
  forward :: ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
    ActionEncoding dev '[batchSize])
-> QTensor dev outShape
forward ActionEncoder{Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
ConstEmb dev '[QSliceHidden - 3]
actTop1sl :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actSplit :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actSpread :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actFreeze :: forall (dev :: (DeviceType, Nat)).
ActionEncoder dev -> ConstEmb dev '[QSliceHidden - 3]
actTop1sl :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actSplit :: ConstEmb dev '[QSliceHidden - 3]
actSpread :: ConstEmb dev '[QSliceHidden - 3]
actFreeze :: ConstEmb dev '[QSliceHidden - 3]
..} (SliceEncoder dev
slc, TransitionEncoder dev
tr, ActionEncoding (ActionTop QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sl TransitionEncoding dev '[batchSize]
t1 (QMaybe QTensor dev '[batchSize]
smMask SliceEncoding dev '[batchSize]
sm) (QMaybe QTensor dev '[batchSize]
t2Mask TransitionEncoding dev '[batchSize]
t2) QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sr) Tensor dev 'Int64 '[batchSize]
opIndex) = QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
topEmb QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
-> QTensor dev outShape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
opEmbReshaped
   where
    runConv
      :: TT.Conv2d nin nout FifthSize OctaveSize QDType dev
      -> QTensor dev (batchSize : nin : PShape)
      -> QTensor dev (batchSize : nout : PShape)
    runConv :: forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batchSize : nin : PShape)
input =
      QTensor dev (batchSize : nout : PShape)
-> QTensor dev (batchSize : nout : PShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (batchSize : nout : PShape)
 -> QTensor dev (batchSize : nout : PShape))
-> QTensor dev (batchSize : nout : PShape)
-> QTensor dev (batchSize : nout : PShape)
forall a b. (a -> b) -> a -> b
$ forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin nout 13 5 QDType dev
Conv2d nin nout FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[batchSize, nin, 13, 5]
QTensor dev (batchSize : nin : PShape)
input
    runConvMasked
      :: QTensor dev '[batchSize]
      -> TT.Conv2d nin nout FifthSize OctaveSize QDType dev
      -> QTensor dev (batchSize : nin : PShape)
      -> QTensor dev (batchSize : nout : PShape)
    runConvMasked :: forall (nin :: Nat) (nout :: Nat).
QTensor dev '[batchSize]
-> Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked QTensor dev '[batchSize]
mask Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batchSize : nin : PShape)
input =
      Tensor dev QDType '[batchSize, 1, 1, 1]
-> Tensor dev QDType '[batchSize, nout, 13, 5]
-> Tensor dev QDType (batchSize : nout : PShape)
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul ([Int]
-> QTensor dev '[batchSize]
-> Tensor dev QDType '[batchSize, 1, 1, 1]
forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat])
       (shape' :: [Nat]).
[Int] -> Tensor dev dtype shape -> Tensor dev dtype shape'
unsafeReshape [-Int
1, Int
1, Int
1, Int
1] QTensor dev '[batchSize]
mask :: QTensor dev '[batchSize, 1, 1, 1]) (Tensor dev QDType '[batchSize, nout, 13, 5]
 -> Tensor dev QDType (batchSize : nout : PShape))
-> Tensor dev QDType '[batchSize, nout, 13, 5]
-> Tensor dev QDType (batchSize : nout : PShape)
forall a b. (a -> b) -> a -> b
$ Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> Tensor dev QDType (batchSize : nout : PShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batchSize : nin : PShape)
input
    -- top embedding
    embl :: QTensor dev (batchSize : QActionHidden : PShape)
    embl :: QTensor dev (batchSize : EmbShape)
embl = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sl (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sl
    embm :: QTensor dev (batchSize : EmbShape)
embm = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sm (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> SliceEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc SliceEncoding dev '[batchSize]
sm
    embr :: QTensor dev (batchSize : EmbShape)
embr = QTensor dev '[batchSize]
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
QTensor dev '[batchSize]
-> Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked QTensor dev '[batchSize]
smMask Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1sr (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[batchSize] (SliceEncoding dev '[batchSize])
sr
    embt1 :: QTensor dev (batchSize : EmbShape)
embt1 = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t1 (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
t1
    embt2 :: QTensor dev (batchSize : EmbShape)
embt2 = QTensor dev '[batchSize]
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
QTensor dev '[batchSize]
-> Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConvMasked QTensor dev '[batchSize]
t2Mask Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop1t2 (QTensor dev (batchSize : EmbShape)
 -> QTensor dev (batchSize : EmbShape))
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[batchSize]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[batchSize]
t2
    topCombined :: QTensor dev (batchSize : QActionHidden : PShape)
    topCombined :: QTensor dev (batchSize : EmbShape)
topCombined = QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embl QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embm QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embr QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embt1 QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
embt2
    topEmb :: QTensor dev (batchSize : EmbSize : PShape)
    topEmb :: QTensor dev (batchSize : EmbShape)
topEmb = Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : EmbShape)
-> QTensor dev (batchSize : EmbShape)
forall (nin :: Nat) (nout :: Nat).
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batchSize : nin : PShape)
-> QTensor dev (batchSize : nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
actTop2 QTensor dev (batchSize : EmbShape)
topCombined
    -- operation embedding
    opFreeze :: QTensor dev '[5]
opFreeze = ConstEmb dev '[5] -> () -> QTensor dev '[5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[5]
ConstEmb dev '[QSliceHidden - 3]
actFreeze ()
    opSplit :: QTensor dev '[5]
opSplit = ConstEmb dev '[5] -> () -> QTensor dev '[5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[5]
ConstEmb dev '[QSliceHidden - 3]
actSplit ()
    opSpread :: QTensor dev '[5]
opSpread = ConstEmb dev '[5] -> () -> QTensor dev '[5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ConstEmb dev '[5]
ConstEmb dev '[QSliceHidden - 3]
actSpread ()
    opCombined :: Tensor dev QDType '[FifthPadding, 5]
opCombined = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.stack @0 (HList
   '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
     QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
 -> Tensor dev QDType '[FifthPadding, 5])
-> HList
     '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
       QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
-> Tensor dev QDType '[FifthPadding, 5]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[5]
opFreeze QTensor dev '[5]
-> HList
     '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
       QTensor dev '[5], QTensor dev '[5]]
-> HList
     '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
       QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSplit QTensor dev '[5]
-> HList
     '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
       QTensor dev '[5]]
-> HList
     '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
       QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opFreeze QTensor dev '[5]
-> HList '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
-> HList
     '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5],
       QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSpread QTensor dev '[5]
-> HList '[QTensor dev '[5], QTensor dev '[5]]
-> HList '[QTensor dev '[5], QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSplit QTensor dev '[5]
-> HList '[QTensor dev '[5]]
-> HList '[QTensor dev '[5], QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. QTensor dev '[5]
opSplit QTensor dev '[5] -> HList '[] -> HList '[QTensor dev '[5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
    opEmbeddings :: QTensor dev '[6, EmbSize]
    opEmbeddings :: QTensor dev '[FifthPadding, QSliceHidden]
opEmbeddings = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ForeignPtr Tensor]) =>
HList tensors -> Tensor device dtype shape
TT.cat @1 (HList
   '[QTensor dev '[FifthPadding, 3],
     Tensor dev QDType '[FifthPadding, 5]]
 -> QTensor dev '[FifthPadding, QSliceHidden])
-> HList
     '[QTensor dev '[FifthPadding, 3],
       Tensor dev QDType '[FifthPadding, 5]]
-> QTensor dev '[FifthPadding, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[FifthPadding, 3]
opTypes @dev QTensor dev '[FifthPadding, 3]
-> HList '[Tensor dev QDType '[FifthPadding, 5]]
-> HList
     '[QTensor dev '[FifthPadding, 3],
       Tensor dev QDType '[FifthPadding, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. Tensor dev QDType '[FifthPadding, 5]
opCombined Tensor dev QDType '[FifthPadding, 5]
-> HList '[] -> HList '[Tensor dev QDType '[FifthPadding, 5]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
TT.:. HList '[]
forall k. HList '[]
TT.HNil
    opIndex' :: TT.Tensor dev TT.Int64 [batchSize, EmbSize]
    opIndex' :: Tensor dev 'Int64 '[batchSize, QSliceHidden]
opIndex' = Tensor -> Tensor dev 'Int64 '[batchSize, QSliceHidden]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev 'Int64 '[batchSize, QSliceHidden])
-> Tensor -> Tensor dev 'Int64 '[batchSize, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ Tensor -> Bool -> [Int] -> Tensor
T.expand (Tensor dev 'Int64 '[batchSize, 1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (Tensor dev 'Int64 '[batchSize, 1] -> Tensor)
-> Tensor dev 'Int64 '[batchSize, 1] -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @1 Tensor dev 'Int64 '[batchSize]
opIndex) Bool
False [-Int
1, forall (n :: Nat). KnownNat n => Int
TT.natValI @EmbSize]
    opEmb :: QTensor dev '[batchSize, EmbSize]
    opEmb :: QTensor dev '[batchSize, QSliceHidden]
opEmb = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ GatherDim shape shape' dim) =>
Tensor device 'Int64 shape'
-> Tensor device dtype shape -> Tensor device dtype shape'
TT.gatherDim @0 Tensor dev 'Int64 '[batchSize, QSliceHidden]
opIndex' QTensor dev '[FifthPadding, QSliceHidden]
opEmbeddings
    opEmbReshaped :: QTensor dev '[batchSize, EmbSize, 1, 1]
    opEmbReshaped :: Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
opEmbReshaped = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @3 (Tensor dev QDType '[batchSize, QSliceHidden, 1]
 -> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1])
-> Tensor dev QDType '[batchSize, QSliceHidden, 1]
-> Tensor dev QDType '[batchSize, QSliceHidden, 1, 1]
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @2 QTensor dev '[batchSize, QSliceHidden]
opEmb
  forwardStoch :: ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
    ActionEncoding dev '[batchSize])
-> IO (QTensor dev outShape)
forwardStoch ActionEncoder dev
a (SliceEncoder dev, TransitionEncoder dev,
 ActionEncoding dev '[batchSize])
i = QTensor dev outShape -> IO (QTensor dev outShape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev outShape -> IO (QTensor dev outShape))
-> QTensor dev outShape -> IO (QTensor dev outShape)
forall a b. (a -> b) -> a -> b
$ ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
    ActionEncoding dev '[batchSize])
-> QTensor dev outShape
forall f a b. HasForward f a b => f -> a -> b
T.forward ActionEncoder dev
a (SliceEncoder dev, TransitionEncoder dev,
 ActionEncoding dev '[batchSize])
i

-- State Encoder
-- -------------

data StateSpec dev = StateSpec

data StateEncoder dev = StateEncoder
  { forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
  , forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
  , forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
  , forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
  , forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr :: TT.Conv2d (EmbSize) QStateHidden FifthSize OctaveSize QDType dev
  , forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 :: TT.Conv2d QStateHidden QStateHidden FifthSize OctaveSize QDType dev
  , forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 :: TT.Conv2d QStateHidden (EmbSize) FifthSize OctaveSize QDType dev
  }
  deriving (Int -> StateEncoder dev -> ShowS
[StateEncoder dev] -> ShowS
StateEncoder dev -> String
(Int -> StateEncoder dev -> ShowS)
-> (StateEncoder dev -> String)
-> ([StateEncoder dev] -> ShowS)
-> Show (StateEncoder dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)). Int -> StateEncoder dev -> ShowS
forall (dev :: (DeviceType, Nat)). [StateEncoder dev] -> ShowS
forall (dev :: (DeviceType, Nat)). StateEncoder dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)). Int -> StateEncoder dev -> ShowS
showsPrec :: Int -> StateEncoder dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). StateEncoder dev -> String
show :: StateEncoder dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [StateEncoder dev] -> ShowS
showList :: [StateEncoder dev] -> ShowS
Show, (forall x. StateEncoder dev -> Rep (StateEncoder dev) x)
-> (forall x. Rep (StateEncoder dev) x -> StateEncoder dev)
-> Generic (StateEncoder dev)
forall x. Rep (StateEncoder dev) x -> StateEncoder dev
forall x. StateEncoder dev -> Rep (StateEncoder dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (StateEncoder dev) x -> StateEncoder dev
forall (dev :: (DeviceType, Nat)) x.
StateEncoder dev -> Rep (StateEncoder dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
StateEncoder dev -> Rep (StateEncoder dev) x
from :: forall x. StateEncoder dev -> Rep (StateEncoder dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (StateEncoder dev) x -> StateEncoder dev
to :: forall x. Rep (StateEncoder dev) x -> StateEncoder dev
Generic, StateEncoder dev -> HList (Parameters (StateEncoder dev))
StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
(StateEncoder dev -> HList (Parameters (StateEncoder dev)))
-> (StateEncoder dev
    -> HList (Parameters (StateEncoder dev)) -> StateEncoder dev)
-> Parameterized (StateEncoder dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
StateEncoder dev -> HList (Parameters (StateEncoder dev))
forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev -> HList (Parameters (StateEncoder dev))
flattenParameters :: StateEncoder dev -> HList (Parameters (StateEncoder dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
replaceParameters :: StateEncoder dev
-> HList (Parameters (StateEncoder dev)) -> StateEncoder dev
TT.Parameterized, Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
Proxy (StateEncoder dev) -> String
(Context -> StateEncoder dev -> IO (Maybe ThunkInfo))
-> (Context -> StateEncoder dev -> IO (Maybe ThunkInfo))
-> (Proxy (StateEncoder dev) -> String)
-> NoThunks (StateEncoder dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)).
Proxy (StateEncoder dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> StateEncoder dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)).
Proxy (StateEncoder dev) -> String
showTypeOf :: Proxy (StateEncoder dev) -> String
NoThunks, StateEncoder dev -> ()
(StateEncoder dev -> ()) -> NFData (StateEncoder dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). StateEncoder dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). StateEncoder dev -> ()
rnf :: StateEncoder dev -> ()
NFData)

instance (IsValidDevice dev) => T.Randomizable (StateSpec dev) (StateEncoder dev) where
  sample :: StateSpec dev -> IO (StateEncoder dev)
sample StateSpec dev
_ = do
    stL1mid <- Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
-> IO (Conv2d QSliceHidden QSliceHidden 13 5 QDType dev)
forall spec f. Randomizable spec f => spec -> IO f
TT.sample Conv2dSpec QSliceHidden QSliceHidden 13 5 QDType dev
forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2dSpec
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
TT.Conv2dSpec
    stL1frozenSlc <- TT.sample TT.Conv2dSpec
    stL1frozenTr <- TT.sample TT.Conv2dSpec
    stL1openSlc <- TT.sample TT.Conv2dSpec
    stL1openTr <- TT.sample TT.Conv2dSpec
    stL2 <- TT.sample TT.Conv2dSpec
    stL3 <- TT.sample TT.Conv2dSpec
    pure StateEncoder{..}

-- | HasForward for the parsing state (doesn't need batching)
instance
  forall dev outShape
   . ( IsValidDevice dev
     , outShape ~ (EmbSize : PShape)
     )
  => T.HasForward
      (StateEncoder dev)
      (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
      (QTensor dev outShape)
  where
  forward :: StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> QTensor dev outShape
forward StateEncoder{Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 :: forall (dev :: (DeviceType, Nat)).
StateEncoder dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
..} (SliceEncoder dev
slc, TransitionEncoder dev
tr, StateEncoding QStartStop dev '[] (SliceEncoding dev '[])
mid QMaybe
  dev
  '[]
  (TransitionEncoding dev '[FakeSize],
   QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
frozen QMaybe
  dev
  '[]
  (TransitionEncoding dev '[FakeSize],
   QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
open) = QTensor dev outShape
QTensor dev EmbShape
out3
   where
    -- helpers: running convolutions (batched and unbatched)
    runConv'
      :: (KnownNat nin, KnownNat nout, KnownNat batch)
      => TT.Conv2d nin nout FifthSize OctaveSize QDType dev
      -> QTensor dev (batch : nin : PShape)
      -> QTensor dev (batch : nout : PShape)
    runConv' :: forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (batch : nin : PShape)
input = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       {kernelSize1 :: Nat} {inputSize1 :: Nat} {kernelSize0 :: Nat}
       {inputSize0 :: Nat} {inputChannelSize :: Nat}
       {outputChannelSize :: Nat} {batchSize :: Nat} {w1 :: DType}
       {w2 :: (DeviceType, Nat)}.
(Assert
   (OrdCond
      (CmpNat kernelSize1 (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond
      (CmpNat kernelSize0 (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Snd stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 (Fst stride)) 'True 'True 'False)
   (TypeError ...)
 ~ (() :: Constraint),
 Assert
   (OrdCond (CmpNat 1 kernelSize0) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize0 - 1) (inputSize0 + (2 * Fst padding)))
      'True
      'True
      'False)
   (TypeError ...),
 Assert
   (OrdCond (CmpNat 1 kernelSize1) 'True 'True 'False)
   (TypeError ...),
 Assert
   (OrdCond
      (CmpNat (kernelSize1 - 1) (inputSize1 + (2 * Snd padding)))
      'True
      'True
      'False)
   (TypeError ...),
 KnownNat inputChannelSize, KnownNat outputChannelSize,
 KnownNat kernelSize0, KnownNat kernelSize1, KnownNat inputSize0,
 KnownNat inputSize1, KnownNat batchSize, KnownNat (Fst stride),
 KnownNat (Fst padding), KnownNat (Snd stride),
 KnownNat (Snd padding)) =>
Conv2d
  inputChannelSize outputChannelSize kernelSize0 kernelSize1 w1 w2
-> Tensor
     w2 w1 '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     w2
     w1
     '[batchSize, outputChannelSize,
       Div ((inputSize0 + (2 * Fst padding)) - kernelSize0) (Fst stride)
       + 1,
       Div ((inputSize1 + (2 * Snd padding)) - kernelSize1) (Snd stride)
       + 1]
TT.conv2dForward @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d nin nout 13 5 QDType dev
Conv2d nin nout FifthSize OctaveSize QDType dev
conv Tensor dev QDType '[batch, nin, 13, 5]
QTensor dev (batch : nin : PShape)
input
    runConv
      :: (KnownNat nin, KnownNat nout)
      => TT.Conv2d nin nout FifthSize OctaveSize QDType dev
      -> QTensor dev (nin : PShape)
      -> QTensor dev (nout : PShape)
    runConv :: forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d nin nout FifthSize OctaveSize QDType dev
conv QTensor dev (nin : PShape)
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor dev QDType '[1, nout, 13, 5]
 -> Tensor dev QDType (nout : PShape))
-> Tensor dev QDType '[1, nout, 13, 5]
-> Tensor dev QDType (nout : PShape)
forall a b. (a -> b) -> a -> b
$ Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (1 : nin : PShape)
-> QTensor dev (1 : nout : PShape)
forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d nin nout FifthSize OctaveSize QDType dev
conv (QTensor dev (1 : nin : PShape) -> QTensor dev (1 : nout : PShape))
-> QTensor dev (1 : nin : PShape)
-> QTensor dev (1 : nout : PShape)
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev QDType '[nin, 13, 5]
QTensor dev (nin : PShape)
input

    -- embedding segments (open and frozen)
    embedSegments
      :: TT.Conv2d EmbSize QStateHidden FifthSize OctaveSize QDType dev
      -> TT.Conv2d EmbSize QStateHidden FifthSize OctaveSize QDType dev
      -> QMaybe dev '[] (TransitionEncoding dev '[FakeSize], QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
      -> QTensor dev (FakeSize : EmbSize : PShape)
    embedSegments :: Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbShape)
embedSegments Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trEnc Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
slcEnc (QMaybe QTensor dev '[]
mask (TransitionEncoding dev '[FakeSize]
ft, QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])
fs)) =
      Tensor dev QDType '[1, 1, 1, 1]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev (FakeSize : EmbShape)
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul (forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.reshape @[1, 1, 1, 1] QTensor dev '[]
mask) (QTensor dev '[FakeSize, QSliceHidden, 13, 5]
 -> QTensor dev (FakeSize : EmbShape))
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ QTensor dev '[FakeSize, QSliceHidden, 13, 5]
QTensor dev (FakeSize : EmbShape)
ftEmb QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[FakeSize, QSliceHidden, 13, 5]
QTensor dev (FakeSize : EmbShape)
fsEmb
     where
      ftEmb :: QTensor dev (FakeSize : EmbSize : PShape)
      ftEmb :: QTensor dev (FakeSize : EmbShape)
ftEmb = QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (FakeSize : EmbShape)
 -> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
trEnc (QTensor dev (FakeSize : EmbShape)
 -> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ TransitionEncoder dev
-> TransitionEncoding dev '[FakeSize]
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward TransitionEncoder dev
tr TransitionEncoding dev '[FakeSize]
ft
      fsEmb :: QTensor dev (FakeSize : EmbSize : PShape)
      fsEmb :: QTensor dev (FakeSize : EmbShape)
fsEmb = QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev (FakeSize : EmbShape)
 -> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall (nin :: Nat) (nout :: Nat) (batch :: Nat).
(KnownNat nin, KnownNat nout, KnownNat batch) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (batch : nin : PShape)
-> QTensor dev (batch : nout : PShape)
runConv' Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
slcEnc (QTensor dev (FakeSize : EmbShape)
 -> QTensor dev (FakeSize : EmbShape))
-> QTensor dev (FakeSize : EmbShape)
-> QTensor dev (FakeSize : EmbShape)
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])
-> QTensor dev '[FakeSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize])
fs

    -- embed frozen segments
    frozenEmb :: QTensor dev (EmbSize : PShape)
    frozenEmb :: QTensor dev EmbShape
frozenEmb = forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
 MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.meanDim @0 (QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape)
-> QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbShape)
embedSegments Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenTr Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1frozenSlc QMaybe
  dev
  '[]
  (TransitionEncoding dev '[FakeSize],
   QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
frozen
    -- embed open segments
    openEmb :: QTensor dev (EmbSize : PShape)
    openEmb :: QTensor dev EmbShape
openEmb = forall (dim :: Nat) (shape' :: [Nat]) (shape :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ DropValue shape dim,
 MeanDTypeValidation device dtype, AllDimsPositive shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.meanDim @0 (QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape)
-> QTensor dev (FakeSize : EmbShape) -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QMaybe
     dev
     '[]
     (TransitionEncoding dev '[FakeSize],
      QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
-> QTensor dev (FakeSize : EmbShape)
embedSegments Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openTr Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1openSlc QMaybe
  dev
  '[]
  (TransitionEncoding dev '[FakeSize],
   QStartStop dev '[FakeSize] (SliceEncoding dev '[FakeSize]))
open
    -- embed the mid slice
    midEmb :: QTensor dev (QStateHidden : PShape)
    midEmb :: QTensor dev EmbShape
midEmb = QTensor dev EmbShape -> QTensor dev EmbShape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL1mid (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ SliceEncoder dev
-> QStartStop dev '[] (SliceEncoding dev '[])
-> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward SliceEncoder dev
slc QStartStop dev '[] (SliceEncoding dev '[])
mid

    -- combined embeddings and compute output
    fullEmb :: QTensor dev (EmbSize : PShape)
    fullEmb :: QTensor dev EmbShape
fullEmb = QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
midEmb QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
frozenEmb QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
forall a. Num a => a -> a -> a
+ QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
openEmb
    out2 :: QTensor dev (QStateHidden : PShape)
    out2 :: QTensor dev EmbShape
out2 = QTensor dev EmbShape -> QTensor dev EmbShape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL2 QTensor dev EmbShape
fullEmb
    out3 :: QTensor dev (EmbSize : PShape)
    out3 :: QTensor dev EmbShape
out3 = QTensor dev EmbShape -> QTensor dev EmbShape
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev EmbShape -> QTensor dev EmbShape)
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall a b. (a -> b) -> a -> b
$ Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
-> QTensor dev EmbShape -> QTensor dev EmbShape
forall (nin :: Nat) (nout :: Nat).
(KnownNat nin, KnownNat nout) =>
Conv2d nin nout FifthSize OctaveSize QDType dev
-> QTensor dev (nin : PShape) -> QTensor dev (nout : PShape)
runConv Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
stL3 QTensor dev EmbShape
out2
  forwardStoch :: StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> IO (QTensor dev outShape)
forwardStoch StateEncoder dev
a (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
i = QTensor dev outShape -> IO (QTensor dev outShape)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev outShape -> IO (QTensor dev outShape))
-> QTensor dev outShape -> IO (QTensor dev outShape)
forall a b. (a -> b) -> a -> b
$ StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> QTensor dev outShape
forall f a b. HasForward f a b => f -> a -> b
T.forward StateEncoder dev
a (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
i

-- Full Q Model
-- ------------

data QSpec dev = QSpec

data QModel dev = QModel
  { forall (dev :: (DeviceType, Nat)). QModel dev -> SliceEncoder dev
qModelSlc :: !(SliceEncoder dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev -> TransitionEncoder dev
qModelTr :: !(TransitionEncoder dev)
  , forall (dev :: (DeviceType, Nat)). QModel dev -> ActionEncoder dev
qModelAct :: !(ActionEncoder dev)
  , forall (dev :: (DeviceType, Nat)). QModel dev -> StateEncoder dev
qModelSt :: !(StateEncoder dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev
-> Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
qModelFinal1 :: !(TT.Conv2d EmbSize QOutHidden FifthSize OctaveSize QDType dev) -- !(TT.Linear (EmbSize (QSpecGeneral DefaultQSpec)) QOutHidden QDType dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev -> LayerNorm '[QSliceHidden] QDType dev
qModelNorm1 :: !(TT.LayerNorm '[QOutHidden] QDType dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev -> Linear QSliceHidden 1 QDType dev
qModelFinal2 :: !(TT.Linear QOutHidden 1 QDType dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev -> Linear QSliceHidden QSliceHidden QDType dev
qModelValue1 :: !(TT.Linear EmbSize QOutHidden QDType dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev -> LayerNorm '[QSliceHidden] QDType dev
qModelValueNorm :: !(TT.LayerNorm '[QOutHidden] QDType dev)
  , forall (dev :: (DeviceType, Nat)).
QModel dev -> Linear QSliceHidden 1 QDType dev
qModelValue2 :: !(TT.Linear QOutHidden 1 QDType dev)
  }
  deriving (Int -> QModel dev -> ShowS
[QModel dev] -> ShowS
QModel dev -> String
(Int -> QModel dev -> ShowS)
-> (QModel dev -> String)
-> ([QModel dev] -> ShowS)
-> Show (QModel dev)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)). Int -> QModel dev -> ShowS
forall (dev :: (DeviceType, Nat)). [QModel dev] -> ShowS
forall (dev :: (DeviceType, Nat)). QModel dev -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)). Int -> QModel dev -> ShowS
showsPrec :: Int -> QModel dev -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)). QModel dev -> String
show :: QModel dev -> String
$cshowList :: forall (dev :: (DeviceType, Nat)). [QModel dev] -> ShowS
showList :: [QModel dev] -> ShowS
Show, (forall x. QModel dev -> Rep (QModel dev) x)
-> (forall x. Rep (QModel dev) x -> QModel dev)
-> Generic (QModel dev)
forall x. Rep (QModel dev) x -> QModel dev
forall x. QModel dev -> Rep (QModel dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (QModel dev) x -> QModel dev
forall (dev :: (DeviceType, Nat)) x.
QModel dev -> Rep (QModel dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
QModel dev -> Rep (QModel dev) x
from :: forall x. QModel dev -> Rep (QModel dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (QModel dev) x -> QModel dev
to :: forall x. Rep (QModel dev) x -> QModel dev
Generic, QModel dev -> HList (Parameters (QModel dev))
QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
(QModel dev -> HList (Parameters (QModel dev)))
-> (QModel dev -> HList (Parameters (QModel dev)) -> QModel dev)
-> Parameterized (QModel dev)
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev))
forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
$cflattenParameters :: forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev))
flattenParameters :: QModel dev -> HList (Parameters (QModel dev))
$creplaceParameters :: forall (dev :: (DeviceType, Nat)).
QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
replaceParameters :: QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
TT.Parameterized, Context -> QModel dev -> IO (Maybe ThunkInfo)
Proxy (QModel dev) -> String
(Context -> QModel dev -> IO (Maybe ThunkInfo))
-> (Context -> QModel dev -> IO (Maybe ThunkInfo))
-> (Proxy (QModel dev) -> String)
-> NoThunks (QModel dev)
forall a.
(Context -> a -> IO (Maybe ThunkInfo))
-> (Context -> a -> IO (Maybe ThunkInfo))
-> (Proxy a -> String)
-> NoThunks a
forall (dev :: (DeviceType, Nat)).
Context -> QModel dev -> IO (Maybe ThunkInfo)
forall (dev :: (DeviceType, Nat)). Proxy (QModel dev) -> String
$cnoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> QModel dev -> IO (Maybe ThunkInfo)
noThunks :: Context -> QModel dev -> IO (Maybe ThunkInfo)
$cwNoThunks :: forall (dev :: (DeviceType, Nat)).
Context -> QModel dev -> IO (Maybe ThunkInfo)
wNoThunks :: Context -> QModel dev -> IO (Maybe ThunkInfo)
$cshowTypeOf :: forall (dev :: (DeviceType, Nat)). Proxy (QModel dev) -> String
showTypeOf :: Proxy (QModel dev) -> String
NoThunks, QModel dev -> ()
(QModel dev -> ()) -> NFData (QModel dev)
forall a. (a -> ()) -> NFData a
forall (dev :: (DeviceType, Nat)). QModel dev -> ()
$crnf :: forall (dev :: (DeviceType, Nat)). QModel dev -> ()
rnf :: QModel dev -> ()
NFData)

instance (IsValidDevice dev) => T.Randomizable (QSpec dev) (QModel dev) where
  sample :: QSpec dev -> IO (QModel dev)
  sample :: QSpec dev -> IO (QModel dev)
sample QSpec dev
QSpec = do
    qModelSlc <- SliceSpec dev -> IO (SliceEncoder dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample (SliceSpec dev -> IO (SliceEncoder dev))
-> SliceSpec dev -> IO (SliceEncoder dev)
forall a b. (a -> b) -> a -> b
$ forall {k} (dev :: k). SliceSpec dev
forall (dev :: (DeviceType, Nat)). SliceSpec dev
SliceSpec @dev
    qModelTr <- T.sample $ TransitionSpec @dev
    qModelAct <- T.sample $ ActionSpec @dev
    qModelSt <- T.sample $ StateSpec @dev
    qModelFinal1 <- T.sample TT.Conv2dSpec
    qModelNorm1 <- T.sample $ TT.LayerNormSpec 1e-05
    qModelFinal2 <- T.sample TT.LinearSpec
    qModelValue1 <- T.sample TT.LinearSpec
    qModelValueNorm <- T.sample $ TT.LayerNormSpec 1e-05
    qModelValue2 <- T.sample TT.LinearSpec
    pure QModel{..}

-- | HasForward for model (unbatched)
instance
  (IsValidDevice dev, TT.CheckIsSuffixOf '[QOutHidden] [1, QOutHidden] (QOutHidden == QOutHidden))
  => T.HasForward (QModel dev) (QEncoding dev '[]) (QTensor dev '[1])
  where
  forward :: QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
  forward :: QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forward QModel dev
model QEncoding dev '[]
encoding = QTensor dev '[1] -> QTensor dev '[1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (QTensor dev '[1] -> QTensor dev '[1])
-> QTensor dev '[1] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[1] -> QTensor dev '[1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.sigmoid (QTensor dev '[1] -> QTensor dev '[1])
-> QTensor dev '[1] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardQModel QModel dev
model QEncoding dev '[]
encoding

  forwardStoch :: QModel dev -> QEncoding dev '[] -> IO (QTensor dev '[1])
  forwardStoch :: QModel dev -> QEncoding dev '[] -> IO (QTensor dev '[1])
forwardStoch QModel dev
model QEncoding dev '[]
input = QTensor dev '[1] -> IO (QTensor dev '[1])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev '[1] -> IO (QTensor dev '[1]))
-> QTensor dev '[1] -> IO (QTensor dev '[1])
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model QEncoding dev '[]
input

-- | HasForward for model (batched)
instance
  ( IsValidDevice dev
  , KnownNat batchSize
  , 1 <= batchSize
  , TT.CheckIsSuffixOf '[QOutHidden] [batchSize, QOutHidden] (QOutHidden == QOutHidden)
  )
  => T.HasForward (QModel dev) (QEncoding dev '[batchSize]) (QTensor dev '[batchSize, 1])
  where
  forward :: QModel dev -> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
  forward :: QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forward QModel dev
model QEncoding dev '[batchSize]
encoding =
    QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1])
-> QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall a b. (a -> b) -> a -> b
$ QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.sigmoid (QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1])
-> QTensor dev '[batchSize, 1] -> QTensor dev '[batchSize, 1]
forall a b. (a -> b) -> a -> b
$ QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched QModel dev
model QEncoding dev '[batchSize]
encoding

  forwardStoch :: QModel dev
-> QEncoding dev '[batchSize] -> IO (QTensor dev '[batchSize, 1])
forwardStoch QModel dev
model QEncoding dev '[batchSize]
input = QTensor dev '[batchSize, 1] -> IO (QTensor dev '[batchSize, 1])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (QTensor dev '[batchSize, 1] -> IO (QTensor dev '[batchSize, 1]))
-> QTensor dev '[batchSize, 1] -> IO (QTensor dev '[batchSize, 1])
forall a b. (a -> b) -> a -> b
$ QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model QEncoding dev '[batchSize]
input

forwardQModel
  :: (IsValidDevice dev)
  => QModel dev
  -> QEncoding dev '[]
  -> QTensor dev '[1]
forwardQModel :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardQModel QModel dev
model QEncoding dev '[]
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor dev QDType '[1, 1] -> Tensor dev QDType '[1])
-> Tensor dev QDType '[1, 1] -> Tensor dev QDType '[1]
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[1] -> Tensor dev QDType '[1, 1]
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched QModel dev
model (QEncoding dev '[1] -> Tensor dev QDType '[1, 1])
-> QEncoding dev '[1] -> Tensor dev QDType '[1, 1]
forall a b. (a -> b) -> a -> b
$ QEncoding dev '[] -> Batched (QEncoding dev '[])
forall a. Batchable a => a -> Batched a
addBatchDim QEncoding dev '[]
input

forwardQModelBatched
  :: forall dev batchSize
   . ( IsValidDevice dev
     , 1 <= batchSize
     )
  => QModel dev
  -> QEncoding dev '[batchSize]
  -> QTensor dev '[batchSize, 1]
forwardQModelBatched :: forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched (QModel SliceEncoder dev
slc TransitionEncoder dev
tr ActionEncoder dev
act StateEncoder dev
st Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
final1 LayerNorm '[QSliceHidden] QDType dev
norm1 Linear QSliceHidden 1 QDType dev
final2 Linear QSliceHidden QSliceHidden QDType dev
_ LayerNorm '[QSliceHidden] QDType dev
_ Linear QSliceHidden 1 QDType dev
_) (QEncoding ActionEncoding dev '[batchSize]
actEncs StateEncoding dev
stEnc) = QTensor dev '[batchSize, 1]
out2
 where
  actEmb :: QTensor dev (batchSize : EmbSize : PShape)
  actEmb :: QTensor dev (batchSize : EmbShape)
actEmb = ActionEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev,
    ActionEncoding dev '[batchSize])
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward ActionEncoder dev
act (SliceEncoder dev
slc, TransitionEncoder dev
tr, ActionEncoding dev '[batchSize]
actEncs)
  stEmb :: QTensor dev (EmbSize : PShape)
  stEmb :: QTensor dev EmbShape
stEmb = StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> QTensor dev '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward StateEncoder dev
st (SliceEncoder dev
slc, TransitionEncoder dev
tr, StateEncoding dev
stEnc)
  inputEmb :: QTensor dev '[batchSize, QSliceHidden, 13, 5]
inputEmb = QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
actEmb QTensor dev '[batchSize, QSliceHidden, 13, 5]
-> QTensor dev '[QSliceHidden, 13, 5]
-> QTensor dev '[batchSize, QSliceHidden, 13, 5]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` QTensor dev '[QSliceHidden, 13, 5]
QTensor dev EmbShape
stEmb
  out1 :: QTensor dev (batchSize : QOutHidden : PShape)
  out1 :: QTensor dev (batchSize : EmbShape)
out1 = forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
TH.conv2dForwardRelaxed @'(1, 1) @'(FifthPadding, OctavePadding) Conv2d QSliceHidden QSliceHidden 13 5 QDType dev
Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
final1 QTensor dev '[batchSize, QSliceHidden, 13, 5]
inputEmb
  sum1 :: QTensor dev '[batchSize, QOutHidden]
  sum1 :: QTensor dev '[batchSize, QSliceHidden]
sum1 = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @2 (Tensor dev QDType (DropValue '[batchSize, QSliceHidden, 13, 5] 2)
 -> QTensor dev '[batchSize, QSliceHidden])
-> Tensor
     dev QDType (DropValue '[batchSize, QSliceHidden, 13, 5] 2)
-> QTensor dev '[batchSize, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @2 QTensor dev '[batchSize, QSliceHidden, 13, 5]
QTensor dev (batchSize : EmbShape)
out1
  out1norm :: QTensor dev '[batchSize, QOutHidden]
  out1norm :: QTensor dev '[batchSize, QSliceHidden]
out1norm = QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (QTensor dev '[batchSize, QSliceHidden]
 -> QTensor dev '[batchSize, QSliceHidden])
-> QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden]
forall a b. (a -> b) -> a -> b
$ LayerNorm '[QSliceHidden] QDType dev
-> QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward LayerNorm '[QSliceHidden] QDType dev
norm1 QTensor dev '[batchSize, QSliceHidden]
sum1
  out2 :: QTensor dev '[batchSize, 1]
  out2 :: QTensor dev '[batchSize, 1]
out2 = Linear QSliceHidden 1 QDType dev
-> QTensor dev '[batchSize, QSliceHidden]
-> QTensor dev '[batchSize, 1]
forall f a b. HasForward f a b => f -> a -> b
T.forward Linear QSliceHidden 1 QDType dev
final2 QTensor dev '[batchSize, QSliceHidden]
out1norm

forwardPolicy
  :: (_)
  => QModel dev
  -> QEncoding dev '[]
  -> QTensor dev '[1]
forwardPolicy :: QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardPolicy = QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forwardQModel

forwardPolicyBatched
  :: forall dev batchSize
   . (_)
  => QModel dev
  -> QEncoding dev '[batchSize]
  -> QTensor dev '[batchSize, 1]
forwardPolicyBatched :: QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardPolicyBatched = QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, 1 <= batchSize) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardQModelBatched

forwardValue
  :: (IsValidDevice dev)
  => QModel dev
  -> StateEncoding dev
  -> QTensor dev '[1]
forwardValue :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> StateEncoding dev -> QTensor dev '[1]
forwardValue (QModel SliceEncoder dev
slc TransitionEncoder dev
tr ActionEncoder dev
_ StateEncoder dev
st Conv2d QSliceHidden QSliceHidden FifthSize OctaveSize QDType dev
_ LayerNorm '[QSliceHidden] QDType dev
_ Linear QSliceHidden 1 QDType dev
_ Linear QSliceHidden QSliceHidden QDType dev
value1 LayerNorm '[QSliceHidden] QDType dev
norm Linear QSliceHidden 1 QDType dev
value2) StateEncoding dev
stateEncoding = Tensor dev QDType '[1]
out2
 where
  outSlc :: Tensor
  dev
  (SumDType (SumDType QDType))
  (DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
outSlc = forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @1 (Tensor dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1)
 -> Tensor
      dev
      (SumDType (SumDType QDType))
      (DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1))
-> Tensor
     dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1)
-> Tensor
     dev
     (SumDType (SumDType QDType))
     (DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
forall a b. (a -> b) -> a -> b
$ forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
TT.sumDim @1 (Tensor dev QDType '[QSliceHidden, 13, 5]
 -> Tensor
      dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1))
-> Tensor dev QDType '[QSliceHidden, 13, 5]
-> Tensor
     dev (SumDType QDType) (DropValue '[QSliceHidden, 13, 5] 1)
forall a b. (a -> b) -> a -> b
$ StateEncoder dev
-> (SliceEncoder dev, TransitionEncoder dev, StateEncoding dev)
-> Tensor dev QDType '[QSliceHidden, 13, 5]
forall f a b. HasForward f a b => f -> a -> b
T.forward StateEncoder dev
st (SliceEncoder dev
slc, TransitionEncoder dev
tr, StateEncoding dev
stateEncoding)
  out1 :: Tensor dev QDType '[QSliceHidden]
out1 = Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall (dev :: (DeviceType, Nat)) (shape :: [Nat]).
IsValidDevice dev =>
QTensor dev shape -> QTensor dev shape
activation (Tensor dev QDType '[QSliceHidden]
 -> Tensor dev QDType '[QSliceHidden])
-> Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall a b. (a -> b) -> a -> b
$ LayerNorm '[QSliceHidden] QDType dev
-> Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward LayerNorm '[QSliceHidden] QDType dev
norm (Tensor dev QDType '[QSliceHidden]
 -> Tensor dev QDType '[QSliceHidden])
-> Tensor dev QDType '[QSliceHidden]
-> Tensor dev QDType '[QSliceHidden]
forall a b. (a -> b) -> a -> b
$ Linear QSliceHidden QSliceHidden QDType dev
-> Tensor
     dev
     (SumDType (SumDType QDType))
     (DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
-> Tensor dev QDType '[QSliceHidden]
forall f a b. HasForward f a b => f -> a -> b
T.forward Linear QSliceHidden QSliceHidden QDType dev
value1 Tensor
  dev
  (SumDType (SumDType QDType))
  (DropValue (DropValue '[QSliceHidden, 13, 5] 1) 1)
outSlc
  out2 :: Tensor dev QDType '[1]
out2 = Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (Tensor dev QDType '[1] -> Tensor dev QDType '[1])
-> Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.sigmoid (Tensor dev QDType '[1] -> Tensor dev QDType '[1])
-> Tensor dev QDType '[1] -> Tensor dev QDType '[1]
forall a b. (a -> b) -> a -> b
$ Linear QSliceHidden 1 QDType dev
-> Tensor dev QDType '[QSliceHidden] -> Tensor dev QDType '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward Linear QSliceHidden 1 QDType dev
value2 Tensor dev QDType '[QSliceHidden]
out1

{- | A loss for any model with 0 gradients everywhere.
Can be used to ensure that all parameters have a gradient,
if not all parameters are used in the real loss.
-}
fakeLoss
  :: forall dev ps
   . (IsValidDevice dev, ps ~ TT.Parameters (QModel dev))
  => QModel dev
  -> QTensor dev '[]
fakeLoss :: forall (dev :: (DeviceType, Nat)) (ps :: [Type]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
model = Tensor dev QDType '[]
tzero Tensor dev QDType '[]
-> Tensor dev QDType '[] -> Tensor dev QDType '[]
forall a. Num a => a -> a -> a
* Tensor dev QDType '[]
total
 where
  tzero :: QTensor dev '[]
  tzero :: Tensor dev QDType '[]
tzero = Tensor dev QDType '[]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
  params :: HList (Parameters (QModel dev))
params = QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model
  deps :: (TT.HMap' TT.ToDependent ps ys) => TT.HList ys
  deps :: forall (ys :: [Type]). HMap' ToDependent ps ys => HList ys
deps = ToDependent
-> HList
     '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
       Parameter dev QDType '[5], Parameter dev QDType '[5],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
       Parameter dev QDType '[QSliceHidden, QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList ys
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ToDependent
TT.ToDependent HList
  '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden, 2, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, 2, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, 1, 1, 1],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, 1, 1, 1],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
    Parameter dev QDType '[5], Parameter dev QDType '[5],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
    Parameter dev QDType '[QSliceHidden, QSliceHidden],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[QSliceHidden],
    Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
HList (Parameters (QModel dev))
params
  sums :: HList
  '[Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[]]
sums = SumAll
-> HList
     '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden, 2, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 2, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
       Tensor dev QDType '[QSliceHidden, QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
-> HList
     '[Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' SumAll
TH.SumAll HList
  '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
    Tensor dev QDType '[QSliceHidden, QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
forall (ys :: [Type]). HMap' ToDependent ps ys => HList ys
deps
  -- total
  total :: Tensor dev QDType '[]
total = Add
-> Tensor dev QDType '[]
-> HList
     '[Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[],
       Tensor dev QDType '[], Tensor dev QDType '[]]
-> Tensor dev QDType '[]
forall {k} f acc (xs :: [k]) res.
HFoldr f acc xs res =>
f -> acc -> HList xs -> res
TT.hfoldr Add
TH.Add Tensor dev QDType '[]
tzero HList
  '[Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[],
    Tensor dev QDType '[], Tensor dev QDType '[]]
sums

mkQModel :: forall dev. (IsValidDevice dev) => IO (QModel dev)
mkQModel :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
IO (QModel dev)
mkQModel = QSpec dev -> IO (QModel dev)
forall spec f. Randomizable spec f => spec -> IO f
T.sample (QSpec dev -> IO (QModel dev)) -> QSpec dev -> IO (QModel dev)
forall a b. (a -> b) -> a -> b
$ forall {k} (dev :: k). QSpec dev
forall (dev :: (DeviceType, Nat)). QSpec dev
QSpec @dev

loadModel :: forall dev. (IsValidDevice dev) => FilePath -> IO (QModel dev)
loadModel :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
String -> IO (QModel dev)
loadModel String
path = do
  modelPlaceholder <- forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
IO (QModel dev)
mkQModel @dev
  tensors
    :: (TT.HMap' TT.ToDependent (TT.Parameters (QModel dev)) ts)
    => TT.HList ts <-
    TT.load path
  -- TT.load doesn't move the parameters to the correct device, so we move them manually
  let tensorsCPU = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       f g.
HasToDevice device' device f g =>
f -> g
TT.toDevice @'(TT.CPU, 0) @dev HList
  '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
    Tensor dev QDType '[QSliceHidden, QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
HMap'
  ToDependent
  (Parameters (QModel dev))
  '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
    Tensor dev QDType '[QSliceHidden, QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]] =>
HList
  '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
    Tensor dev QDType '[QSliceHidden, QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
tensors
  let tensorsDevice = forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
       f g.
HasToDevice device' device f g =>
f -> g
TT.toDevice @dev @'(TT.CPU, 0) HList
  '[Tensor '( 'CPU, 0) QDType '[QSliceHidden, 1, 1, 1],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, 2, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, 2, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, 1, 1, 1],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, 1, 1, 1],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[5], Tensor '( 'CPU, 0) QDType '[5],
    Tensor '( 'CPU, 0) QDType '[5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[1, QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[1],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden, QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[1, QSliceHidden],
    Tensor '( 'CPU, 0) QDType '[1]]
tensorsCPU
  params <- TT.hmapM' TT.MakeIndependent tensorsDevice
  pure $ TT.replaceParameters modelPlaceholder params

saveModel :: FilePath -> QModel dev -> IO ()
saveModel :: forall (dev :: (DeviceType, Nat)). String -> QModel dev -> IO ()
saveModel String
path QModel dev
model = HList
  '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 2, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, 1, 1, 1],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
    Tensor dev QDType '[QSliceHidden, QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[QSliceHidden],
    Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
-> String -> IO ()
forall (tensors :: [Type]).
Castable (HList tensors) [ForeignPtr Tensor] =>
HList tensors -> String -> IO ()
TT.save (ToDependent
-> HList
     '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
       Parameter dev QDType '[5], Parameter dev QDType '[5],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
       Parameter dev QDType '[QSliceHidden, QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
     '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden, 2, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 2, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
       Tensor dev QDType '[QSliceHidden, QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ToDependent
TT.ToDependent (HList
   '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden, 2, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 2, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 1, 1, 1],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 1, 1, 1],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
     Parameter dev QDType '[5], Parameter dev QDType '[5],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
     Parameter dev QDType '[QSliceHidden, QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
 -> HList
      '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden, 2, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, 2, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, 1, 1, 1],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, 1, 1, 1],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
        Tensor dev QDType '[5], Tensor dev QDType '[5],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
        Tensor dev QDType '[QSliceHidden, QSliceHidden],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[QSliceHidden],
        Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]])
-> HList
     '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
       Parameter dev QDType '[5], Parameter dev QDType '[5],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
       Parameter dev QDType '[QSliceHidden, QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
     '[Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden, 2, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 2, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, 1, 1, 1],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1],
       Tensor dev QDType '[QSliceHidden, QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[QSliceHidden],
       Tensor dev QDType '[1, QSliceHidden], Tensor dev QDType '[1]]
forall a b. (a -> b) -> a -> b
$ QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model) String
path

modelSize :: QModel dev -> Int
modelSize :: forall (dev :: (DeviceType, Nat)). QModel dev -> Int
modelSize QModel dev
model = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
product ([Int] -> Int) -> [[Int]] -> [Int]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Int]]
sizes
 where
  sizes :: [[Int]]
sizes = ToList
-> [[Int]]
-> HList
     '[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int]]
-> [[Int]]
forall {k} f acc (xs :: [k]) res.
HFoldr f acc xs res =>
f -> acc -> HList xs -> res
TT.hfoldr ToList
TH.ToList ([] :: [[Int]]) (HList
   '[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
     [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
     [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
     [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
     [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
     [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
     [Int], [Int], [Int], [Int]]
 -> [[Int]])
-> HList
     '[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int]]
-> [[Int]]
forall a b. (a -> b) -> a -> b
$ ShapeVal
-> HList
     '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
       Parameter dev QDType '[5], Parameter dev QDType '[5],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
       Parameter dev QDType '[QSliceHidden, QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
     '[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ShapeVal
TH.ShapeVal (HList
   '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden, 2, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 2, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 1, 1, 1],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, 1, 1, 1],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
     Parameter dev QDType '[5], Parameter dev QDType '[5],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
     Parameter dev QDType '[QSliceHidden, QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[QSliceHidden],
     Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
 -> HList
      '[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
        [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
        [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
        [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
        [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
        [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
        [Int], [Int], [Int], [Int]])
-> HList
     '[Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 2, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, 1, 1, 1],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden], Parameter dev QDType '[5],
       Parameter dev QDType '[5], Parameter dev QDType '[5],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden, QSliceHidden, 13, 5],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1],
       Parameter dev QDType '[QSliceHidden, QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[QSliceHidden],
       Parameter dev QDType '[1, QSliceHidden], Parameter dev QDType '[1]]
-> HList
     '[[Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int], [Int],
       [Int], [Int], [Int], [Int]]
forall a b. (a -> b) -> a -> b
$ QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model

runQ
  :: (IsValidDevice dev)
  => (s -> a -> QEncoding dev '[])
  -> QModel dev
  -> s
  -> a
  -> QType
runQ :: forall (dev :: (DeviceType, Nat)) s a.
IsValidDevice dev =>
(s -> a -> QEncoding dev '[]) -> QModel dev -> s -> a -> QType
runQ !s -> a -> QEncoding dev '[]
encode !QModel dev
model s
s a
a = Tensor -> QType
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> QType) -> Tensor -> QType
forall a b. (a -> b) -> a -> b
$ QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> Tensor
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model (QEncoding dev '[] -> QTensor dev '[1])
-> QEncoding dev '[] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ s -> a -> QEncoding dev '[]
encode s
s a
a

runQ'
  :: (IsValidDevice dev)
  => (s -> a -> QEncoding dev '[])
  -> QModel dev
  -> s
  -> a
  -> QTensor dev '[1]
runQ' :: forall (dev :: (DeviceType, Nat)) s a.
IsValidDevice dev =>
(s -> a -> QEncoding dev '[])
-> QModel dev -> s -> a -> QTensor dev '[1]
runQ' !s -> a -> QEncoding dev '[]
encode !QModel dev
model s
s a
a = QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
model (QEncoding dev '[] -> QTensor dev '[1])
-> QEncoding dev '[] -> QTensor dev '[1]
forall a b. (a -> b) -> a -> b
$ s -> a -> QEncoding dev '[]
encode s
s a
a

runBatchedPolicy
  :: forall dev batchSize
   . (IsValidDevice dev, KnownNat batchSize)
  => QModel dev
  -> QEncoding dev '[batchSize]
  -> T.Tensor
runBatchedPolicy :: forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
actor QEncoding dev '[batchSize]
encoding = Tensor dev QDType '[batchSize, 1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (Tensor dev QDType '[batchSize, 1] -> Tensor)
-> Tensor dev QDType '[batchSize, 1] -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
TT.softmax @0 (Tensor dev QDType '[batchSize, 1]
 -> Tensor dev QDType '[batchSize, 1])
-> Tensor dev QDType '[batchSize, 1]
-> Tensor dev QDType '[batchSize, 1]
forall a b. (a -> b) -> a -> b
$ Tensor dev QDType '[batchSize, 1]
policy
 where
  policy :: QTensor dev '[batchSize, 1]
  policy :: Tensor dev QDType '[batchSize, 1]
policy = case Proxy 1 -> Proxy batchSize -> OrderingI 1 batchSize
forall (a :: Nat) (b :: Nat) (proxy1 :: Nat -> Type)
       (proxy2 :: Nat -> Type).
(KnownNat a, KnownNat b) =>
proxy1 a -> proxy2 b -> OrderingI a b
cmpNat (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @1) (forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @batchSize) of
    OrderingI 1 batchSize
EQI -> forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(OrdCond (CmpNat 1 batchSize) 'True 'True 'False ~ 'True,
 GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType,
 BasicArithmeticDTypeIsValid dev QDType, SumDTypeIsValid dev QDType,
 MeanDTypeValidation dev QDType,
 StandardFloatingPointDTypeValidation dev QDType,
 KnownDevice dev) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardPolicyBatched @dev @batchSize QModel dev
actor QEncoding dev '[batchSize]
encoding
    OrderingI 1 batchSize
LTI -> forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(OrdCond (CmpNat 1 batchSize) 'True 'True 'False ~ 'True,
 GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType,
 BasicArithmeticDTypeIsValid dev QDType, SumDTypeIsValid dev QDType,
 MeanDTypeValidation dev QDType,
 StandardFloatingPointDTypeValidation dev QDType,
 KnownDevice dev) =>
QModel dev
-> QEncoding dev '[batchSize] -> QTensor dev '[batchSize, 1]
forwardPolicyBatched @dev @batchSize QModel dev
actor QEncoding dev '[batchSize]
encoding
    OrderingI 1 batchSize
GTI -> String -> Tensor dev QDType '[batchSize, 1]
forall a. HasCallStack => String -> a
error String
"batched policy: no actions"