{-# LANGUAGE DataKinds #-}
{-# LANGUAGE Strict #-}
module RL.A2CHelpers where
import Control.DeepSeq (force)
import RL.Model
import RL.ModelTypes
import RL.TorchHelpers
import Torch.Typed qualified as TT
type ModelParams dev = TT.Parameters (QModel dev)
type ModelTensors dev = ToModelTensors (ModelParams dev)
newtype UpdateEligCritic = UpdateEligCritic QType
instance (TT.KnownDevice dev) => TT.Apply' UpdateEligCritic (QTensor dev shape, QTensor dev shape) (QTensor dev shape) where
apply' :: UpdateEligCritic
-> (QTensor dev shape, QTensor dev shape) -> QTensor dev shape
apply' (UpdateEligCritic QType
factor) (QTensor dev shape
zV, QTensor dev shape
grad) = QType -> QTensor dev shape -> QTensor dev shape
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
factor QTensor dev shape
zV QTensor dev shape -> QTensor dev shape -> QTensor dev shape
forall a. Num a => a -> a -> a
+ QTensor dev shape
grad
updateEligCritic :: (TT.KnownDevice dev) => QType -> QType -> TT.HList (ModelTensors dev) -> TT.HList (ModelTensors dev) -> TT.HList (ModelTensors dev)
updateEligCritic :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
updateEligCritic QType
gamma QType
lambdaV = (HList (ModelTensors dev)
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall a. NFData a => a -> a
force ((HList (ModelTensors dev)
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev))
-> (HList (ModelTensors dev)
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ UpdateEligCritic
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
TT.hzipWith (QType -> UpdateEligCritic
UpdateEligCritic (QType -> UpdateEligCritic) -> QType -> UpdateEligCritic
forall a b. (a -> b) -> a -> b
$ QType
gamma QType -> QType -> QType
forall a. Num a => a -> a -> a
* QType
lambdaV)
{-# NOINLINE updateEligCritic #-}
data UpdateEligActor = UpdateEligActor QType QType
instance (TT.KnownDevice dev) => TT.Apply' UpdateEligActor (QTensor dev shape, QTensor dev shape) (QTensor dev shape) where
apply' :: UpdateEligActor
-> (QTensor dev shape, QTensor dev shape) -> QTensor dev shape
apply' (UpdateEligActor QType
intensity QType
factor) (QTensor dev shape
zP, QTensor dev shape
grad) =
QType -> QTensor dev shape -> QTensor dev shape
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
factor QTensor dev shape
zP QTensor dev shape -> QTensor dev shape -> QTensor dev shape
forall a. Num a => a -> a -> a
+ QType -> QTensor dev shape -> QTensor dev shape
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
intensity QTensor dev shape
grad
updateEligActor :: (TT.KnownDevice dev) => QType -> QType -> QType -> TT.HList (ModelTensors dev) -> TT.HList (ModelTensors dev) -> TT.HList (ModelTensors dev)
updateEligActor :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType
-> QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
updateEligActor QType
gamma QType
lambdaP QType
intensity =
(HList (ModelTensors dev)
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall a. NFData a => a -> a
force ((HList (ModelTensors dev)
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev))
-> (HList (ModelTensors dev)
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ UpdateEligActor
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
TT.hzipWith (QType -> QType -> UpdateEligActor
UpdateEligActor QType
intensity (QType -> UpdateEligActor) -> QType -> UpdateEligActor
forall a b. (a -> b) -> a -> b
$ QType
gamma QType -> QType -> QType
forall a. Num a => a -> a -> a
* QType
lambdaP)
{-# NOINLINE updateEligActor #-}
mulModelTensors :: (IsValidDevice dev) => QTensor dev '[] -> TT.HList (ModelTensors dev) -> TT.HList (ModelTensors dev)
mulModelTensors :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QTensor dev '[]
-> HList (ModelTensors dev) -> HList (ModelTensors dev)
mulModelTensors QTensor dev '[]
factor = (HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev) -> HList (ModelTensors dev)
forall a. NFData a => a -> a
force ((HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev) -> HList (ModelTensors dev))
-> (HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ Mul' dev 'Double
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' (QTensor dev '[] -> Mul' dev 'Double
forall (dev :: (DeviceType, Nat)) (dtype :: DType).
Tensor dev dtype '[] -> Mul' dev dtype
Mul' QTensor dev '[]
factor)
{-# NOINLINE mulModelTensors #-}
modelZeros :: (IsValidDevice dev) => QModel dev -> TT.HList (ModelTensors dev)
modelZeros :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> HList (ModelTensors dev)
modelZeros QModel dev
model = HList (ModelTensors dev) -> HList (ModelTensors dev)
forall a. NFData a => a -> a
force (HList (ModelTensors dev) -> HList (ModelTensors dev))
-> HList (ModelTensors dev) -> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ ZerosLike
-> HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList (ModelTensors dev)
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' ZerosLike
TT.ZerosLike (HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList (ModelTensors dev))
-> HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model
{-# NOINLINE modelZeros #-}
sumTensorList :: forall dev. (IsValidDevice dev) => TT.HList (ModelTensors dev) -> QTensor dev '[]
sumTensorList :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
HList (ModelTensors dev) -> QTensor dev '[]
sumTensorList HList (ModelTensors dev)
ts = Add
-> QTensor dev '[]
-> HList
'[QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[]]
-> QTensor dev '[]
forall {k} f acc (xs :: [k]) res.
HFoldr f acc xs res =>
f -> acc -> HList xs -> res
TT.hfoldr Add
Add (QTensor dev '[]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros :: QTensor dev '[]) (HList
'[QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[]]
-> QTensor dev '[])
-> HList
'[QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[]]
-> QTensor dev '[]
forall a b. (a -> b) -> a -> b
$ SumAll
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
-> HList
'[QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[], QTensor dev '[],
QTensor dev '[], QTensor dev '[], QTensor dev '[]]
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
TT.hmap' SumAll
SumAll HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
ts