{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module RL.TorchHelpers where
import Data.Kind (Type)
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import Torch qualified as T
import Torch qualified as TD
import Torch.Internal.Cast qualified as ATen
import Torch.Internal.Managed.Native qualified as ATen.Managed
import Torch.Internal.Type qualified as ATen
import Torch.Typed qualified as TT
import Torch.Typed.Auxiliary qualified
data SumAll = SumAll
instance
(dtype' ~ TT.SumDType dtype, TT.SumDTypeIsValid dev dtype)
=> TT.Apply' SumAll (TT.Tensor dev dtype shape) (TT.Tensor dev dtype' '[])
where
apply' :: SumAll -> Tensor dev dtype shape -> Tensor dev dtype' '[]
apply' SumAll
_ = Tensor dev dtype shape -> Tensor dev dtype' '[]
forall (shape :: [Nat]) (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)).
(SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' '[]
TT.sumAll
data Add = Add
instance
( TT.BasicArithmeticDTypeIsValid dev dtype
, TT.CheckBroadcast
shape1
shape2
( TT.ComputeBroadcast
(TT.ReverseImpl shape1 '[])
(TT.ReverseImpl shape2 '[])
)
~ shapeOut
)
=> TT.Apply' Add (TT.Tensor dev dtype shape1, TT.Tensor dev dtype shape2) (TT.Tensor dev dtype shapeOut)
where
apply' :: Add
-> (Tensor dev dtype shape1, Tensor dev dtype shape2)
-> Tensor dev dtype shapeOut
apply' Add
_ (Tensor dev dtype shape1
a, Tensor dev dtype shape2
b) = Tensor dev dtype shape1
-> Tensor dev dtype shape2 -> Tensor dev dtype shapeOut
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.add Tensor dev dtype shape1
a Tensor dev dtype shape2
b
newtype Mul num
= Mul num
instance
(TT.Scalar num)
=> TT.Apply' (Mul num) (TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape)
where
apply' :: Mul num -> Tensor dev dtype shape -> Tensor dev dtype shape
apply' (Mul num
n) = num -> Tensor dev dtype shape -> Tensor dev dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar num
n
newtype Mul' dev dtype
= Mul' (TT.Tensor dev dtype '[])
instance
(shape ~ TT.Broadcast '[] shape, TT.BasicArithmeticDTypeIsValid dev dtype)
=> TT.Apply' (Mul' dev dtype) (TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape)
where
apply' :: Mul' dev dtype -> Tensor dev dtype shape -> Tensor dev dtype shape
apply' (Mul' Tensor dev dtype '[]
n) = Tensor dev dtype '[]
-> Tensor dev dtype shape -> Tensor dev dtype shape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul Tensor dev dtype '[]
n
detach :: TT.Tensor dev dtype shape -> IO (TT.Tensor dev dtype shape)
detach :: forall (dev :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor dev dtype shape -> IO (Tensor dev dtype shape)
detach = (Tensor -> Tensor dev dtype shape)
-> IO Tensor -> IO (Tensor dev dtype shape)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tensor -> Tensor dev dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (IO Tensor -> IO (Tensor dev dtype shape))
-> (Tensor dev dtype shape -> IO Tensor)
-> Tensor dev dtype shape
-> IO (Tensor dev dtype shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IO Tensor
TD.detach (Tensor -> IO Tensor)
-> (Tensor dev dtype shape -> Tensor)
-> Tensor dev dtype shape
-> IO Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor dev dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic
data Detach = Detach
instance TT.Apply' Detach (TT.Tensor dev dtype shape) (IO (TT.Tensor dev dtype shape)) where
apply' :: Detach -> Tensor dev dtype shape -> IO (Tensor dev dtype shape)
apply' Detach
_ = Tensor dev dtype shape -> IO (Tensor dev dtype shape)
forall (dev :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor dev dtype shape -> IO (Tensor dev dtype shape)
detach
newtype Interpolate num = Interpolate num
instance
( TT.Scalar num
, Num num
, TT.BasicArithmeticDTypeIsValid dev dtype
, TT.CheckBroadcast
shape
shape
( TT.ComputeBroadcast
(TT.ReverseImpl shape '[])
(TT.ReverseImpl shape '[])
)
~ shape
)
=> TT.Apply' (Interpolate num) (TT.Tensor dev dtype shape, TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape)
where
apply' :: Interpolate num
-> (Tensor dev dtype shape, Tensor dev dtype shape)
-> Tensor dev dtype shape
apply' (Interpolate num
tau) (Tensor dev dtype shape
p, Tensor dev dtype shape
t) = num -> Tensor dev dtype shape -> Tensor dev dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar num
tau Tensor dev dtype shape
p Tensor dev dtype shape
-> Tensor dev dtype shape -> Tensor dev dtype shape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` num -> Tensor dev dtype shape -> Tensor dev dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar (num
1 num -> num -> num
forall a. Num a => a -> a -> a
- num
tau) Tensor dev dtype shape
t
data ShapeVal = ShapeVal
instance (TT.KnownShape shape) => TT.Apply' ShapeVal (TT.Tensor dev dtype shape) [Int] where
apply' :: ShapeVal -> Tensor dev dtype shape -> [Int]
apply' ShapeVal
_ Tensor dev dtype shape
t = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @shape
instance (TT.KnownShape shape) => TT.Apply' ShapeVal (TT.Parameter dev dtype shape) [Int] where
apply' :: ShapeVal -> Parameter dev dtype shape -> [Int]
apply' ShapeVal
_ Parameter dev dtype shape
t = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @shape
data ToList = ToList
instance TT.Apply' ToList (t, [t]) [t] where
apply' :: ToList -> (t, [t]) -> [t]
apply' ToList
_ (t
x, [t]
xs) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
xs
type family ToModelTensors (params :: [Type]) :: [Type] where
ToModelTensors '[] = '[]
ToModelTensors (TT.Parameter dev dtype shape ': rst) = TT.Tensor dev dtype shape : ToModelTensors rst
withBatchDim
:: forall dev1 dtype1 shape1 dev2 dtype2 shape2
. (TT.Tensor dev1 dtype1 (1 : shape1) -> TT.Tensor dev2 dtype2 (1 : shape2))
-> TT.Tensor dev1 dtype1 shape1
-> TT.Tensor dev2 dtype2 shape2
withBatchDim :: forall (dev1 :: (DeviceType, Nat)) (dtype1 :: DType)
(shape1 :: [Nat]) (dev2 :: (DeviceType, Nat)) (dtype2 :: DType)
(shape2 :: [Nat]).
(Tensor dev1 dtype1 (1 : shape1)
-> Tensor dev2 dtype2 (1 : shape2))
-> Tensor dev1 dtype1 shape1 -> Tensor dev2 dtype2 shape2
withBatchDim Tensor dev1 dtype1 (1 : shape1) -> Tensor dev2 dtype2 (1 : shape2)
op Tensor dev1 dtype1 shape1
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor dev2 dtype2 (1 : shape2) -> Tensor dev2 dtype2 shape2)
-> Tensor dev2 dtype2 (1 : shape2) -> Tensor dev2 dtype2 shape2
forall a b. (a -> b) -> a -> b
$ Tensor dev1 dtype1 (1 : shape1) -> Tensor dev2 dtype2 (1 : shape2)
op Tensor dev1 dtype1 (1 : shape1)
batchedIn
where
batchedIn :: TT.Tensor dev1 dtype1 (1 : shape1)
batchedIn :: Tensor dev1 dtype1 (1 : shape1)
batchedIn = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev1 dtype1 shape1
input
conv2dRelaxed
:: forall
(stride :: (Nat, Nat))
(padding :: (Nat, Nat))
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
inputSize0
inputSize1
batchSize
outputSize0
outputSize1
dtype
device
. ( TT.All
KnownNat
'[ Torch.Typed.Auxiliary.Fst stride
, Torch.Typed.Auxiliary.Snd stride
, Torch.Typed.Auxiliary.Fst padding
, Torch.Typed.Auxiliary.Snd padding
]
, TT.ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0
, TT.ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
)
=> TT.Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> TT.Tensor device dtype '[outputChannelSize]
-> TT.Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> TT.Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dRelaxed :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dRelaxed Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
forall a. IO a -> a
unsafePerformIO (IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
-> IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
forall a b. (a -> b) -> a -> b
$
(ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor))
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> [Int]
-> [Int]
-> [Int]
-> Int
-> IO
(Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1])
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv2d_tttllll
Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight
Tensor device dtype '[outputChannelSize]
bias
([forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
([forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
([Int
1, Int
1] :: [Int])
(Int
1 :: Int)
conv2dForwardRelaxed
:: forall
(stride :: (Nat, Nat))
(padding :: (Nat, Nat))
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
inputSize0
inputSize1
batchSize
outputSize0
outputSize1
dtype
device
. ( TT.All
KnownNat
'[ Torch.Typed.Auxiliary.Fst stride
, Torch.Typed.Auxiliary.Snd stride
, Torch.Typed.Auxiliary.Fst padding
, Torch.Typed.Auxiliary.Snd padding
]
, TT.ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0
, TT.ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
)
=> TT.Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device
-> TT.Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> TT.Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dForwardRelaxed :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dForwardRelaxed TT.Conv2d{Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
Parameter device dtype '[outputChannelSize]
weight :: Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
bias :: Parameter device dtype '[outputChannelSize]
bias :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter device dtype '[outputChannelSize]
weight :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Conv2d
inputChannelSize
outputChannelSize
kernelSize0
kernelSize1
dtype
device
-> Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
..} Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
(inputChannelSize :: Nat) (outputChannelSize :: Nat)
(kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
(inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
(outputSize1 :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
ConvSideCheck
inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
ConvSideCheck
inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
device
dtype
'[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dRelaxed @stride @padding
(Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
TT.toDependent Parameter
device
dtype
'[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight)
(Parameter device dtype '[outputChannelSize]
-> Tensor device dtype '[outputChannelSize]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
TT.toDependent Parameter device dtype '[outputChannelSize]
bias)
Tensor
device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input