{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
module RL.ModelTypes where
import GreedyParser (Action, ActionDouble (ActionDouble), ActionSingle (ActionSingle), DoubleParent (DoubleParent), GreedyState, SingleParent (SingleParent), gsOps, opGoesLeft)
import Musicology.Pitch (SPitch)
import PVGrammar
import Control.DeepSeq
import Data.Kind (Type)
import Data.List.NonEmpty qualified as NE
import Data.Proxy (Proxy (Proxy))
import Data.TypeNums (Nat, TInt (..), type (*), type (+))
import GHC.Generics (Generic)
import NoThunks.Class (NoThunks (..), OnlyCheckWhnf (..), allNoThunks)
import Torch qualified as T
import Torch.Lens qualified
import Torch.Typed qualified as TT
type QDType = TT.Double
type IsValidDevice dev =
( TT.GeluDTypeIsValid dev QDType
, TT.RandDTypeIsValid dev QDType
, TT.BasicArithmeticDTypeIsValid dev QDType
, TT.SumDTypeIsValid dev QDType
, TT.MeanDTypeValidation dev QDType
, TT.StandardFloatingPointDTypeValidation dev QDType
, TT.KnownDevice dev
)
type QType = Double
inf :: QType
inf :: QType
inf = QType
1 QType -> QType -> QType
forall a. Fractional a => a -> a -> a
/ QType
0
qDType :: TT.DType
qDType :: DType
qDType = DType
T.Double
type QTensor device shape = TT.Tensor device QDType shape
opts :: forall dev. (TT.KnownDevice dev) => T.TensorOptions
opts :: forall (dev :: (DeviceType, Nat)). KnownDevice dev => TensorOptions
opts = Device -> TensorOptions -> TensorOptions
T.withDevice Device
dev (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
T.withDType DType
qDType (TensorOptions -> TensorOptions) -> TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ TensorOptions
T.defaultOpts
where
dev :: Device
dev = forall (device :: (DeviceType, Nat)). KnownDevice device => Device
TT.deviceVal @dev
toOpts :: forall dev a. (TT.KnownDevice dev, Torch.Lens.HasTypes a T.Tensor) => a -> a
toOpts :: forall (dev :: (DeviceType, Nat)) a.
(KnownDevice dev, HasTypes a Tensor) =>
a -> a
toOpts = Device -> a -> a
forall a. HasTypes a Tensor => Device -> a -> a
T.toDevice Device
device (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DType -> a -> a
forall a. HasTypes a Tensor => DType -> a -> a
T.toType DType
qDType
where
device :: Device
device = forall (device :: (DeviceType, Nat)). KnownDevice device => Device
TT.deviceVal @dev
toQTensor' :: forall dev. (TT.KnownDevice dev) => QType -> T.Tensor
toQTensor' :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType -> Tensor
toQTensor' QType
a = QType -> TensorOptions -> Tensor
forall a opt.
(TensorLike a, TensorOptionLike opt) =>
a -> opt -> Tensor
forall opt. TensorOptionLike opt => QType -> opt -> Tensor
T.asTensor' QType
a (TensorOptions -> Tensor) -> TensorOptions -> Tensor
forall a b. (a -> b) -> a -> b
$ forall (dev :: (DeviceType, Nat)). KnownDevice dev => TensorOptions
opts @dev
toQTensor :: forall dev. (TT.KnownDevice dev) => QType -> QTensor dev '[]
toQTensor :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType -> QTensor dev '[]
toQTensor = Tensor -> Tensor dev QDType '[]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor dev QDType '[])
-> (QType -> Tensor) -> QType -> Tensor dev QDType '[]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType -> Tensor
toQTensor' @dev
type FakeSize = 1337 :: Nat
type MaxPitches = 8 :: Nat
type MaxEdges = 8 :: Nat
type PVAction = Action (Notes SPitch) (Edges SPitch) (Split SPitch) (Freeze SPitch) (Spread SPitch)
type PVState = GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
type PVActionResult = Either PVState (Edges SPitch, [PVLeftmost SPitch])
type PVRewardFn label = PVActionResult -> Maybe (NE.NonEmpty PVAction) -> PVAction -> label -> IO QType
type CommonHiddenSize = 8
type FifthLow = Neg 3
type FifthPadding = 6
type OctaveLow = (Pos 2)
type OctavePadding = 2
type EmbSize = CommonHiddenSize
type FifthSize = (2 * FifthPadding) + 1
type OctaveSize = (2 * OctavePadding) + 1
type PShape = '[FifthSize, OctaveSize]
type PSize = FifthSize + OctaveSize
type EmbShape = EmbSize ': PShape
type ESize = PSize + PSize
type EShape' = '[FakeSize, ESize]
type QOutHidden = CommonHiddenSize
type QSliceHidden = CommonHiddenSize
type QTransHidden = CommonHiddenSize
type QActionHidden = CommonHiddenSize
type QStateHidden = CommonHiddenSize
deriving instance Generic (TT.Tensor dev dtype shape)
deriving via
OnlyCheckWhnf T.Tensor
instance
NoThunks T.Tensor
deriving instance NoThunks (TT.Tensor dev dtype shape)
deriving instance NFData (TT.Tensor dev dtype shape)
deriving newtype instance NoThunks T.IndependentTensor
deriving newtype instance NFData T.IndependentTensor
deriving instance Generic (TT.Parameter dev dtype shape)
deriving newtype instance NoThunks (TT.Parameter dev dtype shape)
deriving newtype instance NFData (TT.Parameter dev dtype shape)
deriving instance NoThunks (TT.Linear nin nout dtype dev)
deriving instance NFData (TT.Linear nin nout dtype dev)
deriving instance NoThunks (TT.Conv2d cin cout k0 k1 dtype dev)
deriving instance NFData (TT.Conv2d cin cout k0 k1 dtype dev)
deriving instance NoThunks (TT.LayerNorm shape dtype dev)
deriving instance NFData (TT.LayerNorm shape dtype dev)
instance NoThunks (TT.HList '[]) where
showTypeOf :: Proxy (HList '[]) -> String
showTypeOf Proxy (HList '[])
_ = String
"HNil"
wNoThunks :: Context -> HList '[] -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt HList '[]
R:HListk[] k
TT.HNil = Maybe ThunkInfo -> IO (Maybe ThunkInfo)
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Maybe ThunkInfo
forall a. Maybe a
Nothing
instance (NoThunks x, NoThunks (TT.HList xs)) => NoThunks (TT.HList (x : (xs :: [Type]))) where
showTypeOf :: Proxy (HList (x : xs)) -> String
showTypeOf Proxy (HList (x : xs))
_ = String
"HCons " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Proxy x -> String
forall a. NoThunks a => Proxy a -> String
showTypeOf (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @x)
wNoThunks :: Context -> HList (x : xs) -> IO (Maybe ThunkInfo)
wNoThunks Context
ctxt (x
x TT.:. HList xs
xs) = [IO (Maybe ThunkInfo)] -> IO (Maybe ThunkInfo)
allNoThunks [Context -> x -> IO (Maybe ThunkInfo)
forall a. NoThunks a => Context -> a -> IO (Maybe ThunkInfo)
noThunks Context
ctxt x
x, Context -> HList xs -> IO (Maybe ThunkInfo)
forall a. NoThunks a => Context -> a -> IO (Maybe ThunkInfo)
noThunks Context
ctxt HList xs
xs]
instance NFData (TT.HList '[]) where
rnf :: HList '[] -> ()
rnf HList '[]
R:HListk[] k
TT.HNil = ()
instance (NFData x, NFData (TT.HList xs)) => NFData (TT.HList (x : xs :: [Type])) where
rnf :: HList (x : xs) -> ()
rnf (x
x TT.:. HList xs
xs) = x -> () -> ()
forall a b. NFData a => a -> b -> b
deepseq x
x (() -> ()) -> () -> ()
forall a b. (a -> b) -> a -> b
$ HList xs -> ()
forall a. NFData a => a -> ()
rnf HList xs
xs
deriving instance Generic (TT.Adam momenta)
deriving instance (NoThunks (TT.HList momenta)) => NoThunks (TT.Adam momenta)
deriving instance (NFData (TT.HList momenta)) => NFData (TT.Adam momenta)
deriving instance Generic TT.GD
deriving instance NoThunks TT.GD
deriving instance NFData TT.GD