{-# LANGUAGE DataKinds #-} module RL.Jit where import RL.Encoding import RL.Model import Data.TypeNums (KnownNat) import RL.ModelTypes (IsValidDevice) import Torch qualified as T import Torch.Jit qualified as TJit import Torch.Lens qualified as TL compileBatchedPolicy :: forall dev bs. (IsValidDevice dev, KnownNat bs) => TJit.ScriptCache -> QModel dev -> QEncoding dev '[bs] -> T.Tensor compileBatchedPolicy :: forall (dev :: (DeviceType, Nat)) (bs :: Nat). (IsValidDevice dev, KnownNat bs) => ScriptCache -> QModel dev -> QEncoding dev '[bs] -> Tensor compileBatchedPolicy ScriptCache scriptCache QModel dev model QEncoding dev '[bs] encoding = [Tensor] -> Tensor forall a. HasCallStack => [a] -> a head ([Tensor] -> Tensor) -> [Tensor] -> Tensor forall a b. (a -> b) -> a -> b $ ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor] TJit.jit ScriptCache scriptCache [Tensor] -> [Tensor] policy ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor] forall a b. (a -> b) -> a -> b $ Traversal' (QModel dev, QEncoding dev '[bs]) Tensor -> (QModel dev, QEncoding dev '[bs]) -> [Tensor] forall a s. Traversal' s a -> s -> [a] TL.flattenValues (Tensor -> f Tensor) -> (QModel dev, QEncoding dev '[bs]) -> f (QModel dev, QEncoding dev '[bs]) forall a s. HasTypes s a => Traversal' s a Traversal' (QModel dev, QEncoding dev '[bs]) Tensor TL.types (QModel dev model, QEncoding dev '[bs] encoding) where policy :: [T.Tensor] -> [T.Tensor] policy :: [Tensor] -> [Tensor] policy [Tensor] tensors = [QModel dev -> QEncoding dev '[bs] -> Tensor forall (dev :: (DeviceType, Nat)) (batchSize :: Nat). (IsValidDevice dev, KnownNat batchSize) => QModel dev -> QEncoding dev '[batchSize] -> Tensor runBatchedPolicy QModel dev model' QEncoding dev '[bs] encoding'] where (QModel dev model', QEncoding dev '[bs] encoding') = Traversal' (QModel dev, QEncoding dev '[bs]) Tensor -> (QModel dev, QEncoding dev '[bs]) -> [Tensor] -> (QModel dev, QEncoding dev '[bs]) forall a s. Traversal' s a -> s -> [a] -> s TL.replaceValues (Tensor -> f Tensor) -> (QModel dev, QEncoding dev '[bs]) -> f (QModel dev, QEncoding dev '[bs]) forall a s. HasTypes s a => Traversal' s a Traversal' (QModel dev, QEncoding dev '[bs]) Tensor TL.types (QModel dev model, QEncoding dev '[bs] encoding) [Tensor] tensors