{-# LANGUAGE DataKinds #-}

module RL.Jit where

import RL.Encoding
import RL.Model

import Data.TypeNums (KnownNat)
import RL.ModelTypes (IsValidDevice)
import Torch qualified as T
import Torch.Jit qualified as TJit
import Torch.Lens qualified as TL

compileBatchedPolicy :: forall dev bs. (IsValidDevice dev, KnownNat bs) => TJit.ScriptCache -> QModel dev -> QEncoding dev '[bs] -> T.Tensor
compileBatchedPolicy :: forall (dev :: (DeviceType, Nat)) (bs :: Nat).
(IsValidDevice dev, KnownNat bs) =>
ScriptCache -> QModel dev -> QEncoding dev '[bs] -> Tensor
compileBatchedPolicy ScriptCache
scriptCache QModel dev
model QEncoding dev '[bs]
encoding =
  [Tensor] -> Tensor
forall a. HasCallStack => [a] -> a
head ([Tensor] -> Tensor) -> [Tensor] -> Tensor
forall a b. (a -> b) -> a -> b
$ ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
TJit.jit ScriptCache
scriptCache [Tensor] -> [Tensor]
policy ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ Traversal' (QModel dev, QEncoding dev '[bs]) Tensor
-> (QModel dev, QEncoding dev '[bs]) -> [Tensor]
forall a s. Traversal' s a -> s -> [a]
TL.flattenValues (Tensor -> f Tensor)
-> (QModel dev, QEncoding dev '[bs])
-> f (QModel dev, QEncoding dev '[bs])
forall a s. HasTypes s a => Traversal' s a
Traversal' (QModel dev, QEncoding dev '[bs]) Tensor
TL.types (QModel dev
model, QEncoding dev '[bs]
encoding)
 where
  policy :: [T.Tensor] -> [T.Tensor]
  policy :: [Tensor] -> [Tensor]
policy [Tensor]
tensors = [QModel dev -> QEncoding dev '[bs] -> Tensor
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
model' QEncoding dev '[bs]
encoding']
   where
    (QModel dev
model', QEncoding dev '[bs]
encoding') = Traversal' (QModel dev, QEncoding dev '[bs]) Tensor
-> (QModel dev, QEncoding dev '[bs])
-> [Tensor]
-> (QModel dev, QEncoding dev '[bs])
forall a s. Traversal' s a -> s -> [a] -> s
TL.replaceValues (Tensor -> f Tensor)
-> (QModel dev, QEncoding dev '[bs])
-> f (QModel dev, QEncoding dev '[bs])
forall a s. HasTypes s a => Traversal' s a
Traversal' (QModel dev, QEncoding dev '[bs]) Tensor
TL.types (QModel dev
model, QEncoding dev '[bs]
encoding) [Tensor]
tensors