{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module RL.A2C where
import Common
import GreedyParser
import PVGrammar
import PVGrammar.Prob.Simple (PVParams)
import RL.A2CHelpers
import RL.Encoding
import RL.Model
import RL.ModelTypes
import RL.Plotting
import RL.TorchHelpers
import Control.DeepSeq (NFData, force)
import Control.Foldl qualified as Foldl
import Control.Monad (foldM, forM, forM_, when)
import Control.Monad.Except qualified as ET
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except qualified as ET
import Data.Either (partitionEithers)
import Data.Foldable qualified as F
import Data.List qualified as L
import Data.List.NonEmpty qualified as NE
import Data.Maybe (mapMaybe)
import Data.Text.Lazy qualified as Txt
import Data.Vector qualified as V
import Debug.Trace qualified as DT
import GHC.Generics
import Inference.Conjugate (Hyper)
import Musicology.Pitch (SPitch)
import NoThunks.Class (NoThunks (noThunks), ThunkInfo (thunkContext))
import StrictList qualified as SL
import System.IO (hFlush, stdout)
import System.Mem (performGC)
import System.ProgressBar qualified as PB
import System.Random.MWC.Distributions (categorical)
import System.Random.Stateful (StatefulGen)
import System.Random.Stateful qualified as Rand
import Torch qualified as T
import Torch.Typed qualified as TT
import Torch.Typed.Optim.CppOptim qualified as TT
import Torch.Typed.Optim.CppOptim qualified as TTC
gamma :: QType
gamma :: QType
gamma = QType
0.99
lambdaV :: QType
lambdaV :: QType
lambdaV = QType
0.3
lambdaP :: QType
lambdaP :: QType
lambdaP = QType
0.3
nWorkers :: Int
nWorkers :: Int
nWorkers = Int
2
printTensors :: TT.HList (ModelTensors dev) -> IO ()
printTensors :: forall (dev :: (DeviceType, Nat)).
HList (ModelTensors dev) -> IO ()
printTensors (Tensor dev 'Double '[8, 1, 1, 1]
_ TT.:. Tensor dev 'Double '[8]
t TT.:. HList
'[Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
_) = Tensor dev 'Double '[8] -> IO ()
forall a. Show a => a -> IO ()
print Tensor dev 'Double '[8]
t
printParams :: TT.HList (ModelParams dev) -> IO ()
printParams :: forall (dev :: (DeviceType, Nat)). HList (ModelParams dev) -> IO ()
printParams (Parameter dev 'Double '[8, 1, 1, 1]
_ TT.:. Parameter dev 'Double '[8]
t TT.:. HList
'[Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
_) = Parameter dev 'Double '[8] -> IO ()
forall a. Show a => a -> IO ()
print Parameter dev 'Double '[8]
t
data A2CState dev = A2CState
{ forall (dev :: (DeviceType, Nat)). A2CState dev -> QModel dev
a2cActor :: !(QModel dev)
, forall (dev :: (DeviceType, Nat)). A2CState dev -> QModel dev
a2cCritic :: !(QModel dev)
, forall (dev :: (DeviceType, Nat)). A2CState dev -> GD
a2cOptActor :: !TT.GD
, forall (dev :: (DeviceType, Nat)). A2CState dev -> GD
a2cOptCritic :: !TT.GD
}
deriving ((forall x. A2CState dev -> Rep (A2CState dev) x)
-> (forall x. Rep (A2CState dev) x -> A2CState dev)
-> Generic (A2CState dev)
forall x. Rep (A2CState dev) x -> A2CState dev
forall x. A2CState dev -> Rep (A2CState dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (A2CState dev) x -> A2CState dev
forall (dev :: (DeviceType, Nat)) x.
A2CState dev -> Rep (A2CState dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
A2CState dev -> Rep (A2CState dev) x
from :: forall x. A2CState dev -> Rep (A2CState dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (A2CState dev) x -> A2CState dev
to :: forall x. Rep (A2CState dev) x -> A2CState dev
Generic)
data A2CStepState dev = A2CStepState
{ forall (dev :: (DeviceType, Nat)).
A2CStepState dev -> HList (ModelTensors dev)
a2cStepZV :: !(TT.HList (ModelTensors dev))
, forall (dev :: (DeviceType, Nat)).
A2CStepState dev -> HList (ModelTensors dev)
a2cStepZP :: !(TT.HList (ModelTensors dev))
, forall (dev :: (DeviceType, Nat)). A2CStepState dev -> QType
a2cStepIntensity :: !QType
, forall (dev :: (DeviceType, Nat)). A2CStepState dev -> QType
a2cStepReward :: !QType
, forall (dev :: (DeviceType, Nat)).
A2CStepState dev
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
a2cStepState
:: !( GreedyState
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
(PVLeftmost SPitch)
)
, forall (dev :: (DeviceType, Nat)).
A2CStepState dev -> NonEmpty PVAction
a2cStepActions :: !(NE.NonEmpty PVAction)
}
initPieceState
:: (TT.KnownDevice dev)
=> Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> TT.HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
initPieceState :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
initPieceState Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval Path [Note SPitch] [Edge SPitch]
input HList (ModelTensors dev)
z0 =
let
state :: GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
state = Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
forall tr tr' slc slc' h v op.
Eval tr tr' slc slc' h v
-> Path slc' tr' -> GreedyState tr tr' slc op
initParseState Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval Path [Note SPitch] [Edge SPitch]
input
actions :: [PVAction]
actions = Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
state
in
case [PVAction]
actions of
[] -> QType -> Either (A2CStepState dev) QType
forall a b. b -> Either a b
Right (-QType
inf)
(PVAction
a : [PVAction]
as) -> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. a -> Either a b
Left (A2CStepState dev -> Either (A2CStepState dev) QType)
-> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. (a -> b) -> a -> b
$ HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
forall (dev :: (DeviceType, Nat)).
HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
A2CStepState HList (ModelTensors dev)
z0 HList (ModelTensors dev)
z0 QType
1 QType
0 GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
state (PVAction
a PVAction -> [PVAction] -> NonEmpty PVAction
forall a. a -> [a] -> NonEmpty a
NE.:| [PVAction]
as)
pieceStep
:: forall dev label
. (IsValidDevice dev)
=> Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
-> Rand.IOGenM Rand.StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ET.ExceptT String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
pieceStep :: forall (dev :: (DeviceType, Nat)) label.
IsValidDevice dev =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ExceptT
String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
pieceStep Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward label
len QType
lr QType
temp Int
i (A2CState QModel dev
actor QModel dev
critic GD
opta GD
optc) (A2CStepState HList (ModelTensors dev)
zV HList (ModelTensors dev)
zP QType
intensity QType
reward GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state NonEmpty PVAction
actions) = do
let
policy :: Tensor
policy = QType -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
T.pow (QType
1 QType -> QType -> QType
forall a. Fractional a => a -> a -> a
/ QType
temp) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> Tensor)
-> Tensor
forall (dev :: (DeviceType, Nat)) r.
KnownDevice dev =>
GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state NonEmpty PVAction
actions (QModel dev -> QEncoding dev '[n] -> Tensor
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
actor)
actionIndex <- IO Int -> ExceptT String IO Int
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Int -> ExceptT String IO Int)
-> IO Int -> ExceptT String IO Int
forall a b. (a -> b) -> a -> b
$ Vector QType -> IOGenM StdGen -> IO Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v QType) =>
v QType -> g -> m Int
categorical ([QType] -> Vector QType
forall a. [a] -> Vector a
V.fromList ([QType] -> Vector QType) -> [QType] -> Vector QType
forall a b. (a -> b) -> a -> b
$ Tensor -> [QType]
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> [QType]) -> Tensor -> [QType]
forall a b. (a -> b) -> a -> b
$ DType -> Tensor -> Tensor
T.toDType DType
T.Double Tensor
policy) IOGenM StdGen
gen
let action = NonEmpty PVAction
actions NonEmpty PVAction -> Int -> PVAction
forall a. HasCallStack => NonEmpty a -> Int -> a
NE.!! Int
actionIndex
state' <- ET.except $ applyAction state action
let actions' = case Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch])
state' of
Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState -> [PVAction] -> Maybe (NonEmpty PVAction)
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([PVAction] -> Maybe (NonEmpty PVAction))
-> [PVAction] -> Maybe (NonEmpty PVAction)
forall a b. (a -> b) -> a -> b
$ Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState
Right (Edges SPitch, [PVLeftmost SPitch])
_ -> Maybe (NonEmpty PVAction)
forall a. Maybe a
Nothing
r <- lift $ fReward state' actions' action len
let vS = QModel dev -> StateEncoding dev -> Tensor dev 'Double '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> StateEncoding dev -> QTensor dev '[1]
forwardValue QModel dev
critic (StateEncoding dev -> Tensor dev 'Double '[1])
-> StateEncoding dev -> Tensor dev 'Double '[1]
forall a b. (a -> b) -> a -> b
$ GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
encodePVState GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state
vS' = case Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch])
state' of
Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s' -> QModel dev -> StateEncoding dev -> Tensor dev 'Double '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> StateEncoding dev -> QTensor dev '[1]
forwardValue QModel dev
critic (StateEncoding dev -> Tensor dev 'Double '[1])
-> StateEncoding dev -> Tensor dev 'Double '[1]
forall a b. (a -> b) -> a -> b
$ GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
encodePVState GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s'
Right (Edges SPitch, [PVLeftmost SPitch])
_ -> Tensor dev 'Double '[1]
0
delta = QType -> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.addScalar QType
r (Tensor dev 'Double '[] -> Tensor dev 'Double '[])
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a b. (a -> b) -> a -> b
$ Tensor dev 'Double '[1] -> Tensor dev 'Double '[]
forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeAll (Tensor dev 'Double '[1] -> Tensor dev 'Double '[])
-> Tensor dev 'Double '[1] -> Tensor dev 'Double '[]
forall a b. (a -> b) -> a -> b
$ QType -> Tensor dev 'Double '[1] -> Tensor dev 'Double '[1]
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
gamma Tensor dev 'Double '[1]
vS' Tensor dev 'Double '[1]
-> Tensor dev 'Double '[1] -> Tensor dev 'Double '[1]
forall a. Num a => a -> a -> a
- Tensor dev 'Double '[1]
vS
gradV = Tensor dev 'Double '[]
-> HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
TT.grad (Tensor dev 'Double '[1] -> Tensor dev 'Double '[]
forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeAll Tensor dev 'Double '[1]
vS Tensor dev 'Double '[]
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a. Num a => a -> a -> a
+ QModel dev -> Tensor dev 'Double '[]
forall (dev :: (DeviceType, Nat)) (ps :: [*]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
critic) (QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
critic)
zV' = QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
updateEligCritic QType
gamma QType
lambdaV HList (ModelTensors dev)
zV HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
gradV
actionLogProb :: QTensor dev '[]
actionLogProb = Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (Tensor dev 'Double '[] -> Tensor dev 'Double '[])
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor dev 'Double '[]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor
T.squeezeAll (Tensor
policy Tensor -> Int -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
T.! Int
actionIndex))
gradP = Tensor dev 'Double '[]
-> HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> HList
'[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 13, 5],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
Parameter dev 'Double '[8], Parameter dev 'Double '[5],
Parameter dev 'Double '[5], Parameter dev 'Double '[5],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
Parameter dev 'Double '[8], Parameter dev 'Double '[8],
Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
TT.grad (Tensor dev 'Double '[]
actionLogProb Tensor dev 'Double '[]
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a. Num a => a -> a -> a
+ QModel dev -> Tensor dev 'Double '[]
forall (dev :: (DeviceType, Nat)) (ps :: [*]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
actor) (QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
actor)
zP' = QType
-> QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType
-> QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
updateEligActor QType
gamma QType
lambdaP QType
intensity HList (ModelTensors dev)
zP HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
gradP
intensity' = QType
gamma QType -> QType -> QType
forall a. Num a => a -> a -> a
* QType
intensity
learningRate = QType -> Tensor dev 'Double '[]
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType -> QTensor dev '[]
toQTensor (QType -> QType
forall a. Num a => a -> a
negate QType
lr)
(!actor', !opta') <- lift $ TT.runStep' actor opta learningRate $ mulModelTensors delta zP'
(!critic', !optc') <- lift $ TT.runStep' critic optc learningRate $ mulModelTensors delta zV'
let loss' = Tensor -> QType
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> QType) -> Tensor -> QType
forall a b. (a -> b) -> a -> b
$ Tensor dev 'Double '[] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor dev 'Double '[]
delta
reward' = QType
reward QType -> QType -> QType
forall a. Num a => a -> a -> a
+ QType
r
let pieceState' = case (Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch])
state', Maybe (NonEmpty PVAction)
actions') of
(Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s', Just NonEmpty PVAction
a') -> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. a -> Either a b
Left (A2CStepState dev -> Either (A2CStepState dev) QType)
-> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. (a -> b) -> a -> b
$ HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
forall (dev :: (DeviceType, Nat)).
HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
A2CStepState HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
zV' HList
'[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
Tensor dev 'Double '[8], Tensor dev 'Double '[5],
Tensor dev 'Double '[5], Tensor dev 'Double '[5],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
Tensor dev 'Double '[8], Tensor dev 'Double '[8],
Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
zP' QType
intensity' QType
reward' GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s' NonEmpty PVAction
a'
(Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s', Maybe (NonEmpty PVAction)
Nothing) ->
QType -> Either (A2CStepState dev) QType
forall a b. b -> Either a b
Right QType
reward'
(Right (Edges SPitch, [PVLeftmost SPitch])
_, Maybe (NonEmpty PVAction)
_) -> QType -> Either (A2CStepState dev) QType
forall a b. b -> Either a b
Right QType
reward'
pure (A2CState actor' critic' opta' optc', pieceState', loss')
runEpisode
:: forall dev label
. (_)
=> Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
-> Rand.IOGenM Rand.StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
runEpisode :: Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
runEpisode !Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval !IOGenM StdGen
gen !PVRewardFn label
fReward !QType -> QType
fLr !QType -> QType
fTemp !Path [Note SPitch] [Edge SPitch]
input !label
label !A2CState dev
modelState !Int
i =
case Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
initPieceState Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval Path [Note SPitch] [Edge SPitch]
input HList (ModelTensors dev)
z0 of
Left A2CStepState dev
s0 -> ExceptT String IO (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
ET.runExceptT (ExceptT String IO (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType)))
-> ExceptT String IO (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall a b. (a -> b) -> a -> b
$ A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
forall {dev :: (DeviceType, Nat)}.
(GeluDTypeIsValid dev 'Double, RandDTypeIsValid dev 'Double,
BasicArithmeticDTypeIsValid dev 'Double,
SumDTypeIsValid dev 'Double, MeanDTypeValidation dev 'Double,
StandardFloatingPointDTypeValidation dev 'Double,
KnownDevice dev) =>
A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
go A2CState dev
modelState A2CStepState dev
s0 List QType
forall a. List a
SL.Nil
Right QType
reward -> Either String (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType)))
-> Either String (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall a b. (a -> b) -> a -> b
$ (A2CState dev, QType, QType)
-> Either String (A2CState dev, QType, QType)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
modelState, QType
reward, QType
0)
where
z0 :: TT.HList (ModelTensors dev)
z0 :: HList (ModelTensors dev)
z0 = QModel dev -> HList (ModelTensors dev)
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> HList (ModelTensors dev)
modelZeros (QModel dev -> HList (ModelTensors dev))
-> QModel dev -> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ A2CState dev -> QModel dev
forall (dev :: (DeviceType, Nat)). A2CState dev -> QModel dev
a2cActor A2CState dev
modelState
lr :: QType
lr = QType -> QType
fLr (QType -> QType) -> QType -> QType
forall a b. (a -> b) -> a -> b
$ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
temp :: QType
temp = QType -> QType
fTemp (QType -> QType) -> QType -> QType
forall a b. (a -> b) -> a -> b
$ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
go :: A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
go A2CState dev
modelState A2CStepState dev
pieceState List QType
losses = do
(modelState', pieceState', loss) <- Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ExceptT
String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
forall (dev :: (DeviceType, Nat)) label.
IsValidDevice dev =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ExceptT
String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
pieceStep Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward label
label QType
lr QType
temp Int
i A2CState dev
modelState A2CStepState dev
pieceState
let losses' = QType
loss QType -> List QType -> List QType
forall a. a -> List a -> List a
`SL.Cons` List QType
losses
case pieceState' of
Left A2CStepState dev
ps' -> A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
go A2CState dev
modelState' A2CStepState dev
ps' List QType
losses'
Right QType
reward -> (A2CState dev, QType, QType)
-> ExceptT String IO (A2CState dev, QType, QType)
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
modelState', QType
reward, List QType -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean List QType
losses')
runAccuracy
:: (IsValidDevice dev)
=> Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) slc' (Spread SPitch) (PVLeftmost SPitch)
-> PVRewardFn label
-> QModel dev
-> (Path slc' [Edge SPitch], label)
-> IO (Either String (QType, PVAnalysis SPitch))
runAccuracy :: forall (dev :: (DeviceType, Nat)) slc' label.
IsValidDevice dev =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
-> PVRewardFn label
-> QModel dev
-> (Path slc' [Edge SPitch], label)
-> IO (Either String (QType, PVAnalysis SPitch))
runAccuracy !Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
eval !PVRewardFn label
fReward !QModel dev
actor (!Path slc' [Edge SPitch]
input, !label
label) = case Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
eval GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
s0 of
[] -> Either String (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch)))
-> Either String (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall a b. (a -> b) -> a -> b
$ String -> Either String (QType, PVAnalysis SPitch)
forall a b. a -> Either a b
Left String
"cannot parse: no possible actions for first step!"
(PVAction
a : [PVAction]
as) -> ExceptT String IO (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
ET.runExceptT (ExceptT String IO (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch)))
-> ExceptT String IO (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall a b. (a -> b) -> a -> b
$ Tensor
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT String IO (QType, PVAnalysis SPitch)
forall {slc}.
Tensor
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT
String
IO
(QType,
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
go Tensor
0 QType
0 GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
s0 (PVAction
a PVAction -> [PVAction] -> NonEmpty PVAction
forall a. a -> [a] -> NonEmpty a
NE.:| [PVAction]
as)
where
s0 :: GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
s0 = Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
-> Path slc' [Edge SPitch]
-> GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
forall tr tr' slc slc' h v op.
Eval tr tr' slc slc' h v
-> Path slc' tr' -> GreedyState tr tr' slc op
initParseState Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
eval Path slc' [Edge SPitch]
input
go :: Tensor
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT
String
IO
(QType,
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
go !Tensor
cost !QType
reward !GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state !NonEmpty PVAction
actions = do
let
probs :: Tensor
probs = GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> Tensor)
-> Tensor
forall (dev :: (DeviceType, Nat)) r.
KnownDevice dev =>
GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state NonEmpty PVAction
actions (QModel dev -> QEncoding dev '[n] -> Tensor
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
actor)
best :: Int
best = Tensor -> Int
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> Int) -> Tensor -> Int
forall a b. (a -> b) -> a -> b
$ Dim -> KeepDim -> Tensor -> Tensor
T.argmax (Int -> Dim
T.Dim Int
0) KeepDim
T.KeepDim Tensor
probs
action :: PVAction
action = NonEmpty PVAction
actions NonEmpty PVAction -> Int -> PVAction
forall a. HasCallStack => NonEmpty a -> Int -> a
NE.!! Int
best
bestprob :: Tensor
bestprob = Tensor
probs Tensor -> Int -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
T.! Int
best
cost' :: Tensor
cost' = Tensor
cost Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor -> Tensor
T.log Tensor
bestprob
state' <- Either
String
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch]))
-> ExceptT
String
IO
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch]))
forall (m :: * -> *) e a. Monad m => Either e a -> ExceptT e m a
ET.except (Either
String
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch]))
-> ExceptT
String
IO
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch])))
-> Either
String
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch]))
-> ExceptT
String
IO
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch]))
forall a b. (a -> b) -> a -> b
$ GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> PVAction
-> Either
String
(Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch]))
forall {k1} {k2} (m :: k1) tr tr' slc (slc' :: k2) s f h.
GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h
-> Either
String
(Either
(GreedyState tr tr' slc (Leftmost s f h)) (tr, [Leftmost s f h]))
applyAction GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state PVAction
action
let actions' = case Either
(GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
(Edges SPitch, [PVLeftmost SPitch])
state' of
Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState -> [PVAction] -> Maybe (NonEmpty PVAction)
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([PVAction] -> Maybe (NonEmpty PVAction))
-> [PVAction] -> Maybe (NonEmpty PVAction)
forall a b. (a -> b) -> a -> b
$ Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
slc'
(Spread SPitch)
(PVLeftmost SPitch)
eval GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState
Right (Edges SPitch, [PVLeftmost SPitch])
_ -> Maybe (NonEmpty PVAction)
forall a. Maybe a
Nothing
actionReward <- lift $ fReward state' actions' action label
let reward' = QType
reward QType -> QType -> QType
forall a. Num a => a -> a -> a
+ QType
actionReward
case (state', actions') of
(Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
_, Maybe (NonEmpty PVAction)
Nothing) ->
String
-> ExceptT
String
IO
(QType,
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
ET.throwE String
"cannot parse: no possible actions in non-terminal state!"
(Left GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s', Just NonEmpty PVAction
a') -> Tensor
-> QType
-> GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT
String
IO
(QType,
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
go Tensor
cost' QType
reward' GreedyState
(Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s' NonEmpty PVAction
a'
(Right (Edges SPitch
top, [PVLeftmost SPitch]
deriv), Maybe (NonEmpty PVAction)
_) -> do
IO () -> ExceptT String IO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT String IO ()) -> IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"accuracy cost: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Tensor -> String
forall a. Show a => a -> String
show Tensor
cost'
Bool -> ExceptT String IO () -> ExceptT String IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Tensor -> QType
forall a. TensorLike a => Tensor -> a
T.asValue Tensor
cost QType -> QType -> Bool
forall a. Eq a => a -> a -> Bool
== (QType
0 :: Double)) (ExceptT String IO () -> ExceptT String IO ())
-> ExceptT String IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ do
IO () -> ExceptT String IO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT String IO ()) -> IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Tensor -> String
forall a. Show a => a -> String
show Tensor
bestprob
let ana :: Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
ana = [PVLeftmost SPitch]
-> Path (Edges SPitch) slc
-> Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
forall s f h tr slc.
[Leftmost s f h] -> Path tr slc -> Analysis s f h tr slc
Analysis [PVLeftmost SPitch]
deriv (Edges SPitch -> Path (Edges SPitch) slc
forall around between. around -> Path around between
PathEnd Edges SPitch
top)
(QType,
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
-> ExceptT
String
IO
(QType,
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (QType
reward', Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
forall {slc}.
Analysis
(Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
ana)
deriving instance (NoThunks a) => NoThunks (SL.List a)
data A2CLoopState dev = A2CLoopState
{ forall (dev :: (DeviceType, Nat)). A2CLoopState dev -> A2CState dev
a2clState :: A2CState dev
, forall (dev :: (DeviceType, Nat)).
A2CLoopState dev -> List (List QType)
a2clRewards :: SL.List (SL.List QType)
, forall (dev :: (DeviceType, Nat)).
A2CLoopState dev -> List (List QType)
a2clLosses :: SL.List (SL.List QType)
, forall (dev :: (DeviceType, Nat)).
A2CLoopState dev -> List (List QType)
a2clAccs :: SL.List (SL.List QType)
}
deriving ((forall x. A2CLoopState dev -> Rep (A2CLoopState dev) x)
-> (forall x. Rep (A2CLoopState dev) x -> A2CLoopState dev)
-> Generic (A2CLoopState dev)
forall x. Rep (A2CLoopState dev) x -> A2CLoopState dev
forall x. A2CLoopState dev -> Rep (A2CLoopState dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (A2CLoopState dev) x -> A2CLoopState dev
forall (dev :: (DeviceType, Nat)) x.
A2CLoopState dev -> Rep (A2CLoopState dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
A2CLoopState dev -> Rep (A2CLoopState dev) x
from :: forall x. A2CLoopState dev -> Rep (A2CLoopState dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (A2CLoopState dev) x -> A2CLoopState dev
to :: forall x. Rep (A2CLoopState dev) x -> A2CLoopState dev
Generic)
trainA2C
:: forall dev label
. (IsValidDevice dev)
=> Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
-> Rand.IOGenM Rand.StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Maybe [QType]
-> QModel dev
-> QModel dev
-> [(Path [Note SPitch] [Edge SPitch], label)]
-> Int
-> IO ([[QType]], [QType], QModel dev, QModel dev)
trainA2C :: forall (dev :: (DeviceType, Nat)) label.
IsValidDevice dev =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Maybe [QType]
-> QModel dev
-> QModel dev
-> [(Path [Note SPitch] [Edge SPitch], label)]
-> Int
-> IO ([[QType]], [QType], QModel dev, QModel dev)
trainA2C Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward QType -> QType
fLr QType -> QType
fTemp Maybe [QType]
targets QModel dev
actor0 QModel dev
critic0 [(Path [Note SPitch] [Edge SPitch], label)]
pieces Int
n = do
let
opta :: GD
opta = GD
TT.GD
optc :: GD
optc = GD
TT.GD
emptyStat :: List (List a)
emptyStat = [List a] -> List (List a)
forall a. [a] -> List a
SL.fromListReversed (Int -> List a -> [List a]
forall a. Int -> a -> [a]
replicate ([(Path [Note SPitch] [Edge SPitch], label)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Path [Note SPitch] [Edge SPitch], label)]
pieces) List a
forall a. List a
SL.Nil)
state0 :: A2CState dev
state0 = QModel dev -> QModel dev -> GD -> GD -> A2CState dev
forall (dev :: (DeviceType, Nat)).
QModel dev -> QModel dev -> GD -> GD -> A2CState dev
A2CState QModel dev
actor0 QModel dev
critic0 GD
opta GD
optc
(A2CLoopState (A2CState actorTrained criticTrained _ _) rewards losses accs) <- A2CLoopState dev
-> Int
-> (A2CLoopState dev -> Int -> IO (A2CLoopState dev))
-> IO (A2CLoopState dev)
forall a. a -> Int -> (a -> Int -> IO a) -> IO a
T.foldLoop (A2CState dev
-> List (List QType)
-> List (List QType)
-> List (List QType)
-> A2CLoopState dev
forall (dev :: (DeviceType, Nat)).
A2CState dev
-> List (List QType)
-> List (List QType)
-> List (List QType)
-> A2CLoopState dev
A2CLoopState A2CState dev
state0 List (List QType)
forall {a}. List (List a)
emptyStat List (List QType)
forall {a}. List (List a)
emptyStat List (List QType)
forall {a}. List (List a)
emptyStat) Int
n A2CLoopState dev -> Int -> IO (A2CLoopState dev)
forall {dev :: (DeviceType, Nat)}.
(GeluDTypeIsValid dev 'Double, RandDTypeIsValid dev 'Double,
BasicArithmeticDTypeIsValid dev 'Double,
SumDTypeIsValid dev 'Double, MeanDTypeValidation dev 'Double,
StandardFloatingPointDTypeValidation dev 'Double,
KnownDevice dev) =>
A2CLoopState dev -> Int -> IO (A2CLoopState dev)
trainEpoch
pure
( SL.toListReversed $ SL.toListReversed <$> rewards
, SL.toListReversed $ mean <$> losses
, actorTrained
, criticTrained
)
where
trainPiece :: ProgressBar s
-> Int
-> (A2CState dev, List QType, List QType)
-> ((Path [Note SPitch] [Edge SPitch], label), b)
-> IO (A2CState dev, List QType, List QType)
trainPiece ProgressBar s
pb Int
i (!A2CState dev
state, !List QType
rewards, !List QType
losses) ((!Path [Note SPitch] [Edge SPitch]
piece, label
label), !b
j) = do
!result <- Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
forall (dev :: (DeviceType, Nat)) label.
(GeluDTypeIsValid dev 'Double, RandDTypeIsValid dev 'Double,
BasicArithmeticDTypeIsValid dev 'Double,
SumDTypeIsValid dev 'Double, MeanDTypeValidation dev 'Double,
StandardFloatingPointDTypeValidation dev 'Double,
KnownDevice dev) =>
Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
runEpisode Eval
(Edges SPitch)
[Edge SPitch]
(Notes SPitch)
[Note SPitch]
(Spread SPitch)
(PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward QType -> QType
fLr QType -> QType
fTemp Path [Note SPitch] [Edge SPitch]
piece label
label A2CState dev
state Int
i
PB.incProgress pb 1
case result of
Left String
error -> do
String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Episode error: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
error
(A2CState dev, List QType, List QType)
-> IO (A2CState dev, List QType, List QType)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
state, List QType
rewards, List QType
losses)
Right (A2CState dev
state', QType
r, QType
loss) -> do
(A2CState dev, List QType, List QType)
-> IO (A2CState dev, List QType, List QType)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
state', QType
r QType -> List QType -> List QType
forall a. a -> List a -> List a
`SL.Cons` List QType
rewards, QType
loss QType -> List QType -> List QType
forall a. a -> List a -> List a
`SL.Cons` List QType
losses)
trainEpoch :: A2CLoopState dev -> Int -> IO (A2CLoopState dev)
trainEpoch fullstate :: A2CLoopState dev
fullstate@(A2CLoopState !A2CState dev
state !List (List QType)
rewardsHist !List (List QType)
lossHist !List (List QType)
accuracies) !Int
i = do
pb <-
Style () -> QType -> Progress () -> IO (ProgressBar ())
forall s. Style s -> QType -> Progress s -> IO (ProgressBar s)
PB.newProgressBar
( Style ()
forall s. Style s
PB.defStyle
{ PB.stylePrefix = "Epoch " <> (PB.msg $ Txt.show i) <> ": " <> (PB.elapsedTime PB.renderDuration)
, PB.stylePostfix = PB.exact <> " (" <> PB.percentage <> ")"
, PB.styleWidth = PB.ConstantWidth 80
}
)
QType
10
(Int -> Int -> () -> Progress ()
forall s. Int -> Int -> s -> Progress s
PB.Progress Int
0 ([(Path [Note SPitch] [Edge SPitch], label)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Path [Note SPitch] [Edge SPitch], label)]
pieces) ())
(!state', !rewards, !losses) <-
foldM (trainPiece pb i) (state, SL.Nil, SL.Nil) (zip pieces [1 ..])
let rewardsHist' = (QType -> List QType -> List QType)
-> List QType -> List (List QType) -> List (List QType)
forall a b c. (a -> b -> c) -> List a -> List b -> List c
zipWithStrict QType -> List QType -> List QType
forall a. a -> List a -> List a
SL.Cons List QType
rewards List (List QType)
rewardsHist
lossHist' = (QType -> List QType -> List QType)
-> List QType -> List (List QType) -> List (List QType)
forall a b c. (a -> b -> c) -> List a -> List b -> List c
zipWithStrict QType -> List QType -> List QType
forall a. a -> List a -> List a
SL.Cons List QType
losses List (List QType)
lossHist
accuracies' <-
if (i `mod` 10) == 0
then do
results <- mapM (runAccuracy eval fReward (a2cActor state)) pieces
newAccs <- forM (zip results [1 ..]) $ \(Either String (QType, PVAnalysis SPitch)
result, Integer
j) ->
case Either String (QType, PVAnalysis SPitch)
result of
Left String
error -> do
String -> IO ()
putStrLn String
error
QType -> IO QType
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (-QType
inf)
Right (QType
acc, Analysis [PVLeftmost SPitch]
deriv Path (Edges SPitch) (Notes SPitch)
top) -> do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
100) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"current best analysis (piece " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
j String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"):"
(PVLeftmost SPitch -> IO ()) -> [PVLeftmost SPitch] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PVLeftmost SPitch -> IO ()
forall a. Show a => a -> IO ()
print [PVLeftmost SPitch]
deriv
QType -> IO QType
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure QType
acc
pure $ zipWithStrict SL.Cons (SL.fromListReversed newAccs) accuracies
else pure accuracies
when ((i `mod` 1) == 0) $ do
let rews = List QType -> [QType]
forall a. List a -> [a]
SL.toListReversed (List QType -> [QType]) -> [List QType] -> [[QType]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> List (List QType) -> [List QType]
forall a. List a -> [a]
SL.toListReversed List (List QType)
rewardsHist'
accs = List QType -> [QType]
forall a. List a -> [a]
SL.toListReversed (List QType -> [QType]) -> [List QType] -> [[QType]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> List (List QType) -> [List QType]
forall a. List a -> [a]
SL.toListReversed List (List QType)
accuracies'
losses = List QType -> [QType]
forall a. List a -> [a]
SL.toListReversed (List QType -> [QType]) -> [List QType] -> [[QType]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> List (List QType) -> [List QType]
forall a. List a -> [a]
SL.toListReversed List (List QType)
lossHist'
avgReward = [QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean ([QType] -> QType) -> [[QType]] -> [QType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[QType]] -> [[QType]]
forall a. [[a]] -> [[a]]
L.transpose [[QType]]
rews
avgAbsLoss = ([QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean ([QType] -> QType) -> ([QType] -> [QType]) -> [QType] -> QType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (QType -> QType) -> [QType] -> [QType]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap QType -> QType
forall a. Num a => a -> a
abs) ([QType] -> QType) -> [[QType]] -> [QType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[QType]] -> [[QType]]
forall a. [[a]] -> [[a]]
L.transpose [[QType]]
losses
plotHistories "losses" losses
case targets of
Maybe [QType]
Nothing -> do
String -> [[QType]] -> IO ()
plotHistories String
"rewards" [[QType]]
rews
String -> [[QType]] -> IO ()
plotHistories String
"accuracy" [[QType]]
accs
Just [QType]
ts -> do
String -> [QType] -> [[QType]] -> IO ()
plotHistories' String
"rewards" [QType]
ts [[QType]]
rews
String -> [QType] -> [[QType]] -> IO ()
plotHistories' String
"accuracy" [QType]
ts [[QType]]
accs
plotHistory "mean_reward" avgReward
plotHistory "mean_loss" avgAbsLoss
TT.save (TT.hmap' TT.ToDependent $ TT.flattenParameters $ a2cActor state) "actor_checkpoint.ht"
TT.save (TT.hmap' TT.ToDependent $ TT.flattenParameters $ a2cCritic state) "critic_checkpoint.ht"
pure $ A2CLoopState state' rewardsHist' lossHist' accuracies'