{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module RL.TorchHelpers where

import Data.Kind (Type)
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import Torch qualified as T
import Torch qualified as TD
import Torch.Internal.Cast qualified as ATen
import Torch.Internal.Managed.Native qualified as ATen.Managed
import Torch.Internal.Type qualified as ATen
import Torch.Typed qualified as TT
import Torch.Typed.Auxiliary qualified

-- | Helper Type to map sumAll over a HList.
data SumAll = SumAll

instance
  (dtype' ~ TT.SumDType dtype, TT.SumDTypeIsValid dev dtype)
  => TT.Apply' SumAll (TT.Tensor dev dtype shape) (TT.Tensor dev dtype' '[])
  where
  apply' :: SumAll -> Tensor dev dtype shape -> Tensor dev dtype' '[]
apply' SumAll
_ = Tensor dev dtype shape -> Tensor dev dtype' '[]
forall (shape :: [Nat]) (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' '[]
TT.sumAll

-- | Helper Type to fold a HList by adding the values.
data Add = Add

instance
  ( TT.BasicArithmeticDTypeIsValid dev dtype
  , TT.CheckBroadcast
      shape1
      shape2
      ( TT.ComputeBroadcast
          (TT.ReverseImpl shape1 '[])
          (TT.ReverseImpl shape2 '[])
      )
      ~ shapeOut
  )
  => TT.Apply' Add (TT.Tensor dev dtype shape1, TT.Tensor dev dtype shape2) (TT.Tensor dev dtype shapeOut)
  where
  apply' :: Add
-> (Tensor dev dtype shape1, Tensor dev dtype shape2)
-> Tensor dev dtype shapeOut
apply' Add
_ (Tensor dev dtype shape1
a, Tensor dev dtype shape2
b) = Tensor dev dtype shape1
-> Tensor dev dtype shape2 -> Tensor dev dtype shapeOut
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.add Tensor dev dtype shape1
a Tensor dev dtype shape2
b

-- | Helper Type to multiply a HList with a scalar
newtype Mul num
  = Mul num

instance
  (TT.Scalar num)
  => TT.Apply' (Mul num) (TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape)
  where
  apply' :: Mul num -> Tensor dev dtype shape -> Tensor dev dtype shape
apply' (Mul num
n) = num -> Tensor dev dtype shape -> Tensor dev dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar num
n

newtype Mul' dev dtype
  = Mul' (TT.Tensor dev dtype '[])

instance
  (shape ~ TT.Broadcast '[] shape, TT.BasicArithmeticDTypeIsValid dev dtype)
  => TT.Apply' (Mul' dev dtype) (TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape)
  where
  apply' :: Mul' dev dtype -> Tensor dev dtype shape -> Tensor dev dtype shape
apply' (Mul' Tensor dev dtype '[]
n) = Tensor dev dtype '[]
-> Tensor dev dtype shape -> Tensor dev dtype shape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
TT.mul Tensor dev dtype '[]
n

-- | Detach a typed tensor.
detach :: TT.Tensor dev dtype shape -> IO (TT.Tensor dev dtype shape)
detach :: forall (dev :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor dev dtype shape -> IO (Tensor dev dtype shape)
detach = (Tensor -> Tensor dev dtype shape)
-> IO Tensor -> IO (Tensor dev dtype shape)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tensor -> Tensor dev dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (IO Tensor -> IO (Tensor dev dtype shape))
-> (Tensor dev dtype shape -> IO Tensor)
-> Tensor dev dtype shape
-> IO (Tensor dev dtype shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> IO Tensor
TD.detach (Tensor -> IO Tensor)
-> (Tensor dev dtype shape -> Tensor)
-> Tensor dev dtype shape
-> IO Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor dev dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic

-- | Helper type for combining detach and 'TT.Apply''.
data Detach = Detach

instance TT.Apply' Detach (TT.Tensor dev dtype shape) (IO (TT.Tensor dev dtype shape)) where
  apply' :: Detach -> Tensor dev dtype shape -> IO (Tensor dev dtype shape)
apply' Detach
_ = Tensor dev dtype shape -> IO (Tensor dev dtype shape)
forall (dev :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor dev dtype shape -> IO (Tensor dev dtype shape)
detach

-- | Helper Type for interpolating qnet parameters.
newtype Interpolate num = Interpolate num

instance
  ( TT.Scalar num
  , Num num
  , TT.BasicArithmeticDTypeIsValid dev dtype
  , TT.CheckBroadcast
      shape
      shape
      ( TT.ComputeBroadcast
          (TT.ReverseImpl shape '[])
          (TT.ReverseImpl shape '[])
      )
      ~ shape
  )
  => TT.Apply' (Interpolate num) (TT.Tensor dev dtype shape, TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape)
  where
  apply' :: Interpolate num
-> (Tensor dev dtype shape, Tensor dev dtype shape)
-> Tensor dev dtype shape
apply' (Interpolate num
tau) (Tensor dev dtype shape
p, Tensor dev dtype shape
t) = num -> Tensor dev dtype shape -> Tensor dev dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar num
tau Tensor dev dtype shape
p Tensor dev dtype shape
-> Tensor dev dtype shape -> Tensor dev dtype shape
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.add` num -> Tensor dev dtype shape -> Tensor dev dtype shape
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar (num
1 num -> num -> num
forall a. Num a => a -> a -> a
- num
tau) Tensor dev dtype shape
t

-- | Helper Type for getting the number of parameters in a model
data ShapeVal = ShapeVal

instance (TT.KnownShape shape) => TT.Apply' ShapeVal (TT.Tensor dev dtype shape) [Int] where
  apply' :: ShapeVal -> Tensor dev dtype shape -> [Int]
apply' ShapeVal
_ Tensor dev dtype shape
t = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @shape

instance (TT.KnownShape shape) => TT.Apply' ShapeVal (TT.Parameter dev dtype shape) [Int] where
  apply' :: ShapeVal -> Parameter dev dtype shape -> [Int]
apply' ShapeVal
_ Parameter dev dtype shape
t = forall (shape :: [Nat]). KnownShape shape => [Int]
TT.shapeVal @shape

-- | Helper Type for getting a list out of a HList
data ToList = ToList

instance TT.Apply' ToList (t, [t]) [t] where
  apply' :: ToList -> (t, [t]) -> [t]
apply' ToList
_ (t
x, [t]
xs) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
xs

-- -- | Helper Type for getting zeros like the parameters of a model
-- data ZerosLike = ZerosLike

-- instance TT.Apply' ZerosLike (TT.Tensor dev dtype shape) (TT.Tensor dev dtype shape) where
--   apply' _ = TT.zerosLike

type family ToModelTensors (params :: [Type]) :: [Type] where
  ToModelTensors '[] = '[]
  ToModelTensors (TT.Parameter dev dtype shape ': rst) = TT.Tensor dev dtype shape : ToModelTensors rst

-- | Run a batched operation in an unbatched context
withBatchDim
  :: forall dev1 dtype1 shape1 dev2 dtype2 shape2
   . (TT.Tensor dev1 dtype1 (1 : shape1) -> TT.Tensor dev2 dtype2 (1 : shape2))
  -> TT.Tensor dev1 dtype1 shape1
  -> TT.Tensor dev2 dtype2 shape2
withBatchDim :: forall (dev1 :: (DeviceType, Nat)) (dtype1 :: DType)
       (shape1 :: [Nat]) (dev2 :: (DeviceType, Nat)) (dtype2 :: DType)
       (shape2 :: [Nat]).
(Tensor dev1 dtype1 (1 : shape1)
 -> Tensor dev2 dtype2 (1 : shape2))
-> Tensor dev1 dtype1 shape1 -> Tensor dev2 dtype2 shape2
withBatchDim Tensor dev1 dtype1 (1 : shape1) -> Tensor dev2 dtype2 (1 : shape2)
op Tensor dev1 dtype1 shape1
input = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ SqueezeDim shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeDim @0 (Tensor dev2 dtype2 (1 : shape2) -> Tensor dev2 dtype2 shape2)
-> Tensor dev2 dtype2 (1 : shape2) -> Tensor dev2 dtype2 shape2
forall a b. (a -> b) -> a -> b
$ Tensor dev1 dtype1 (1 : shape1) -> Tensor dev2 dtype2 (1 : shape2)
op Tensor dev1 dtype1 (1 : shape1)
batchedIn
 where
  batchedIn :: TT.Tensor dev1 dtype1 (1 : shape1)
  batchedIn :: Tensor dev1 dtype1 (1 : shape1)
batchedIn = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.unsqueeze @0 Tensor dev1 dtype1 shape1
input

-- | conv2d with dropped batch size constraint
conv2dRelaxed
  :: forall
    (stride :: (Nat, Nat))
    (padding :: (Nat, Nat))
    inputChannelSize
    outputChannelSize
    kernelSize0
    kernelSize1
    inputSize0
    inputSize1
    batchSize
    outputSize0
    outputSize1
    dtype
    device
   . ( TT.All
        KnownNat
        '[ Torch.Typed.Auxiliary.Fst stride
         , Torch.Typed.Auxiliary.Snd stride
         , Torch.Typed.Auxiliary.Fst padding
         , Torch.Typed.Auxiliary.Snd padding
         ]
     , TT.ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0
     , TT.ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
     )
  => TT.Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
  -- ^ weight
  -> TT.Tensor device dtype '[outputChannelSize]
  -- ^ bias
  -> TT.Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
  -- ^ input
  -> TT.Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
  -- ^ output
conv2dRelaxed :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dRelaxed Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight Tensor device dtype '[outputChannelSize]
bias Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
  IO
  (Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1])
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
forall a. IO a -> a
unsafePerformIO (IO
   (Tensor
      device
      dtype
      '[batchSize, outputChannelSize, outputSize0, outputSize1])
 -> Tensor
      device
      dtype
      '[batchSize, outputChannelSize, outputSize0, outputSize1])
-> IO
     (Tensor
        device
        dtype
        '[batchSize, outputChannelSize, outputSize0, outputSize1])
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
forall a b. (a -> b) -> a -> b
$
    (ForeignPtr Tensor
 -> ForeignPtr Tensor
 -> ForeignPtr Tensor
 -> ForeignPtr IntArray
 -> ForeignPtr IntArray
 -> ForeignPtr IntArray
 -> Int64
 -> IO (ForeignPtr Tensor))
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> [Int]
-> [Int]
-> [Int]
-> Int
-> IO
     (Tensor
        device
        dtype
        '[batchSize, outputChannelSize, outputSize0, outputSize1])
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 x5 cx5 x6 cx6 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable x5 cx5, Castable x6 cx6,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> cx5 -> cx6 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> x5 -> x6 -> IO y
ATen.cast7
      ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr Tensor
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> ForeignPtr IntArray
-> Int64
-> IO (ForeignPtr Tensor)
ATen.Managed.conv2d_tttllll
      Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input
      Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight
      Tensor device dtype '[outputChannelSize]
bias
      ([forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Fst stride), forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Snd stride)] :: [Int])
      ([forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Fst padding), forall (n :: Nat). KnownNat n => Int
TT.natValI @(Torch.Typed.Auxiliary.Snd padding)] :: [Int])
      ([Int
1, Int
1] :: [Int])
      (Int
1 :: Int)

conv2dForwardRelaxed
  :: forall
    (stride :: (Nat, Nat))
    (padding :: (Nat, Nat))
    inputChannelSize
    outputChannelSize
    kernelSize0
    kernelSize1
    inputSize0
    inputSize1
    batchSize
    outputSize0
    outputSize1
    dtype
    device
   . ( TT.All
        KnownNat
        '[ Torch.Typed.Auxiliary.Fst stride
         , Torch.Typed.Auxiliary.Snd stride
         , Torch.Typed.Auxiliary.Fst padding
         , Torch.Typed.Auxiliary.Snd padding
         ]
     , TT.ConvSideCheck inputSize0 kernelSize0 (Torch.Typed.Auxiliary.Fst stride) (Torch.Typed.Auxiliary.Fst padding) outputSize0
     , TT.ConvSideCheck inputSize1 kernelSize1 (Torch.Typed.Auxiliary.Snd stride) (Torch.Typed.Auxiliary.Snd padding) outputSize1
     )
  => TT.Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device
  -> TT.Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
  -> TT.Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dForwardRelaxed :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dForwardRelaxed TT.Conv2d{Parameter
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
Parameter device dtype '[outputChannelSize]
weight :: Parameter
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
bias :: Parameter device dtype '[outputChannelSize]
bias :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter device dtype '[outputChannelSize]
weight :: forall (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Conv2d
  inputChannelSize
  outputChannelSize
  kernelSize0
  kernelSize1
  dtype
  device
-> Parameter
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
..} Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input =
  forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat))
       (inputChannelSize :: Nat) (outputChannelSize :: Nat)
       (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat)
       (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat)
       (outputSize1 :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding],
 ConvSideCheck
   inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0,
 ConvSideCheck
   inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) =>
Tensor
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor device dtype '[outputChannelSize]
-> Tensor
     device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
-> Tensor
     device
     dtype
     '[batchSize, outputChannelSize, outputSize0, outputSize1]
conv2dRelaxed @stride @padding
    (Parameter
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
-> Tensor
     device
     dtype
     '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
TT.toDependent Parameter
  device
  dtype
  '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]
weight)
    (Parameter device dtype '[outputChannelSize]
-> Tensor device dtype '[outputChannelSize]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
TT.toDependent Parameter device dtype '[outputChannelSize]
bias)
    Tensor
  device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]
input