| Safe Haskell | None |
|---|---|
| Language | GHC2021 |
RL.A2C
Contents
Synopsis
- gamma :: QType
- lambdaV :: QType
- lambdaP :: QType
- nWorkers :: Int
- printTensors :: forall (dev :: (DeviceType, Nat)). HList (ModelTensors dev) -> IO ()
- printParams :: forall (dev :: (DeviceType, Nat)). HList (ModelParams dev) -> IO ()
- data A2CState (dev :: (DeviceType, Nat)) = A2CState {
- a2cActor :: !(QModel dev)
- a2cCritic :: !(QModel dev)
- a2cOptActor :: !GD
- a2cOptCritic :: !GD
- data A2CStepState (dev :: (DeviceType, Nat)) = A2CStepState {
- a2cStepZV :: !(HList (ModelTensors dev))
- a2cStepZP :: !(HList (ModelTensors dev))
- a2cStepIntensity :: !QType
- a2cStepReward :: !QType
- a2cStepState :: !(GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
- a2cStepActions :: !(NonEmpty PVAction)
- initPieceState :: forall (dev :: (DeviceType, Nat)). KnownDevice dev => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) -> Path [Note SPitch] [Edge SPitch] -> HList (ModelTensors dev) -> Either (A2CStepState dev) QType
- pieceStep :: forall (dev :: (DeviceType, Nat)) label. IsValidDevice dev => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) -> IOGenM StdGen -> PVRewardFn label -> label -> QType -> QType -> Int -> A2CState dev -> A2CStepState dev -> ExceptT String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
- runEpisode :: forall (dev :: (DeviceType, Nat)) label. (GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType, BasicArithmeticDTypeIsValid dev QDType, SumDTypeIsValid dev QDType, MeanDTypeValidation dev QDType, StandardFloatingPointDTypeValidation dev QDType, KnownDevice dev) => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) -> IOGenM StdGen -> PVRewardFn label -> (QType -> QType) -> (QType -> QType) -> Path [Note SPitch] [Edge SPitch] -> label -> A2CState dev -> Int -> IO (Either String (A2CState dev, QType, QType))
- runAccuracy :: forall (dev :: (DeviceType, Nat)) slc' label. IsValidDevice dev => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) slc' (Spread SPitch) (PVLeftmost SPitch) -> PVRewardFn label -> QModel dev -> (Path slc' [Edge SPitch], label) -> IO (Either String (QType, PVAnalysis SPitch))
- data A2CLoopState (dev :: (DeviceType, Nat)) = A2CLoopState {}
- trainA2C :: forall (dev :: (DeviceType, Nat)) label. IsValidDevice dev => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) -> IOGenM StdGen -> PVRewardFn label -> (QType -> QType) -> (QType -> QType) -> Maybe [QType] -> QModel dev -> QModel dev -> [(Path [Note SPitch] [Edge SPitch], label)] -> Int -> IO ([[QType]], [QType], QModel dev, QModel dev)
Documentation
printTensors :: forall (dev :: (DeviceType, Nat)). HList (ModelTensors dev) -> IO () Source #
printParams :: forall (dev :: (DeviceType, Nat)). HList (ModelParams dev) -> IO () Source #
data A2CState (dev :: (DeviceType, Nat)) Source #
Constructors
| A2CState | |
Fields
| |
Instances
data A2CStepState (dev :: (DeviceType, Nat)) Source #
Constructors
| A2CStepState | |
Fields
| |
initPieceState :: forall (dev :: (DeviceType, Nat)). KnownDevice dev => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) -> Path [Note SPitch] [Edge SPitch] -> HList (ModelTensors dev) -> Either (A2CStepState dev) QType Source #
Arguments
| :: forall (dev :: (DeviceType, Nat)) label. IsValidDevice dev | |
| => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) | |
| -> IOGenM StdGen | |
| -> PVRewardFn label | |
| -> label | |
| -> QType | learning rate |
| -> QType | temperature |
| -> Int | iteration |
| -> A2CState dev | |
| -> A2CStepState dev | |
| -> ExceptT String IO (A2CState dev, Either (A2CStepState dev) QType, QType) |
runEpisode :: forall (dev :: (DeviceType, Nat)) label. (GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType, BasicArithmeticDTypeIsValid dev QDType, SumDTypeIsValid dev QDType, MeanDTypeValidation dev QDType, StandardFloatingPointDTypeValidation dev QDType, KnownDevice dev) => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) -> IOGenM StdGen -> PVRewardFn label -> (QType -> QType) -> (QType -> QType) -> Path [Note SPitch] [Edge SPitch] -> label -> A2CState dev -> Int -> IO (Either String (A2CState dev, QType, QType)) Source #
Run an episode
runAccuracy :: forall (dev :: (DeviceType, Nat)) slc' label. IsValidDevice dev => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) slc' (Spread SPitch) (PVLeftmost SPitch) -> PVRewardFn label -> QModel dev -> (Path slc' [Edge SPitch], label) -> IO (Either String (QType, PVAnalysis SPitch)) Source #
data A2CLoopState (dev :: (DeviceType, Nat)) Source #
Constructors
| A2CLoopState | |
Instances
| Generic (A2CLoopState dev) Source # | |||||
Defined in RL.A2C Associated Types
Methods from :: A2CLoopState dev -> Rep (A2CLoopState dev) x # to :: Rep (A2CLoopState dev) x -> A2CLoopState dev # | |||||
| type Rep (A2CLoopState dev) Source # | |||||
Defined in RL.A2C type Rep (A2CLoopState dev) = D1 ('MetaData "A2CLoopState" "RL.A2C" "protovoices-rl-0.1.0.0-JjFFM1P77sPCI8QyjRIHUO" 'False) (C1 ('MetaCons "A2CLoopState" 'PrefixI 'True) ((S1 ('MetaSel ('Just "a2clState") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (A2CState dev)) :*: S1 ('MetaSel ('Just "a2clRewards") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (List (List QType)))) :*: (S1 ('MetaSel ('Just "a2clLosses") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (List (List QType))) :*: S1 ('MetaSel ('Just "a2clAccs") 'NoSourceUnpackedness 'NoSourceStrictness 'DecidedStrict) (Rec0 (List (List QType)))))) | |||||
Arguments
| :: forall (dev :: (DeviceType, Nat)) label. IsValidDevice dev | |
| => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch) | |
| -> IOGenM StdGen | |
| -> PVRewardFn label | |
| -> (QType -> QType) | learning rate schedule |
| -> (QType -> QType) | temperature schedule |
| -> Maybe [QType] | |
| -> QModel dev | |
| -> QModel dev | |
| -> [(Path [Note SPitch] [Edge SPitch], label)] | |
| -> Int | |
| -> IO ([[QType]], [QType], QModel dev, QModel dev) |