{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module RL.A2C where

import Common
import GreedyParser
import PVGrammar
import PVGrammar.Prob.Simple (PVParams)
import RL.A2CHelpers
import RL.Encoding
import RL.Model
import RL.ModelTypes
import RL.Plotting
import RL.TorchHelpers

import Control.DeepSeq (NFData, force)
import Control.Foldl qualified as Foldl
import Control.Monad (foldM, forM, forM_, when)
import Control.Monad.Except qualified as ET
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except qualified as ET
import Data.Either (partitionEithers)
import Data.Foldable qualified as F
import Data.List qualified as L
import Data.List.NonEmpty qualified as NE
import Data.Maybe (mapMaybe)
import Data.Text.Lazy qualified as Txt
import Data.Vector qualified as V
import Debug.Trace qualified as DT
import GHC.Generics
import Inference.Conjugate (Hyper)
import Musicology.Pitch (SPitch)
import NoThunks.Class (NoThunks (noThunks), ThunkInfo (thunkContext))
import StrictList qualified as SL
import System.IO (hFlush, stdout)
import System.Mem (performGC)
import System.ProgressBar qualified as PB
import System.Random.MWC.Distributions (categorical)
import System.Random.Stateful (StatefulGen)
import System.Random.Stateful qualified as Rand
import Torch qualified as T
import Torch.Typed qualified as TT
import Torch.Typed.Optim.CppOptim qualified as TT
import Torch.Typed.Optim.CppOptim qualified as TTC

-- global settings
-- ===============

-- discount factor
gamma :: QType
gamma :: QType
gamma = QType
0.99

-- eligibility decay factor (values)
lambdaV :: QType
lambdaV :: QType
lambdaV = QType
0.3

-- eligibility decay factor (policy)
lambdaP :: QType
lambdaP :: QType
lambdaP = QType
0.3

-- learning rate
-- learningRate :: TT.LearningRate QDevice QDType
-- learningRate = 0.01

nWorkers :: Int
nWorkers :: Int
nWorkers = Int
2

-- A2C
-- ===

printTensors :: TT.HList (ModelTensors dev) -> IO ()
printTensors :: forall (dev :: (DeviceType, Nat)).
HList (ModelTensors dev) -> IO ()
printTensors (Tensor dev 'Double '[8, 1, 1, 1]
_ TT.:. Tensor dev 'Double '[8]
t TT.:. HList
  '[Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[5],
    Tensor dev 'Double '[5], Tensor dev 'Double '[5],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
    Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
_) = Tensor dev 'Double '[8] -> IO ()
forall a. Show a => a -> IO ()
print Tensor dev 'Double '[8]
t

printParams :: TT.HList (ModelParams dev) -> IO ()
printParams :: forall (dev :: (DeviceType, Nat)). HList (ModelParams dev) -> IO ()
printParams (Parameter dev 'Double '[8, 1, 1, 1]
_ TT.:. Parameter dev 'Double '[8]
t TT.:. HList
  '[Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 13, 5],
    Parameter dev 'Double '[8, 13, 5],
    Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
    Parameter dev 'Double '[8], Parameter dev 'Double '[5],
    Parameter dev 'Double '[5], Parameter dev 'Double '[5],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8],
    Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
    Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
    Parameter dev 'Double '[8], Parameter dev 'Double '[8],
    Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
_) = Parameter dev 'Double '[8] -> IO ()
forall a. Show a => a -> IO ()
print Parameter dev 'Double '[8]
t

data A2CState dev = A2CState
  { forall (dev :: (DeviceType, Nat)). A2CState dev -> QModel dev
a2cActor :: !(QModel dev)
  , forall (dev :: (DeviceType, Nat)). A2CState dev -> QModel dev
a2cCritic :: !(QModel dev)
  , forall (dev :: (DeviceType, Nat)). A2CState dev -> GD
a2cOptActor :: !TT.GD -- !(TT.CppOptimizerState TT.AdamOptions ModelParams) -- !(TT.Adam ModelTensors) --
  , forall (dev :: (DeviceType, Nat)). A2CState dev -> GD
a2cOptCritic :: !TT.GD -- !(TT.Adam ModelTensors)
  }
  deriving ((forall x. A2CState dev -> Rep (A2CState dev) x)
-> (forall x. Rep (A2CState dev) x -> A2CState dev)
-> Generic (A2CState dev)
forall x. Rep (A2CState dev) x -> A2CState dev
forall x. A2CState dev -> Rep (A2CState dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (A2CState dev) x -> A2CState dev
forall (dev :: (DeviceType, Nat)) x.
A2CState dev -> Rep (A2CState dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
A2CState dev -> Rep (A2CState dev) x
from :: forall x. A2CState dev -> Rep (A2CState dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (A2CState dev) x -> A2CState dev
to :: forall x. Rep (A2CState dev) x -> A2CState dev
Generic)

data A2CStepState dev = A2CStepState
  { forall (dev :: (DeviceType, Nat)).
A2CStepState dev -> HList (ModelTensors dev)
a2cStepZV :: !(TT.HList (ModelTensors dev))
  , forall (dev :: (DeviceType, Nat)).
A2CStepState dev -> HList (ModelTensors dev)
a2cStepZP :: !(TT.HList (ModelTensors dev))
  , forall (dev :: (DeviceType, Nat)). A2CStepState dev -> QType
a2cStepIntensity :: !QType
  , forall (dev :: (DeviceType, Nat)). A2CStepState dev -> QType
a2cStepReward :: !QType
  , forall (dev :: (DeviceType, Nat)).
A2CStepState dev
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
a2cStepState
      :: !( GreedyState
              (Edges SPitch)
              [Edge SPitch]
              (Notes SPitch)
              (PVLeftmost SPitch)
          )
  , forall (dev :: (DeviceType, Nat)).
A2CStepState dev -> NonEmpty PVAction
a2cStepActions :: !(NE.NonEmpty PVAction)
  }

initPieceState
  :: (TT.KnownDevice dev)
  => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
  -> Path [Note SPitch] [Edge SPitch]
  -> TT.HList (ModelTensors dev)
  -> Either (A2CStepState dev) QType
initPieceState :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
initPieceState Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval Path [Note SPitch] [Edge SPitch]
input HList (ModelTensors dev)
z0 =
  let
    state :: GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
state = Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
forall tr tr' slc slc' h v op.
Eval tr tr' slc slc' h v
-> Path slc' tr' -> GreedyState tr tr' slc op
initParseState Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval Path [Note SPitch] [Edge SPitch]
input
    actions :: [PVAction]
actions = Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
state
   in
    case [PVAction]
actions of
      [] -> QType -> Either (A2CStepState dev) QType
forall a b. b -> Either a b
Right (-QType
inf)
      (PVAction
a : [PVAction]
as) -> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. a -> Either a b
Left (A2CStepState dev -> Either (A2CStepState dev) QType)
-> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. (a -> b) -> a -> b
$ HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
forall (dev :: (DeviceType, Nat)).
HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
A2CStepState HList (ModelTensors dev)
z0 HList (ModelTensors dev)
z0 QType
1 QType
0 GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
state (PVAction
a PVAction -> [PVAction] -> NonEmpty PVAction
forall a. a -> [a] -> NonEmpty a
NE.:| [PVAction]
as)

pieceStep
  :: forall dev label
   . (IsValidDevice dev)
  => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
  -> Rand.IOGenM Rand.StdGen
  -> PVRewardFn label
  -> label
  -> QType
  -- ^ learning rate
  -> QType
  -- ^ temperature
  -> Int
  -- ^ iteration
  -> A2CState dev
  -> A2CStepState dev
  -> ET.ExceptT String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
pieceStep :: forall (dev :: (DeviceType, Nat)) label.
IsValidDevice dev =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ExceptT
     String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
pieceStep Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward label
len QType
lr QType
temp Int
i (A2CState QModel dev
actor QModel dev
critic GD
opta GD
optc) (A2CStepState HList (ModelTensors dev)
zV HList (ModelTensors dev)
zP QType
intensity QType
reward GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state NonEmpty PVAction
actions) = do
  -- EitherT String IO
  -- preparation: list actions, compute policy
  -- TODO: smarter cap than taking 200 actions
  let
    -- encodings = encodeStep state <$> actions
    -- policy = T.softmax (T.Dim 0) $ T.cat (T.Dim 0) $ TT.toDynamic . forwardPolicy actor <$> encodings
    policy :: Tensor
policy = QType -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
T.pow (QType
1 QType -> QType -> QType
forall a. Fractional a => a -> a -> a
/ QType
temp) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> Tensor)
-> Tensor
forall (dev :: (DeviceType, Nat)) r.
KnownDevice dev =>
GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state NonEmpty PVAction
actions (QModel dev -> QEncoding dev '[n] -> Tensor
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
actor)
  -- choose action according to policy
  actionIndex <- IO Int -> ExceptT String IO Int
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Int -> ExceptT String IO Int)
-> IO Int -> ExceptT String IO Int
forall a b. (a -> b) -> a -> b
$ Vector QType -> IOGenM StdGen -> IO Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v QType) =>
v QType -> g -> m Int
categorical ([QType] -> Vector QType
forall a. [a] -> Vector a
V.fromList ([QType] -> Vector QType) -> [QType] -> Vector QType
forall a b. (a -> b) -> a -> b
$ Tensor -> [QType]
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> [QType]) -> Tensor -> [QType]
forall a b. (a -> b) -> a -> b
$ DType -> Tensor -> Tensor
T.toDType DType
T.Double Tensor
policy) IOGenM StdGen
gen
  let action = NonEmpty PVAction
actions NonEmpty PVAction -> Int -> PVAction
forall a. HasCallStack => NonEmpty a -> Int -> a
NE.!! Int
actionIndex
  -- apply action
  state' <- ET.except $ applyAction state action
  let actions' = case Either
  (GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
  (Edges SPitch, [PVLeftmost SPitch])
state' of
        Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState -> [PVAction] -> Maybe (NonEmpty PVAction)
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([PVAction] -> Maybe (NonEmpty PVAction))
-> [PVAction] -> Maybe (NonEmpty PVAction)
forall a b. (a -> b) -> a -> b
$ Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState
        Right (Edges SPitch, [PVLeftmost SPitch])
_ -> Maybe (NonEmpty PVAction)
forall a. Maybe a
Nothing
  -- compute A2C update
  r <- lift $ fReward state' actions' action len
  let vS = QModel dev -> StateEncoding dev -> Tensor dev 'Double '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> StateEncoding dev -> QTensor dev '[1]
forwardValue QModel dev
critic (StateEncoding dev -> Tensor dev 'Double '[1])
-> StateEncoding dev -> Tensor dev 'Double '[1]
forall a b. (a -> b) -> a -> b
$ GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
encodePVState GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state
      vS' = case Either
  (GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
  (Edges SPitch, [PVLeftmost SPitch])
state' of
        Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s' -> QModel dev -> StateEncoding dev -> Tensor dev 'Double '[1]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> StateEncoding dev -> QTensor dev '[1]
forwardValue QModel dev
critic (StateEncoding dev -> Tensor dev 'Double '[1])
-> StateEncoding dev -> Tensor dev 'Double '[1]
forall a b. (a -> b) -> a -> b
$ GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> StateEncoding dev
encodePVState GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s'
        Right (Edges SPitch, [PVLeftmost SPitch])
_ -> Tensor dev 'Double '[1]
0
      delta = QType -> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.addScalar QType
r (Tensor dev 'Double '[] -> Tensor dev 'Double '[])
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a b. (a -> b) -> a -> b
$ Tensor dev 'Double '[1] -> Tensor dev 'Double '[]
forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeAll (Tensor dev 'Double '[1] -> Tensor dev 'Double '[])
-> Tensor dev 'Double '[1] -> Tensor dev 'Double '[]
forall a b. (a -> b) -> a -> b
$ QType -> Tensor dev 'Double '[1] -> Tensor dev 'Double '[1]
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
gamma Tensor dev 'Double '[1]
vS' Tensor dev 'Double '[1]
-> Tensor dev 'Double '[1] -> Tensor dev 'Double '[1]
forall a. Num a => a -> a -> a
- Tensor dev 'Double '[1]
vS
      gradV = Tensor dev 'Double '[]
-> HList
     '[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[5],
       Parameter dev 'Double '[5], Parameter dev 'Double '[5],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
       Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
     '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[5],
       Tensor dev 'Double '[5], Tensor dev 'Double '[5],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
       Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> HList
     '[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[5],
       Parameter dev 'Double '[5], Parameter dev 'Double '[5],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
       Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
     '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[5],
       Tensor dev 'Double '[5], Tensor dev 'Double '[5],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
       Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
TT.grad (Tensor dev 'Double '[1] -> Tensor dev 'Double '[]
forall (shape :: [Nat]) (shape' :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape' ~ SqueezeAll shape) =>
Tensor device dtype shape -> Tensor device dtype shape'
TT.squeezeAll Tensor dev 'Double '[1]
vS Tensor dev 'Double '[]
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a. Num a => a -> a -> a
+ QModel dev -> Tensor dev 'Double '[]
forall (dev :: (DeviceType, Nat)) (ps :: [*]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
critic) (QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
critic)
      zV' = QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
updateEligCritic QType
gamma QType
lambdaV HList (ModelTensors dev)
zV HList
  '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[5],
    Tensor dev 'Double '[5], Tensor dev 'Double '[5],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
    Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
gradV
      actionLogProb :: QTensor dev '[]
      actionLogProb = Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log (Tensor dev 'Double '[] -> Tensor dev 'Double '[])
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor dev 'Double '[]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor
T.squeezeAll (Tensor
policy Tensor -> Int -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
T.! Int
actionIndex))
      gradP = Tensor dev 'Double '[]
-> HList
     '[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[5],
       Parameter dev 'Double '[5], Parameter dev 'Double '[5],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
       Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
     '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[5],
       Tensor dev 'Double '[5], Tensor dev 'Double '[5],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
       Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
forall a b (dtype :: DType) (device :: (DeviceType, Nat)).
HasGrad a b =>
Tensor device dtype '[] -> a -> b
forall (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[]
-> HList
     '[Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 13, 5],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 2, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 1, 1, 1], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8, 8, 13, 5],
       Parameter dev 'Double '[8], Parameter dev 'Double '[5],
       Parameter dev 'Double '[5], Parameter dev 'Double '[5],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8, 8, 13, 5], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1],
       Parameter dev 'Double '[8, 8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[8], Parameter dev 'Double '[8],
       Parameter dev 'Double '[1, 8], Parameter dev 'Double '[1]]
-> HList
     '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
       Tensor dev 'Double '[8], Tensor dev 'Double '[5],
       Tensor dev 'Double '[5], Tensor dev 'Double '[5],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
       Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[8], Tensor dev 'Double '[8],
       Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
TT.grad (Tensor dev 'Double '[]
actionLogProb Tensor dev 'Double '[]
-> Tensor dev 'Double '[] -> Tensor dev 'Double '[]
forall a. Num a => a -> a -> a
+ QModel dev -> Tensor dev 'Double '[]
forall (dev :: (DeviceType, Nat)) (ps :: [*]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
actor) (QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
actor)
      zP' = QType
-> QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType
-> QType
-> QType
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> HList (ModelTensors dev)
updateEligActor QType
gamma QType
lambdaP QType
intensity HList (ModelTensors dev)
zP HList
  '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[5],
    Tensor dev 'Double '[5], Tensor dev 'Double '[5],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
    Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
gradP
      intensity' = QType
gamma QType -> QType -> QType
forall a. Num a => a -> a -> a
* QType
intensity
      learningRate = QType -> Tensor dev 'Double '[]
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QType -> QTensor dev '[]
toQTensor (QType -> QType
forall a. Num a => a -> a
negate QType
lr)
  (!actor', !opta') <- lift $ TT.runStep' actor opta learningRate $ mulModelTensors delta zP'
  (!critic', !optc') <- lift $ TT.runStep' critic optc learningRate $ mulModelTensors delta zV'
  let loss' = Tensor -> QType
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> QType) -> Tensor -> QType
forall a b. (a -> b) -> a -> b
$ Tensor dev 'Double '[] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic Tensor dev 'Double '[]
delta
      reward' = QType
reward QType -> QType -> QType
forall a. Num a => a -> a -> a
+ QType
r
  let pieceState' = case (Either
  (GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
  (Edges SPitch, [PVLeftmost SPitch])
state', Maybe (NonEmpty PVAction)
actions') of
        (Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s', Just NonEmpty PVAction
a') -> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. a -> Either a b
Left (A2CStepState dev -> Either (A2CStepState dev) QType)
-> A2CStepState dev -> Either (A2CStepState dev) QType
forall a b. (a -> b) -> a -> b
$ HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
forall (dev :: (DeviceType, Nat)).
HList (ModelTensors dev)
-> HList (ModelTensors dev)
-> QType
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> A2CStepState dev
A2CStepState HList
  '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[5],
    Tensor dev 'Double '[5], Tensor dev 'Double '[5],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
    Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
zV' HList
  '[Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 13, 5], Tensor dev 'Double '[8, 13, 5],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 2, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 1, 1, 1], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8, 8, 13, 5],
    Tensor dev 'Double '[8], Tensor dev 'Double '[5],
    Tensor dev 'Double '[5], Tensor dev 'Double '[5],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8, 8, 13, 5], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1],
    Tensor dev 'Double '[8, 8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[8], Tensor dev 'Double '[8],
    Tensor dev 'Double '[1, 8], Tensor dev 'Double '[1]]
HList (ModelTensors dev)
zP' QType
intensity' QType
reward' GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s' NonEmpty PVAction
a'
        (Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s', Maybe (NonEmpty PVAction)
Nothing) ->
          -- DT.trace ("incomplete parse:\n" <> show s') $
          QType -> Either (A2CStepState dev) QType
forall a b. b -> Either a b
Right QType
reward'
        (Right (Edges SPitch, [PVLeftmost SPitch])
_, Maybe (NonEmpty PVAction)
_) -> QType -> Either (A2CStepState dev) QType
forall a b. b -> Either a b
Right QType
reward' -- TT.toDouble (TT.squeezeAll vS) - r
  pure (A2CState actor' critic' opta' optc', pieceState', loss')

-- | Run an episode
runEpisode
  :: forall dev label
   . (_)
  => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
  -> Rand.IOGenM Rand.StdGen
  -> PVRewardFn label
  -> (QType -> QType)
  -> (QType -> QType)
  -> Path [Note SPitch] [Edge SPitch]
  -> label
  -> A2CState dev
  -> Int
  -> IO (Either String (A2CState dev, QType, QType))
runEpisode :: Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
runEpisode !Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval !IOGenM StdGen
gen !PVRewardFn label
fReward !QType -> QType
fLr !QType -> QType
fTemp !Path [Note SPitch] [Edge SPitch]
input !label
label !A2CState dev
modelState !Int
i =
  case Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> Path [Note SPitch] [Edge SPitch]
-> HList (ModelTensors dev)
-> Either (A2CStepState dev) QType
initPieceState Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval Path [Note SPitch] [Edge SPitch]
input HList (ModelTensors dev)
z0 of
    Left A2CStepState dev
s0 -> ExceptT String IO (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
ET.runExceptT (ExceptT String IO (A2CState dev, QType, QType)
 -> IO (Either String (A2CState dev, QType, QType)))
-> ExceptT String IO (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall a b. (a -> b) -> a -> b
$ A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
forall {dev :: (DeviceType, Nat)}.
(GeluDTypeIsValid dev 'Double, RandDTypeIsValid dev 'Double,
 BasicArithmeticDTypeIsValid dev 'Double,
 SumDTypeIsValid dev 'Double, MeanDTypeValidation dev 'Double,
 StandardFloatingPointDTypeValidation dev 'Double,
 KnownDevice dev) =>
A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
go A2CState dev
modelState A2CStepState dev
s0 List QType
forall a. List a
SL.Nil
    Right QType
reward -> Either String (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (A2CState dev, QType, QType)
 -> IO (Either String (A2CState dev, QType, QType)))
-> Either String (A2CState dev, QType, QType)
-> IO (Either String (A2CState dev, QType, QType))
forall a b. (a -> b) -> a -> b
$ (A2CState dev, QType, QType)
-> Either String (A2CState dev, QType, QType)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
modelState, QType
reward, QType
0)
 where
  z0 :: TT.HList (ModelTensors dev)
  z0 :: HList (ModelTensors dev)
z0 = QModel dev -> HList (ModelTensors dev)
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QModel dev -> HList (ModelTensors dev)
modelZeros (QModel dev -> HList (ModelTensors dev))
-> QModel dev -> HList (ModelTensors dev)
forall a b. (a -> b) -> a -> b
$ A2CState dev -> QModel dev
forall (dev :: (DeviceType, Nat)). A2CState dev -> QModel dev
a2cActor A2CState dev
modelState
  lr :: QType
lr = QType -> QType
fLr (QType -> QType) -> QType -> QType
forall a b. (a -> b) -> a -> b
$ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
  temp :: QType
temp = QType -> QType
fTemp (QType -> QType) -> QType -> QType
forall a b. (a -> b) -> a -> b
$ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
  -- len = pathLen input
  go :: A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
go A2CState dev
modelState A2CStepState dev
pieceState List QType
losses = do
    (modelState', pieceState', loss) <- Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ExceptT
     String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
forall (dev :: (DeviceType, Nat)) label.
IsValidDevice dev =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> label
-> QType
-> QType
-> Int
-> A2CState dev
-> A2CStepState dev
-> ExceptT
     String IO (A2CState dev, Either (A2CStepState dev) QType, QType)
pieceStep Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward label
label QType
lr QType
temp Int
i A2CState dev
modelState A2CStepState dev
pieceState
    let losses' = QType
loss QType -> List QType -> List QType
forall a. a -> List a -> List a
`SL.Cons` List QType
losses
    case pieceState' of
      Left A2CStepState dev
ps' -> A2CState dev
-> A2CStepState dev
-> List QType
-> ExceptT String IO (A2CState dev, QType, QType)
go A2CState dev
modelState' A2CStepState dev
ps' List QType
losses'
      Right QType
reward -> (A2CState dev, QType, QType)
-> ExceptT String IO (A2CState dev, QType, QType)
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
modelState', QType
reward, List QType -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean List QType
losses')

runAccuracy
  :: (IsValidDevice dev)
  => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) slc' (Spread SPitch) (PVLeftmost SPitch)
  -> PVRewardFn label
  -> QModel dev
  -> (Path slc' [Edge SPitch], label)
  -> IO (Either String (QType, PVAnalysis SPitch))
runAccuracy :: forall (dev :: (DeviceType, Nat)) slc' label.
IsValidDevice dev =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
-> PVRewardFn label
-> QModel dev
-> (Path slc' [Edge SPitch], label)
-> IO (Either String (QType, PVAnalysis SPitch))
runAccuracy !Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
eval !PVRewardFn label
fReward !QModel dev
actor (!Path slc' [Edge SPitch]
input, !label
label) = case Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
eval GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
s0 of
  [] -> Either String (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (QType, PVAnalysis SPitch)
 -> IO (Either String (QType, PVAnalysis SPitch)))
-> Either String (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall a b. (a -> b) -> a -> b
$ String -> Either String (QType, PVAnalysis SPitch)
forall a b. a -> Either a b
Left String
"cannot parse: no possible actions for first step!"
  (PVAction
a : [PVAction]
as) -> ExceptT String IO (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
ET.runExceptT (ExceptT String IO (QType, PVAnalysis SPitch)
 -> IO (Either String (QType, PVAnalysis SPitch)))
-> ExceptT String IO (QType, PVAnalysis SPitch)
-> IO (Either String (QType, PVAnalysis SPitch))
forall a b. (a -> b) -> a -> b
$ Tensor
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT String IO (QType, PVAnalysis SPitch)
forall {slc}.
Tensor
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT
     String
     IO
     (QType,
      Analysis
        (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
go Tensor
0 QType
0 GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
forall {op}.
GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
s0 (PVAction
a PVAction -> [PVAction] -> NonEmpty PVAction
forall a. a -> [a] -> NonEmpty a
NE.:| [PVAction]
as)
 where
  s0 :: GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
s0 = Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
-> Path slc' [Edge SPitch]
-> GreedyState (Edges SPitch) [Edge SPitch] (Notes SPitch) op
forall tr tr' slc slc' h v op.
Eval tr tr' slc slc' h v
-> Path slc' tr' -> GreedyState tr tr' slc op
initParseState Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
eval Path slc' [Edge SPitch]
input
  go :: Tensor
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT
     String
     IO
     (QType,
      Analysis
        (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
go !Tensor
cost !QType
reward !GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state !NonEmpty PVAction
actions = do
    let
      -- encodings = encodeStep state <$> actions
      -- probs = T.softmax (T.Dim 0) $ T.cat (T.Dim 0) $ TT.toDynamic . forwardPolicy actor <$> encodings
      probs :: Tensor
probs = GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> Tensor)
-> Tensor
forall (dev :: (DeviceType, Nat)) r.
KnownDevice dev =>
GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> (forall (n :: Nat). KnownNat n => QEncoding dev '[n] -> r)
-> r
withBatchedEncoding GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state NonEmpty PVAction
actions (QModel dev -> QEncoding dev '[n] -> Tensor
forall (dev :: (DeviceType, Nat)) (batchSize :: Nat).
(IsValidDevice dev, KnownNat batchSize) =>
QModel dev -> QEncoding dev '[batchSize] -> Tensor
runBatchedPolicy QModel dev
actor)
      best :: Int
best = Tensor -> Int
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> Int) -> Tensor -> Int
forall a b. (a -> b) -> a -> b
$ Dim -> KeepDim -> Tensor -> Tensor
T.argmax (Int -> Dim
T.Dim Int
0) KeepDim
T.KeepDim Tensor
probs
      action :: PVAction
action = NonEmpty PVAction
actions NonEmpty PVAction -> Int -> PVAction
forall a. HasCallStack => NonEmpty a -> Int -> a
NE.!! Int
best
      bestprob :: Tensor
bestprob = Tensor
probs Tensor -> Int -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
T.! Int
best
      cost' :: Tensor
cost' = Tensor
cost Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor -> Tensor
T.log Tensor
bestprob
    state' <- Either
  String
  (Either
     (GreedyState
        (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
     (Edges SPitch, [PVLeftmost SPitch]))
-> ExceptT
     String
     IO
     (Either
        (GreedyState
           (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
        (Edges SPitch, [PVLeftmost SPitch]))
forall (m :: * -> *) e a. Monad m => Either e a -> ExceptT e m a
ET.except (Either
   String
   (Either
      (GreedyState
         (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
      (Edges SPitch, [PVLeftmost SPitch]))
 -> ExceptT
      String
      IO
      (Either
         (GreedyState
            (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
         (Edges SPitch, [PVLeftmost SPitch])))
-> Either
     String
     (Either
        (GreedyState
           (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
        (Edges SPitch, [PVLeftmost SPitch]))
-> ExceptT
     String
     IO
     (Either
        (GreedyState
           (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
        (Edges SPitch, [PVLeftmost SPitch]))
forall a b. (a -> b) -> a -> b
$ GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> PVAction
-> Either
     String
     (Either
        (GreedyState
           (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
        (Edges SPitch, [PVLeftmost SPitch]))
forall {k1} {k2} (m :: k1) tr tr' slc (slc' :: k2) s f h.
GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h
-> Either
     String
     (Either
        (GreedyState tr tr' slc (Leftmost s f h)) (tr, [Leftmost s f h]))
applyAction GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
state PVAction
action
    let actions' = case Either
  (GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch))
  (Edges SPitch, [PVLeftmost SPitch])
state' of
          Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState -> [PVAction] -> Maybe (NonEmpty PVAction)
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty ([PVAction] -> Maybe (NonEmpty PVAction))
-> [PVAction] -> Maybe (NonEmpty PVAction)
forall a b. (a -> b) -> a -> b
$ Int -> [PVAction] -> [PVAction]
forall a. Int -> [a] -> [a]
take Int
200 ([PVAction] -> [PVAction]) -> [PVAction] -> [PVAction]
forall a b. (a -> b) -> a -> b
$ Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> [PVAction]
forall {k} (m :: k) tr tr' slc slc' s f h.
Eval tr tr' slc slc' h (Leftmost s f h)
-> GreedyState tr tr' slc (Leftmost s f h) -> [Action slc tr s f h]
getActions Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  slc'
  (Spread SPitch)
  (PVLeftmost SPitch)
eval GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
newState
          Right (Edges SPitch, [PVLeftmost SPitch])
_ -> Maybe (NonEmpty PVAction)
forall a. Maybe a
Nothing
    actionReward <- lift $ fReward state' actions' action label
    let reward' = QType
reward QType -> QType -> QType
forall a. Num a => a -> a -> a
+ QType
actionReward
    -- lift $ print probs
    case (state', actions') of
      (Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
_, Maybe (NonEmpty PVAction)
Nothing) ->
        String
-> ExceptT
     String
     IO
     (QType,
      Analysis
        (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
ET.throwE String
"cannot parse: no possible actions in non-terminal state!"
      (Left GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s', Just NonEmpty PVAction
a') -> Tensor
-> QType
-> GreedyState
     (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
-> NonEmpty PVAction
-> ExceptT
     String
     IO
     (QType,
      Analysis
        (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
go Tensor
cost' QType
reward' GreedyState
  (Edges SPitch) [Edge SPitch] (Notes SPitch) (PVLeftmost SPitch)
s' NonEmpty PVAction
a'
      (Right (Edges SPitch
top, [PVLeftmost SPitch]
deriv), Maybe (NonEmpty PVAction)
_) -> do
        IO () -> ExceptT String IO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT String IO ()) -> IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"accuracy cost: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Tensor -> String
forall a. Show a => a -> String
show Tensor
cost'
        Bool -> ExceptT String IO () -> ExceptT String IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Tensor -> QType
forall a. TensorLike a => Tensor -> a
T.asValue Tensor
cost QType -> QType -> Bool
forall a. Eq a => a -> a -> Bool
== (QType
0 :: Double)) (ExceptT String IO () -> ExceptT String IO ())
-> ExceptT String IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ do
          IO () -> ExceptT String IO ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO () -> ExceptT String IO ()) -> IO () -> ExceptT String IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Tensor -> String
forall a. Show a => a -> String
show Tensor
bestprob
        let ana :: Analysis
  (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
ana = [PVLeftmost SPitch]
-> Path (Edges SPitch) slc
-> Analysis
     (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
forall s f h tr slc.
[Leftmost s f h] -> Path tr slc -> Analysis s f h tr slc
Analysis [PVLeftmost SPitch]
deriv (Edges SPitch -> Path (Edges SPitch) slc
forall around between. around -> Path around between
PathEnd Edges SPitch
top)
        (QType,
 Analysis
   (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
-> ExceptT
     String
     IO
     (QType,
      Analysis
        (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc)
forall a. a -> ExceptT String IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (QType
reward', Analysis
  (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
forall {slc}.
Analysis
  (Split SPitch) (Freeze SPitch) (Spread SPitch) (Edges SPitch) slc
ana)

deriving instance (NoThunks a) => NoThunks (SL.List a)

data A2CLoopState dev = A2CLoopState
  { forall (dev :: (DeviceType, Nat)). A2CLoopState dev -> A2CState dev
a2clState :: A2CState dev
  , forall (dev :: (DeviceType, Nat)).
A2CLoopState dev -> List (List QType)
a2clRewards :: SL.List (SL.List QType)
  , forall (dev :: (DeviceType, Nat)).
A2CLoopState dev -> List (List QType)
a2clLosses :: SL.List (SL.List QType)
  , forall (dev :: (DeviceType, Nat)).
A2CLoopState dev -> List (List QType)
a2clAccs :: SL.List (SL.List QType)
  }
  deriving ((forall x. A2CLoopState dev -> Rep (A2CLoopState dev) x)
-> (forall x. Rep (A2CLoopState dev) x -> A2CLoopState dev)
-> Generic (A2CLoopState dev)
forall x. Rep (A2CLoopState dev) x -> A2CLoopState dev
forall x. A2CLoopState dev -> Rep (A2CLoopState dev) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (dev :: (DeviceType, Nat)) x.
Rep (A2CLoopState dev) x -> A2CLoopState dev
forall (dev :: (DeviceType, Nat)) x.
A2CLoopState dev -> Rep (A2CLoopState dev) x
$cfrom :: forall (dev :: (DeviceType, Nat)) x.
A2CLoopState dev -> Rep (A2CLoopState dev) x
from :: forall x. A2CLoopState dev -> Rep (A2CLoopState dev) x
$cto :: forall (dev :: (DeviceType, Nat)) x.
Rep (A2CLoopState dev) x -> A2CLoopState dev
to :: forall x. Rep (A2CLoopState dev) x -> A2CLoopState dev
Generic)

trainA2C
  :: forall dev label
   . (IsValidDevice dev)
  => Eval (Edges SPitch) [Edge SPitch] (Notes SPitch) [Note SPitch] (Spread SPitch) (PVLeftmost SPitch)
  -> Rand.IOGenM Rand.StdGen
  -> PVRewardFn label
  -> (QType -> QType)
  -- ^ learning rate schedule
  -> (QType -> QType)
  -- ^ temperature schedule
  -> Maybe [QType]
  -> QModel dev
  -> QModel dev
  -> [(Path [Note SPitch] [Edge SPitch], label)]
  -> Int
  -> IO ([[QType]], [QType], QModel dev, QModel dev)
trainA2C :: forall (dev :: (DeviceType, Nat)) label.
IsValidDevice dev =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Maybe [QType]
-> QModel dev
-> QModel dev
-> [(Path [Note SPitch] [Edge SPitch], label)]
-> Int
-> IO ([[QType]], [QType], QModel dev, QModel dev)
trainA2C Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward QType -> QType
fLr QType -> QType
fTemp Maybe [QType]
targets QModel dev
actor0 QModel dev
critic0 [(Path [Note SPitch] [Edge SPitch], label)]
pieces Int
n = do
  -- print $ qModelFinal2 model0
  -- opta <- TT.initOptimizer (TT.AdamOptions 0.0001 (0.9, 0.999) 1e-8 0 False) actor0
  let
    opta :: GD
opta = GD
TT.GD -- TT.mkAdam 0 0.9 0.99 (TT.flattenParameters actor0)
    optc :: GD
optc = GD
TT.GD -- TT.mkAdam 0 0.9 0.99 (TT.flattenParameters critic0)
    emptyStat :: List (List a)
emptyStat = [List a] -> List (List a)
forall a. [a] -> List a
SL.fromListReversed (Int -> List a -> [List a]
forall a. Int -> a -> [a]
replicate ([(Path [Note SPitch] [Edge SPitch], label)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Path [Note SPitch] [Edge SPitch], label)]
pieces) List a
forall a. List a
SL.Nil)
    state0 :: A2CState dev
state0 = QModel dev -> QModel dev -> GD -> GD -> A2CState dev
forall (dev :: (DeviceType, Nat)).
QModel dev -> QModel dev -> GD -> GD -> A2CState dev
A2CState QModel dev
actor0 QModel dev
critic0 GD
opta GD
optc
  (A2CLoopState (A2CState actorTrained criticTrained _ _) rewards losses accs) <- A2CLoopState dev
-> Int
-> (A2CLoopState dev -> Int -> IO (A2CLoopState dev))
-> IO (A2CLoopState dev)
forall a. a -> Int -> (a -> Int -> IO a) -> IO a
T.foldLoop (A2CState dev
-> List (List QType)
-> List (List QType)
-> List (List QType)
-> A2CLoopState dev
forall (dev :: (DeviceType, Nat)).
A2CState dev
-> List (List QType)
-> List (List QType)
-> List (List QType)
-> A2CLoopState dev
A2CLoopState A2CState dev
state0 List (List QType)
forall {a}. List (List a)
emptyStat List (List QType)
forall {a}. List (List a)
emptyStat List (List QType)
forall {a}. List (List a)
emptyStat) Int
n A2CLoopState dev -> Int -> IO (A2CLoopState dev)
forall {dev :: (DeviceType, Nat)}.
(GeluDTypeIsValid dev 'Double, RandDTypeIsValid dev 'Double,
 BasicArithmeticDTypeIsValid dev 'Double,
 SumDTypeIsValid dev 'Double, MeanDTypeValidation dev 'Double,
 StandardFloatingPointDTypeValidation dev 'Double,
 KnownDevice dev) =>
A2CLoopState dev -> Int -> IO (A2CLoopState dev)
trainEpoch
  pure
    ( SL.toListReversed $ SL.toListReversed <$> rewards
    , SL.toListReversed $ mean <$> losses
    , actorTrained
    , criticTrained
    )
 where
  -- \| train a single episode on a single piece
  trainPiece :: ProgressBar s
-> Int
-> (A2CState dev, List QType, List QType)
-> ((Path [Note SPitch] [Edge SPitch], label), b)
-> IO (A2CState dev, List QType, List QType)
trainPiece ProgressBar s
pb Int
i (!A2CState dev
state, !List QType
rewards, !List QType
losses) ((!Path [Note SPitch] [Edge SPitch]
piece, label
label), !b
j) = do
    !result <- Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
forall (dev :: (DeviceType, Nat)) label.
(GeluDTypeIsValid dev 'Double, RandDTypeIsValid dev 'Double,
 BasicArithmeticDTypeIsValid dev 'Double,
 SumDTypeIsValid dev 'Double, MeanDTypeValidation dev 'Double,
 StandardFloatingPointDTypeValidation dev 'Double,
 KnownDevice dev) =>
Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
-> IOGenM StdGen
-> PVRewardFn label
-> (QType -> QType)
-> (QType -> QType)
-> Path [Note SPitch] [Edge SPitch]
-> label
-> A2CState dev
-> Int
-> IO (Either String (A2CState dev, QType, QType))
runEpisode Eval
  (Edges SPitch)
  [Edge SPitch]
  (Notes SPitch)
  [Note SPitch]
  (Spread SPitch)
  (PVLeftmost SPitch)
eval IOGenM StdGen
gen PVRewardFn label
fReward QType -> QType
fLr QType -> QType
fTemp Path [Note SPitch] [Edge SPitch]
piece label
label A2CState dev
state Int
i
    PB.incProgress pb 1
    case result of
      Left String
error -> do
        String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Episode error: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
error
        (A2CState dev, List QType, List QType)
-> IO (A2CState dev, List QType, List QType)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
state, List QType
rewards, List QType
losses)
      Right (A2CState dev
state', QType
r, QType
loss) -> do
        -- putStrLn $ "loss " <> show j <> ": " <> show loss
        (A2CState dev, List QType, List QType)
-> IO (A2CState dev, List QType, List QType)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (A2CState dev
state', QType
r QType -> List QType -> List QType
forall a. a -> List a -> List a
`SL.Cons` List QType
rewards, QType
loss QType -> List QType -> List QType
forall a. a -> List a -> List a
`SL.Cons` List QType
losses)
  -- \| train one episode on each piece
  trainEpoch :: A2CLoopState dev -> Int -> IO (A2CLoopState dev)
trainEpoch fullstate :: A2CLoopState dev
fullstate@(A2CLoopState !A2CState dev
state !List (List QType)
rewardsHist !List (List QType)
lossHist !List (List QType)
accuracies) !Int
i = do
    -- putStrLn $ "\nepoch " <> show i
    pb <-
      Style () -> QType -> Progress () -> IO (ProgressBar ())
forall s. Style s -> QType -> Progress s -> IO (ProgressBar s)
PB.newProgressBar
        ( Style ()
forall s. Style s
PB.defStyle
            { PB.stylePrefix = "Epoch " <> (PB.msg $ Txt.show i) <> ": " <> (PB.elapsedTime PB.renderDuration)
            , PB.stylePostfix = PB.exact <> " (" <> PB.percentage <> ")"
            , PB.styleWidth = PB.ConstantWidth 80
            }
        )
        QType
10
        (Int -> Int -> () -> Progress ()
forall s. Int -> Int -> s -> Progress s
PB.Progress Int
0 ([(Path [Note SPitch] [Edge SPitch], label)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Path [Note SPitch] [Edge SPitch], label)]
pieces) ())
    -- performGC
    -- thunkCheck <- noThunks ["trainA2C", "trainEpoch"] fullstate
    -- case thunkCheck of
    --   Nothing -> pure ()
    --   Just thunkInfo -> error $ "Unexpected thunk at " <> show (thunkContext thunkInfo)
    -- run epoch
    (!state', !rewards, !losses) <-
      foldM (trainPiece pb i) (state, SL.Nil, SL.Nil) (zip pieces [1 ..])
    let rewardsHist' = (QType -> List QType -> List QType)
-> List QType -> List (List QType) -> List (List QType)
forall a b c. (a -> b -> c) -> List a -> List b -> List c
zipWithStrict QType -> List QType -> List QType
forall a. a -> List a -> List a
SL.Cons List QType
rewards List (List QType)
rewardsHist
        lossHist' = (QType -> List QType -> List QType)
-> List QType -> List (List QType) -> List (List QType)
forall a b c. (a -> b -> c) -> List a -> List b -> List c
zipWithStrict QType -> List QType -> List QType
forall a. a -> List a -> List a
SL.Cons List QType
losses List (List QType)
lossHist
    -- compute greedy reward ("accuracy")
    accuracies' <-
      if (i `mod` 10) == 0
        then do
          results <- mapM (runAccuracy eval fReward (a2cActor state)) pieces
          newAccs <- forM (zip results [1 ..]) $ \(Either String (QType, PVAnalysis SPitch)
result, Integer
j) ->
            case Either String (QType, PVAnalysis SPitch)
result of
              Left String
error -> do
                String -> IO ()
putStrLn String
error
                QType -> IO QType
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (-QType
inf)
              Right (QType
acc, Analysis [PVLeftmost SPitch]
deriv Path (Edges SPitch) (Notes SPitch)
top) -> do
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
100) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                  String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"current best analysis (piece " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
j String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"):"
                  (PVLeftmost SPitch -> IO ()) -> [PVLeftmost SPitch] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PVLeftmost SPitch -> IO ()
forall a. Show a => a -> IO ()
print [PVLeftmost SPitch]
deriv
                -- plotDeriv ("rl/deriv" <> show j <> ".tex") deriv
                QType -> IO QType
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure QType
acc
          pure $ zipWithStrict SL.Cons (SL.fromListReversed newAccs) accuracies
        else pure accuracies
    -- logging
    when ((i `mod` 1) == 0) $ do
      -- putStrLn $ "epoch " <> show i
      -- mapM_ print $ take 10 bcontent
      let rews = List QType -> [QType]
forall a. List a -> [a]
SL.toListReversed (List QType -> [QType]) -> [List QType] -> [[QType]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> List (List QType) -> [List QType]
forall a. List a -> [a]
SL.toListReversed List (List QType)
rewardsHist'
          accs = List QType -> [QType]
forall a. List a -> [a]
SL.toListReversed (List QType -> [QType]) -> [List QType] -> [[QType]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> List (List QType) -> [List QType]
forall a. List a -> [a]
SL.toListReversed List (List QType)
accuracies'
          losses = List QType -> [QType]
forall a. List a -> [a]
SL.toListReversed (List QType -> [QType]) -> [List QType] -> [[QType]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> List (List QType) -> [List QType]
forall a. List a -> [a]
SL.toListReversed List (List QType)
lossHist'
          avgReward = [QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean ([QType] -> QType) -> [[QType]] -> [QType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[QType]] -> [[QType]]
forall a. [[a]] -> [[a]]
L.transpose [[QType]]
rews
          avgAbsLoss = ([QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean ([QType] -> QType) -> ([QType] -> [QType]) -> [QType] -> QType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (QType -> QType) -> [QType] -> [QType]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap QType -> QType
forall a. Num a => a -> a
abs) ([QType] -> QType) -> [[QType]] -> [QType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[QType]] -> [[QType]]
forall a. [[a]] -> [[a]]
L.transpose [[QType]]
losses
      plotHistories "losses" losses
      case targets of
        Maybe [QType]
Nothing -> do
          String -> [[QType]] -> IO ()
plotHistories String
"rewards" [[QType]]
rews
          String -> [[QType]] -> IO ()
plotHistories String
"accuracy" [[QType]]
accs
        Just [QType]
ts -> do
          String -> [QType] -> [[QType]] -> IO ()
plotHistories' String
"rewards" [QType]
ts [[QType]]
rews
          String -> [QType] -> [[QType]] -> IO ()
plotHistories' String
"accuracy" [QType]
ts [[QType]]
accs
      plotHistory "mean_reward" avgReward
      plotHistory "mean_loss" avgAbsLoss
      -- print $ qModelFinal2 (a2cModel state)
      TT.save (TT.hmap' TT.ToDependent $ TT.flattenParameters $ a2cActor state) "actor_checkpoint.ht"
      TT.save (TT.hmap' TT.ToDependent $ TT.flattenParameters $ a2cCritic state) "critic_checkpoint.ht"
    pure $ A2CLoopState state' rewardsHist' lossHist' accuracies'