protovoices-rl-0.1.0.0
Safe HaskellNone
LanguageGHC2021

RL.TorchHelpers

Synopsis

Documentation

data SumAll Source #

Helper Type to map sumAll over a HList.

Constructors

SumAll 

Instances

Instances details
(dtype' ~ SumDType dtype, SumDTypeIsValid dev dtype) => Apply' SumAll (Tensor dev dtype shape) (Tensor dev dtype' ('[] :: [Nat])) Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: SumAll -> Tensor dev dtype shape -> Tensor dev dtype' ('[] :: [Nat]) #

data Add Source #

Helper Type to fold a HList by adding the values.

Constructors

Add 

Instances

Instances details
(BasicArithmeticDTypeIsValid dev dtype, CheckBroadcast shape1 shape2 (ComputeBroadcast (ReverseImpl shape1 ('[] :: [Nat])) (ReverseImpl shape2 ('[] :: [Nat]))) ~ shapeOut) => Apply' Add (Tensor dev dtype shape1, Tensor dev dtype shape2) (Tensor dev dtype shapeOut) Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: Add -> (Tensor dev dtype shape1, Tensor dev dtype shape2) -> Tensor dev dtype shapeOut #

newtype Mul num Source #

Helper Type to multiply a HList with a scalar

Constructors

Mul num 

Instances

Instances details
Scalar num => Apply' (Mul num) (Tensor dev dtype shape) (Tensor dev dtype shape) Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: Mul num -> Tensor dev dtype shape -> Tensor dev dtype shape #

newtype Mul' (dev :: (DeviceType, Nat)) (dtype :: DType) Source #

Constructors

Mul' (Tensor dev dtype ('[] :: [Nat])) 

Instances

Instances details
(shape ~ Broadcast ('[] :: [Nat]) shape, BasicArithmeticDTypeIsValid dev dtype) => Apply' (Mul' dev dtype) (Tensor dev dtype shape) (Tensor dev dtype shape) Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: Mul' dev dtype -> Tensor dev dtype shape -> Tensor dev dtype shape #

detach :: forall (dev :: (DeviceType, Nat)) (dtype :: DType) (shape :: [Nat]). Tensor dev dtype shape -> IO (Tensor dev dtype shape) Source #

Detach a typed tensor.

data Detach Source #

Helper type for combining detach and Apply'.

Constructors

Detach 

Instances

Instances details
Apply' Detach (Tensor dev dtype shape) (IO (Tensor dev dtype shape)) Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: Detach -> Tensor dev dtype shape -> IO (Tensor dev dtype shape) #

newtype Interpolate num Source #

Helper Type for interpolating qnet parameters.

Constructors

Interpolate num 

Instances

Instances details
(Scalar num, Num num, BasicArithmeticDTypeIsValid dev dtype, CheckBroadcast shape shape (ComputeBroadcast (ReverseImpl shape ('[] :: [Nat])) (ReverseImpl shape ('[] :: [Nat]))) ~ shape) => Apply' (Interpolate num) (Tensor dev dtype shape, Tensor dev dtype shape) (Tensor dev dtype shape) Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: Interpolate num -> (Tensor dev dtype shape, Tensor dev dtype shape) -> Tensor dev dtype shape #

data ShapeVal Source #

Helper Type for getting the number of parameters in a model

Constructors

ShapeVal 

Instances

Instances details
KnownShape shape => Apply' ShapeVal (Parameter dev dtype shape) [Int] Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: ShapeVal -> Parameter dev dtype shape -> [Int] #

KnownShape shape => Apply' ShapeVal (Tensor dev dtype shape) [Int] Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: ShapeVal -> Tensor dev dtype shape -> [Int] #

data ToList Source #

Helper Type for getting a list out of a HList

Constructors

ToList 

Instances

Instances details
Apply' ToList (t, [t]) [t] Source # 
Instance details

Defined in RL.TorchHelpers

Methods

apply' :: ToList -> (t, [t]) -> [t] #

type family ToModelTensors (params :: [Type]) :: [Type] where ... Source #

Equations

ToModelTensors ('[] :: [Type]) = '[] :: [Type] 
ToModelTensors (Parameter dev dtype shape ': rst) = Tensor dev dtype shape ': ToModelTensors rst 

withBatchDim :: forall (dev1 :: (DeviceType, Nat)) (dtype1 :: DType) (shape1 :: [Natural]) (dev2 :: (DeviceType, Nat)) (dtype2 :: DType) (shape2 :: [Natural]). (Tensor dev1 dtype1 (1 ': shape1) -> Tensor dev2 dtype2 (1 ': shape2)) -> Tensor dev1 dtype1 shape1 -> Tensor dev2 dtype2 shape2 Source #

Run a batched operation in an unbatched context

conv2dRelaxed Source #

Arguments

:: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat)) (inputChannelSize :: Nat) (outputChannelSize :: Nat) (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)). (All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding], ConvSideCheck inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0, ConvSideCheck inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) 
=> Tensor device dtype '[outputChannelSize, inputChannelSize, kernelSize0, kernelSize1]

weight

-> Tensor device dtype '[outputChannelSize]

bias

-> Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1]

input

-> Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1]

output

conv2d with dropped batch size constraint

conv2dForwardRelaxed :: forall (stride :: (Nat, Nat)) (padding :: (Nat, Nat)) (inputChannelSize :: Nat) (outputChannelSize :: Nat) (kernelSize0 :: Nat) (kernelSize1 :: Nat) (inputSize0 :: Nat) (inputSize1 :: Nat) (batchSize :: Nat) (outputSize0 :: Nat) (outputSize1 :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)). (All KnownNat '[Fst stride, Snd stride, Fst padding, Snd padding], ConvSideCheck inputSize0 kernelSize0 (Fst stride) (Fst padding) outputSize0, ConvSideCheck inputSize1 kernelSize1 (Snd stride) (Snd padding) outputSize1) => Conv2d inputChannelSize outputChannelSize kernelSize0 kernelSize1 dtype device -> Tensor device dtype '[batchSize, inputChannelSize, inputSize0, inputSize1] -> Tensor device dtype '[batchSize, outputChannelSize, outputSize0, outputSize1] Source #