{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
{-# HLINT ignore "Use <$>" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

module RL.DQN where

import Common
import Display (replayDerivation, viewGraph)
import GreedyParser (Action, ActionDouble (ActionDouble), ActionSingle (ActionSingle), GreedyState, getActions, initParseState, parseGreedy, parseStep, pickRandom)
import PVGrammar (Edge, Edges (Edges), Freeze (FreezeOp), Notes (Notes), PVAnalysis, PVLeftmost, Split, Spread)
import PVGrammar.Generate (derivationPlayerPV)
import PVGrammar.Parse (protoVoiceEvaluator)
import PVGrammar.Prob.Simple (PVParams, evalDoubleStep, evalSingleStep, observeDerivation, observeDerivation', observeDoubleStepParsing, observeSingleStepParsing, sampleDerivation', sampleDoubleStepParsing, sampleSingleStepParsing)
import RL.Callbacks
import RL.Encoding
import RL.Model
import RL.ModelTypes
import RL.Plotting
import RL.ReplayBuffer
import RL.TorchHelpers qualified as TH

-- import Control.DeepSeq (force)
import Control.Exception (Exception, catch, onException)
import Control.Monad (foldM, foldM_, forM_, replicateM, when)
import Control.Monad.Except qualified as ET
import Control.Monad.Primitive (RealWorld)
import Control.Monad.State qualified as ST
import Control.Monad.Trans (lift)
import Data.Foldable qualified as F
import Data.List.Extra qualified as E
import Data.Vector qualified as V
import Debug.Trace qualified as DT
import GHC.Float (double2Float)
import Inference.Conjugate (Hyper, HyperRep, Prior (expectedProbs), evalTraceLogP, printTrace, sampleProbs)
import Musicology.Pitch
import System.Random.MWC.Distributions (categorical)
import System.Random.MWC.Probability qualified as MWC
import System.Random.Stateful as Rand (StatefulGen, UniformRange (uniformRM), split)
import Torch qualified as T
import Torch.HList qualified as TT
import Torch.Lens qualified
import Torch.Typed qualified as TT

-- Notes
-- -----

{-
Idee: Variant of Q-learning:
- instead of Q value (expected total reward) under optimal policy
  learn "P value": expected probability under random policy
- does this lead to a policy where p(as) ∝ reward?
  - then you learn a method of sampling from the reward distribution
  - if reward is a probability (e.g. p(deriv)), you learn to sample from that!
    - useful for unsupervised inference
- changes:
  - use proportional random policy (is this MC-tree-search?)
  - loss uses E[] instead of max over next actions.
-}

-- global settings
-- ---------------

-- discount factor
gamma :: (TT.KnownDevice dev) => QTensor dev '[]
-- gamma = toOpts $ T.asTensor @Double 0.99
gamma :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[]
gamma = QTensor dev '[]
0.99

-- interpolation factor between target and policy net
tau :: QType -- T.Tensor -- QTensor '[]
-- tau = toOpts $ T.asTensor @Double 0.05
tau :: QType
tau = QType
0.1

learningRate :: (IsValidDevice dev) => Double -> TT.LearningRate dev QDType
-- learningRate _ = 0.1
-- learningRate progress = 0.01 + TT.mulScalar progress (-0.009)
learningRate :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QType -> LearningRate dev QDType
learningRate QType
progress = LearningRate dev QDType
0.1 LearningRate dev QDType
-> LearningRate dev QDType -> LearningRate dev QDType
forall a. Num a => a -> a -> a
* LearningRate dev QDType -> LearningRate dev QDType
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.exp (QType -> LearningRate dev QDType -> LearningRate dev QDType
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
progress (LearningRate dev QDType -> LearningRate dev QDType
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log LearningRate dev QDType
0.1))

-- replay buffer
bufferSize :: Int
bufferSize :: Int
bufferSize = Int
1_000

replayN :: Int
replayN :: Int
replayN = Int
200

-- exploration factors
epsStart :: QType
epsStart :: QType
epsStart = QType
0.9

epsEnd :: QType
epsEnd :: QType
epsEnd = QType
0.2

-- epsDecay :: QType
-- epsDecay = 2

eps :: Int -> Int -> QType
eps :: Int -> Int -> QType
eps Int
i Int
n = QType -> QType -> QType -> QType -> QType
expSchedule QType
epsStart QType
epsEnd (Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) (Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)

-- device = T.Device T.CPU 0

-- Deep Q-Learning
-- ---------------

data DQNState dev opt tr tr' slc s f h r = DQNState
  { forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
       (r :: k).
DQNState dev opt tr tr' slc s f h r -> QModel dev
pnet :: !(QModel dev)
  , forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
       (r :: k).
DQNState dev opt tr tr' slc s f h r -> QModel dev
tnet :: !(QModel dev)
  , forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
       (r :: k).
DQNState dev opt tr tr' slc s f h r -> opt
opt :: !opt
  , forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
       (r :: k).
DQNState dev opt tr tr' slc s f h r
-> ReplayBuffer dev tr tr' slc s f h
buffer :: !(ReplayBuffer dev tr tr' slc s f h)
  }

-- epsilonGreedyPolicy
--   :: (StatefulGen gen m)
--   => gen
--   -> QType
--   -> (embedding -> QTensor '[1])
--   -> [embedding]
--   -> m Int
-- epsilonGreedyPolicy gen epsilon q actions = do
--   coin <- uniformRM (0, 1) gen
--   if coin >= epsilon
--     then pure $ T.asValue $ T.argmax (T.Dim 0) T.RemoveDim $ T.cat (T.Dim 0) (TT.toDynamic . q <$> actions)
--     else do
--       uniformRM (0, length actions - 1) gen

greedyPolicy
  :: (Applicative m)
  => (embedding -> QTensor dev '[1])
  -> [embedding]
  -> m Int
greedyPolicy :: forall (m :: * -> *) embedding (dev :: (DeviceType, Nat)).
Applicative m =>
(embedding -> QTensor dev '[1]) -> [embedding] -> m Int
greedyPolicy embedding -> QTensor dev '[1]
q [embedding]
actions = do
  Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ Tensor -> Int
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> Int) -> Tensor -> Int
forall a b. (a -> b) -> a -> b
$ Dim -> KeepDim -> Tensor -> Tensor
T.argmax (Int -> Dim
T.Dim Int
0) KeepDim
T.RemoveDim (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) (QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor)
-> (embedding -> QTensor dev '[1]) -> embedding -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. embedding -> QTensor dev '[1]
q (embedding -> Tensor) -> [embedding] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [embedding]
actions)

epsilonic
  :: (StatefulGen gen m)
  => gen
  -> QType
  -> ([embedding] -> m Int)
  -> [embedding]
  -> m Int
epsilonic :: forall gen (m :: * -> *) embedding.
StatefulGen gen m =>
gen -> QType -> ([embedding] -> m Int) -> [embedding] -> m Int
epsilonic gen
gen QType
epsilon [embedding] -> m Int
policy [embedding]
actions = do
  coin <- (QType, QType) -> gen -> m QType
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
(QType, QType) -> g -> m QType
uniformRM (QType
0, QType
1) gen
gen
  if coin >= epsilon
    then policy actions
    else uniformRM (0, length actions - 1) gen

softmaxPolicy
  :: (StatefulGen gen m)
  => gen
  -> (embedding -> QTensor dev '[1])
  -> [embedding]
  -> m Int
softmaxPolicy :: forall gen (m :: * -> *) embedding (dev :: (DeviceType, Nat)).
StatefulGen gen m =>
gen -> (embedding -> QTensor dev '[1]) -> [embedding] -> m Int
softmaxPolicy gen
gen embedding -> QTensor dev '[1]
q [embedding]
actions = do
  let probs :: Tensor
probs = Dim -> Tensor -> Tensor
T.softmax (Int -> Dim
T.Dim Int
0) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) ([Tensor] -> Tensor) -> [Tensor] -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor)
-> (embedding -> QTensor dev '[1]) -> embedding -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. embedding -> QTensor dev '[1]
q (embedding -> Tensor) -> [embedding] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [embedding]
actions
  Vector QType -> gen -> m Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v QType) =>
v QType -> g -> m Int
categorical ([QType] -> Vector QType
forall a. [a] -> Vector a
V.fromList ([QType] -> Vector QType) -> [QType] -> Vector QType
forall a b. (a -> b) -> a -> b
$ Tensor -> [QType]
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> [QType]) -> Tensor -> [QType]
forall a b. (a -> b) -> a -> b
$ DType -> Tensor -> Tensor
T.toDType DType
T.Double Tensor
probs) gen
gen

runEpisode
  :: forall dev tr tr' slc slc' s f h gen state action encoding step
   . ( state ~ GreedyState tr tr' slc (Leftmost s f h)
     , action ~ Action slc tr s f h
     , encoding ~ QEncoding dev '[]
     , step ~ (state, action, encoding, Maybe (state, [encoding]), Maybe Bool)
     )
  => Eval tr tr' slc slc' h (Leftmost s f h)
  -> (state -> action -> encoding)
  -> ([encoding] -> IO Int)
  -> Path slc' tr'
  -> IO
      ( Either
          String
          ([step], Analysis s f h tr slc)
      )
runEpisode :: forall {k} (dev :: (DeviceType, Nat)) tr tr' slc slc' s f h
       (gen :: k) state action encoding step.
(state ~ GreedyState tr tr' slc (Leftmost s f h),
 action ~ Action slc tr s f h, encoding ~ QEncoding dev '[],
 step
 ~ (state, action, encoding, Maybe (state, [encoding]),
    Maybe Bool)) =>
Eval tr tr' slc slc' h (Leftmost s f h)
-> (state -> action -> encoding)
-> ([encoding] -> IO Int)
-> Path slc' tr'
-> IO (Either String ([step], Analysis s f h tr slc))
runEpisode !Eval tr tr' slc slc' h (Leftmost s f h)
eval !state -> action -> encoding
encode ![encoding] -> IO Int
policyF !Path slc' tr'
input =
  StateT
  (Maybe (action, encoding), [encoding])
  IO
  (Either String ([step], Analysis s f h tr slc))
-> (Maybe (action, encoding), [encoding])
-> IO (Either String ([step], Analysis s f h tr slc))
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
ST.evalStateT (ExceptT
  String
  (StateT (Maybe (action, encoding), [encoding]) IO)
  ([step], Analysis s f h tr slc)
-> StateT
     (Maybe (action, encoding), [encoding])
     IO
     (Either String ([step], Analysis s f h tr slc))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
ET.runExceptT (ExceptT
   String
   (StateT (Maybe (action, encoding), [encoding]) IO)
   ([step], Analysis s f h tr slc)
 -> StateT
      (Maybe (action, encoding), [encoding])
      IO
      (Either String ([step], Analysis s f h tr slc)))
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     ([step], Analysis s f h tr slc)
-> StateT
     (Maybe (action, encoding), [encoding])
     IO
     (Either String ([step], Analysis s f h tr slc))
forall a b. (a -> b) -> a -> b
$ [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     ([(state, action, encoding, Maybe (state, [encoding]),
        Maybe Bool)],
      Analysis s f h tr slc)
go [] Maybe (state, Maybe (action, encoding), Maybe Bool)
forall a. Maybe a
Nothing (GreedyState tr tr' slc (Leftmost s f h)
 -> ExceptT
      String
      (StateT (Maybe (action, encoding), [encoding]) IO)
      ([step], Analysis s f h tr slc))
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     ([step], Analysis s f h tr slc)
forall a b. (a -> b) -> a -> b
$ Eval tr tr' slc slc' h (Leftmost s f h)
-> Path slc' tr' -> GreedyState tr tr' slc (Leftmost s f h)
forall tr tr' slc slc' h v op.
Eval tr tr' slc slc' h v
-> Path slc' tr' -> GreedyState tr tr' slc op
initParseState Eval tr tr' slc slc' h (Leftmost s f h)
eval Path slc' tr'
input) (Maybe (action, encoding)
forall a. Maybe a
Nothing, [])
 where
  -- go :: [step] -> (state, Maybe (action, encoding), Maybe Bool) -> state -> [step]
  go :: [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     ([(state, action, encoding, Maybe (state, [encoding]),
        Maybe Bool)],
      Analysis s f h tr slc)
go [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions Maybe (state, Maybe (action, encoding), Maybe Bool)
prev state
state = do
    -- run step
    (Maybe (Action slc tr s f h, QEncoding dev '[]),
 [QEncoding dev '[]])
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put (Maybe (Action slc tr s f h, QEncoding dev '[])
forall a. Maybe a
Nothing, []) -- TODO: have parseStep return the action instead of using State
    result <- Eval tr tr' slc slc' h (Leftmost s f h)
-> ([Action slc tr s f h]
    -> ExceptT
         String
         (StateT (Maybe (action, encoding), [encoding]) IO)
         (Action slc tr s f h))
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     (Either
        (GreedyState tr tr' slc (Leftmost s f h)) (tr, [Leftmost s f h]))
forall (m :: * -> *) tr tr' slc slc' s f h.
Monad m =>
Eval tr tr' slc slc' h (Leftmost s f h)
-> ([Action slc tr s f h]
    -> ExceptT String m (Action slc tr s f h))
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
     String
     m
     (Either
        (GreedyState tr tr' slc (Leftmost s f h)) (tr, [Leftmost s f h]))
parseStep Eval tr tr' slc slc' h (Leftmost s f h)
eval [action]
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) action
[Action slc tr s f h]
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     (Action slc tr s f h)
policy state
GreedyState tr tr' slc (Leftmost s f h)
state
    (actionAndEncoding, actions) <- ST.get
    -- add previous step if it exists
    let transitions' = case Maybe (state, Maybe (action, encoding), Maybe Bool)
prev of
          Maybe (state, Maybe (action, encoding), Maybe Bool)
Nothing -> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions
          Just (state
prevState, Maybe (action, encoding)
prevAction, Maybe Bool
goLeft) ->
            state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
addStep state
prevState Maybe (action, encoding)
prevAction ((state, [encoding]) -> Maybe (state, [encoding])
forall a. a -> Maybe a
Just (state
state, [encoding]
actions)) Maybe Bool
goLeft [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions
    -- get previous "continueLeft" decision from previous action
    let goLeft' = case Maybe (state, Maybe (action, encoding), Maybe Bool)
prev of
          Just (state
_, Just (Right (ActionDouble DoubleParent slc tr
_ LeftmostDouble s f h
op), encoding
_), Maybe Bool
_) -> case LeftmostDouble s f h
op of
            LMDoubleFreezeLeft f
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
            LMDoubleSplitLeft s
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
            LeftmostDouble s f h
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
          Maybe (state, Maybe (action, encoding), Maybe Bool)
_ -> Maybe Bool
forall a. Maybe a
Nothing
    -- evaluate current step
    case result of
      -- done parsing
      Right (tr
top, [Leftmost s f h]
deriv) ->
        ([(state, action, encoding, Maybe (state, [encoding]),
   Maybe Bool)],
 Analysis s f h tr slc)
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     ([(state, action, encoding, Maybe (state, [encoding]),
        Maybe Bool)],
      Analysis s f h tr slc)
forall a.
a
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
addStep state
state Maybe (action, encoding)
actionAndEncoding Maybe (state, [encoding])
forall a. Maybe a
Nothing Maybe Bool
goLeft' [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions', [Leftmost s f h] -> Path tr slc -> Analysis s f h tr slc
forall s f h tr slc.
[Leftmost s f h] -> Path tr slc -> Analysis s f h tr slc
Analysis [Leftmost s f h]
deriv (Path tr slc -> Analysis s f h tr slc)
-> Path tr slc -> Analysis s f h tr slc
forall a b. (a -> b) -> a -> b
$ tr -> Path tr slc
forall around between. around -> Path around between
PathEnd tr
top)
      -- continue parsing
      Left GreedyState tr tr' slc (Leftmost s f h)
state' -> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
     String
     (StateT (Maybe (action, encoding), [encoding]) IO)
     ([(state, action, encoding, Maybe (state, [encoding]),
        Maybe Bool)],
      Analysis s f h tr slc)
go [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions' ((state, Maybe (action, encoding), Maybe Bool)
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
forall a. a -> Maybe a
Just (state
state, Maybe (action, encoding)
actionAndEncoding, Maybe Bool
goLeft')) GreedyState tr tr' slc (Leftmost s f h)
state'
   where
    addStep
      :: state
      -> Maybe (action, encoding)
      -> Maybe (state, [encoding])
      -> Maybe Bool
      -> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
      -> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
    addStep :: state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
addStep state
state Maybe (action, encoding)
Nothing Maybe (state, [encoding])
_next Maybe Bool
_goLeft [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts = [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts
    addStep state
state (Just (action
action, encoding
actEnc)) Maybe (state, [encoding])
next Maybe Bool
goLeft [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts = (state
state, action
action, encoding
actEnc, Maybe (state, [encoding])
next, Maybe Bool
goLeft) (state, action, encoding, Maybe (state, [encoding]), Maybe Bool)
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
     Maybe Bool)]
forall a. a -> [a] -> [a]
: [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts

    policy :: [action] -> ET.ExceptT String (ST.StateT (Maybe (action, encoding), [encoding]) IO) action
    policy :: [action]
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) action
policy [] = String
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) action
forall a.
String
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
ET.throwError String
"no actions to select from"
    policy [action]
actions = do
      let encodings :: [encoding]
encodings = state -> action -> encoding
encode state
state (action -> encoding) -> [action] -> [encoding]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [action]
actions
      actionIndex <- StateT (Maybe (action, encoding), [encoding]) IO Int
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) Int
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT (Maybe (action, encoding), [encoding]) IO Int
 -> ExceptT
      String (StateT (Maybe (action, encoding), [encoding]) IO) Int)
-> StateT (Maybe (action, encoding), [encoding]) IO Int
-> ExceptT
     String (StateT (Maybe (action, encoding), [encoding]) IO) Int
forall a b. (a -> b) -> a -> b
$ IO Int -> StateT (Maybe (action, encoding), [encoding]) IO Int
forall (m :: * -> *) a.
Monad m =>
m a -> StateT (Maybe (action, encoding), [encoding]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Int -> StateT (Maybe (action, encoding), [encoding]) IO Int)
-> IO Int -> StateT (Maybe (action, encoding), [encoding]) IO Int
forall a b. (a -> b) -> a -> b
$ [encoding] -> IO Int
policyF [encoding]
encodings
      let action = [action]
actions [action] -> Int -> action
forall a. HasCallStack => [a] -> Int -> a
!! Int
actionIndex
      ST.put (Just (actions !! actionIndex, encodings !! actionIndex), encodings)
      pure action

trainLoop
  :: forall dev tr tr' slc slc' s f h gen opt -- params (grads :: [Type])
   . -- . ( StatefulGen gen IO
  --   , params ~ TT.Parameters QModel
  --   , TT.HasGrad (TT.HList params) (TT.HList grads)
  --   , TT.Optimizer opt grads grads T.Double QDevice
  --   , TT.HMap' TT.ToDependent params grads
  --   , TT.HFoldrM IO TT.TensorListFold [T.ATenTensor] grads [T.ATenTensor]
  --   , TT.Apply TT.TensorListUnfold [T.ATenTensor] (TT.HUnfoldMRes IO [T.ATenTensor] grads)
  --   , TT.HUnfoldM IO TT.TensorListUnfold (TT.HUnfoldMRes IO [T.ATenTensor] grads) grads
  --   -- , Show opt
  --   )
  (_)
  => gen
  -> Eval tr tr' slc slc' h (Leftmost s f h)
  -> (GreedyState tr tr' slc (Leftmost s f h) -> Action slc tr s f h -> QEncoding dev '[])
  -> (Analysis s f h tr slc -> IO QType)
  -> (Action slc tr s f h -> Maybe Bool -> IO QType)
  -> Path slc' tr'
  -> DQNState dev opt tr tr' slc s f h QType
  -> Int
  -> Int
  -> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
trainLoop :: gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
    -> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState dev opt tr tr' slc s f h QType
-> Int
-> Int
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
trainLoop !gen
gen !Eval tr tr' slc slc' h (Leftmost s f h)
eval !GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode !Analysis s f h tr slc -> IO QType
reward !Action slc tr s f h -> Maybe Bool -> IO QType
rewardStep !Path slc' tr'
piece oldstate :: DQNState dev opt tr tr' slc s f h QType
oldstate@(DQNState !QModel dev
pnet !QModel dev
tnet !opt
opt !ReplayBuffer dev tr tr' slc s f h
buffer) Int
i Int
n = do
  -- 1. run episode, collect results
  -- let policy q = epsilonic gen (eps i n) (greedyPolicy q)
  -- let policy = softmaxPolicy gen
  let policy :: (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> IO Int
policy QEncoding dev '[] -> QTensor dev '[1]
q = gen
-> QType
-> ([QEncoding dev '[]] -> IO Int)
-> [QEncoding dev '[]]
-> IO Int
forall gen (m :: * -> *) embedding.
StatefulGen gen m =>
gen -> QType -> ([embedding] -> m Int) -> [embedding] -> m Int
epsilonic gen
gen (Int -> Int -> QType
eps Int
i Int
n) (([QEncoding dev '[]] -> IO Int) -> [QEncoding dev '[]] -> IO Int)
-> ([QEncoding dev '[]] -> IO Int) -> [QEncoding dev '[]] -> IO Int
forall a b. (a -> b) -> a -> b
$ gen
-> (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]]
-> IO Int
forall gen (m :: * -> *) embedding (dev :: (DeviceType, Nat)).
StatefulGen gen m =>
gen -> (embedding -> QTensor dev '[1]) -> [embedding] -> m Int
softmaxPolicy gen
gen QEncoding dev '[] -> QTensor dev '[1]
q
  result <- Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
    -> Action slc tr s f h -> QEncoding dev '[])
-> ([QEncoding dev '[]] -> IO Int)
-> Path slc' tr'
-> IO
     (Either
        String
        ([(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
           QEncoding dev '[],
           Maybe
             (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
           Maybe Bool)],
         Analysis s f h tr slc))
forall {k} (dev :: (DeviceType, Nat)) tr tr' slc slc' s f h
       (gen :: k) state action encoding step.
(state ~ GreedyState tr tr' slc (Leftmost s f h),
 action ~ Action slc tr s f h, encoding ~ QEncoding dev '[],
 step
 ~ (state, action, encoding, Maybe (state, [encoding]),
    Maybe Bool)) =>
Eval tr tr' slc slc' h (Leftmost s f h)
-> (state -> action -> encoding)
-> ([encoding] -> IO Int)
-> Path slc' tr'
-> IO (Either String ([step], Analysis s f h tr slc))
runEpisode Eval tr tr' slc slc' h (Leftmost s f h)
eval GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode ((QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> IO Int
policy ((QEncoding dev '[] -> QTensor dev '[1])
 -> [QEncoding dev '[]] -> IO Int)
-> (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]]
-> IO Int
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
pnet) Path slc' tr'
piece
  case result of
    -- error? skip
    Left String
error -> do
      String -> IO ()
forall a. Show a => a -> IO ()
print String
error
      (DQNState dev opt tr tr' slc s f h QType, QType, QType)
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DQNState dev opt tr tr' slc s f h QType
oldstate, QType
0, QType
0)
    Right ([(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
steps, Analysis s f h tr slc
analysis) -> do
      -- 2. compute reward and add steps to replay buffer
      (steps', r) <- [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
-> Analysis s f h tr slc
-> IO ([ReplayStep dev tr tr' slc s f h], QType)
rewardEpisode [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
steps Analysis s f h tr slc
analysis
      -- rall <- reward analysis
      -- putStrLn $ "total episode reward: " <> show r
      -- putStrLn $ "hypothetical reward: " <> show rall
      -- mapM_ print (anaDerivation analysis)
      -- mapM_ print steps'
      let buffer' = (ReplayBuffer dev tr tr' slc s f h
 -> ReplayStep dev tr tr' slc s f h
 -> ReplayBuffer dev tr tr' slc s f h)
-> ReplayBuffer dev tr tr' slc s f h
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
pushStep ReplayBuffer dev tr tr' slc s f h
buffer [ReplayStep dev tr tr' slc s f h]
steps'
      -- 3. optimize models
      (pnet', tnet', opt', loss) <- optimizeModels buffer'
      pure (DQNState pnet' tnet' opt' buffer', r, loss)
 where
  mkReplay :: QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
    QEncoding dev '[],
    Maybe
      (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
    e)
-> ReplayStep dev tr tr' slc s f h
mkReplay QType
r (!GreedyState tr tr' slc (Leftmost s f h)
state, !Action slc tr s f h
action, !QEncoding dev '[]
actEnc, !Maybe
  (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]])
next, e
goLeft) =
    RPState tr tr' slc s f h
-> RPAction slc tr s f h
-> QEncoding dev '[]
-> Maybe (RPState tr tr' slc s f h)
-> [QEncoding dev '[]]
-> QType
-> ReplayStep dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
RPState tr tr' slc s f h
-> RPAction slc tr s f h
-> QEncoding dev '[]
-> Maybe (RPState tr tr' slc s f h)
-> [QEncoding dev '[]]
-> QType
-> ReplayStep dev tr tr' slc s f h
ReplayStep (GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
forall tr tr' slc s f h.
GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
RPState GreedyState tr tr' slc (Leftmost s f h)
state) (Action slc tr s f h -> RPAction slc tr s f h
forall slc tr s f h. Action slc tr s f h -> RPAction slc tr s f h
RPAction Action slc tr s f h
action) QEncoding dev '[]
actEnc Maybe (RPState tr tr' slc s f h)
state' [QEncoding dev '[]]
steps' QType
r
   where
    (!Maybe (RPState tr tr' slc s f h)
state', ![QEncoding dev '[]]
steps') = case Maybe
  (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]])
next of
      Maybe
  (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]])
Nothing -> (Maybe (RPState tr tr' slc s f h)
forall a. Maybe a
Nothing, [])
      Just (GreedyState tr tr' slc (Leftmost s f h)
s, [QEncoding dev '[]]
acts) -> (RPState tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h)
forall a. a -> Maybe a
Just (RPState tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h))
-> RPState tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h)
forall a b. (a -> b) -> a -> b
$ GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
forall tr tr' slc s f h.
GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
RPState GreedyState tr tr' slc (Leftmost s f h)
s, [QEncoding dev '[]]
acts)
  rewardEpisode :: [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
-> Analysis s f h tr slc
-> IO ([ReplayStep dev tr tr' slc s f h], QType)
rewardEpisode [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
steps Analysis s f h tr slc
analysis = do
    r <- Analysis s f h tr slc -> IO QType
reward Analysis s f h tr slc
analysis
    let steps' = case [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
steps of
          [] -> []
          (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
 QEncoding dev '[],
 Maybe
   (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
 Maybe Bool)
last : [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
rest -> QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
    QEncoding dev '[],
    Maybe
      (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
    Maybe Bool)
-> ReplayStep dev tr tr' slc s f h
forall {tr} {tr'} {slc} {s} {f} {h} {dev :: (DeviceType, Nat)} {e}.
QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
    QEncoding dev '[],
    Maybe
      (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
    e)
-> ReplayStep dev tr tr' slc s f h
mkReplay QType
r (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
 QEncoding dev '[],
 Maybe
   (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
 Maybe Bool)
last ReplayStep dev tr tr' slc s f h
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. a -> [a] -> [a]
: ((GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)
 -> ReplayStep dev tr tr' slc s f h)
-> [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
     QEncoding dev '[],
     Maybe
       (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
     Maybe Bool)]
-> [ReplayStep dev tr tr' slc s f h]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
    QEncoding dev '[],
    Maybe
      (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
    Maybe Bool)
-> ReplayStep dev tr tr' slc s f h
forall {tr} {tr'} {slc} {s} {f} {h} {dev :: (DeviceType, Nat)} {e}.
QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
    QEncoding dev '[],
    Maybe
      (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
    e)
-> ReplayStep dev tr tr' slc s f h
mkReplay QType
0) [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
  QEncoding dev '[],
  Maybe
    (GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
  Maybe Bool)]
rest
    pure (steps', r)
  -- rewardSteps steps _analysis = do
  --   steps' <- mapM mkStep steps
  --   let r = sum $ replayReward <$> steps'
  --   pure (steps', r)
  --  where
  --   mkStep step@(_, action, _, _, goLeft) = do
  --     rstep <- rewardStep action goLeft
  --     pure $ mkReplay rstep step

  -- A single optimization step for deep q learning (DQN)
  optimizeModels :: ReplayBuffer dev tr tr' slc s f h
-> IO (QModel dev, QModel dev, opt, QType)
optimizeModels ReplayBuffer dev tr tr' slc s f h
buffer' = do
    -- choose batch from replay buffer
    batch <- ReplayBuffer dev tr tr' slc s f h
-> Int -> IO [ReplayStep dev tr tr' slc s f h]
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> Int -> IO [ReplayStep dev tr tr' slc s f h]
sampleSteps ReplayBuffer dev tr tr' slc s f h
buffer' Int
replayN
    -- compute loss over batch
    let (qsNow, qsExpected) = unzip (dqnValues <$> batch)
    expectedDetached <- T.detach $ T.stack (T.Dim 0) qsExpected
    let !loss =
          Reduction -> Tensor -> Tensor -> Tensor
T.smoothL1Loss
            Reduction
T.ReduceMean
            (Dim -> [Tensor] -> Tensor
T.stack (Int -> Dim
T.Dim Int
0) [Tensor]
qsNow)
            Tensor
expectedDetached
        !lossWithFake = Tensor -> Tensor device QDType '[]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
       (shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor device QDType '[])
-> Tensor -> Tensor device QDType '[]
forall a b. (a -> b) -> a -> b
$ Tensor
loss Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ QTensor dev '[] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QModel dev -> QTensor dev '[]
forall (dev :: (DeviceType, Nat)) (ps :: [*]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
pnet)
    -- print loss
    -- optimize policy net
    putStr $ "loss: " <> show (T.asValue @QType $ TT.toDynamic lossWithFake)
    putStrLn $ "\tavgq: " <> show (T.asValue @QType $ T.mean $ T.stack (T.Dim 0) qsNow)
    -- let params = TT.flattenParameters pnet
    --     grads = TT.grad lossWithFake params
    let lr = QType -> Tensor device QDType '[]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QType -> LearningRate dev QDType
learningRate (QType -> Tensor device QDType '[])
-> QType -> Tensor device QDType '[]
forall a b. (a -> b) -> a -> b
$ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i QType -> QType -> QType
forall a. Fractional a => a -> a -> a
/ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
    (pnet', opt') <- TT.runStep pnet opt lossWithFake lr
    -- update target net
    tparams <- TT.hmapM' TH.Detach $ TT.hmap' TT.ToDependent $ TT.flattenParameters tnet
    pparams <- TT.hmapM' TH.Detach $ TT.hmap' TT.ToDependent $ TT.flattenParameters pnet'
    let tparams' = Interpolate QType
-> HList
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], QTensor dev '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], QTensor dev '[1]]
-> HList
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], QTensor dev '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], QTensor dev '[1]]
-> HList
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], QTensor dev '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], QTensor dev '[1]]
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
TT.hzipWith (QType -> Interpolate QType
forall num. num -> Interpolate num
TH.Interpolate QType
tau) HList
  '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
    Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8],
    Tensor dev QDType '[1, 8], QTensor dev '[1],
    Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8],
    Tensor dev QDType '[1, 8], QTensor dev '[1]]
pparams HList
  '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
    Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8],
    Tensor dev QDType '[1, 8], QTensor dev '[1],
    Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8],
    Tensor dev QDType '[1, 8], QTensor dev '[1]]
tparams
    tparamsNew <- TT.hmapM' TT.MakeIndependent tparams'
    let tnet' = QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
forall f. Parameterized f => f -> HList (Parameters f) -> f
TT.replaceParameters QModel dev
tnet HList
  '[Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 13, 5], Parameter dev QDType '[8, 13, 5],
    Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
    Parameter dev QDType '[8], Parameter dev QDType '[5],
    Parameter dev QDType '[5], Parameter dev QDType '[5],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
    Parameter dev QDType '[8], Parameter dev QDType '[8],
    Parameter dev QDType '[1, 8], Parameter dev QDType '[1],
    Parameter dev QDType '[8, 8], Parameter dev QDType '[8],
    Parameter dev QDType '[8], Parameter dev QDType '[8],
    Parameter dev QDType '[1, 8], Parameter dev QDType '[1]]
HList (Parameters (QModel dev))
tparamsNew
    -- return new state
    pure (pnet', tnet', opt', T.asValue loss)

  -- The loss function of a single replay step
  dqnValues :: ReplayStep dev tr tr' slc s f h -> (T.Tensor, T.Tensor) -- (QTensor '[1], QTensor '[1])
  dqnValues :: ReplayStep dev tr tr' slc s f h -> (Tensor, Tensor)
dqnValues (ReplayStep RPState tr tr' slc s f h
_ RPAction slc tr s f h
_ QEncoding dev '[]
step0Enc Maybe (RPState tr tr' slc s f h)
s' [QEncoding dev '[]]
step1Encs QType
r) = (Tensor
qnow, Tensor
qexpected)
   where
    qzero :: QTensor dev '[1]
qzero = QTensor dev '[1]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
    qnext :: QTensor dev '[1]
qnext = case Maybe (RPState tr tr' slc s f h)
s' of
      Maybe (RPState tr tr' slc s f h)
Nothing -> QTensor dev '[1]
qzero
      Just (RPState GreedyState tr tr' slc (Leftmost s f h)
state') ->
        let
          nextQs :: [QTensor dev '[1]]
          nextQs :: [QTensor dev '[1]]
nextQs = QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
TT.forward QModel dev
tnet (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> [QTensor dev '[1]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [QEncoding dev '[]]
step1Encs
          -- nextQs = runQ' encode tnet state' <$> getActions eval state'
          toVal :: QTensor dev '[1] -> QType
          toVal :: QTensor dev '[1] -> QType
toVal = forall a. TensorLike a => Tensor -> a
T.asValue @QType (Tensor -> QType)
-> (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> QType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic
         in
          (QTensor dev '[1] -> QType)
-> [QTensor dev '[1]] -> QTensor dev '[1]
forall b a. (HasCallStack, Ord b) => (a -> b) -> [a] -> a
E.maximumOn QTensor dev '[1] -> QType
toVal [QTensor dev '[1]]
nextQs
    qnow :: Tensor
qnow = QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> Tensor
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
pnet QEncoding dev '[]
step0Enc -- (encode s a)
    qexpected :: Tensor
qexpected = QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> Tensor
forall a b. (a -> b) -> a -> b
$ QType -> QTensor dev '[1] -> QTensor dev '[1]
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.addScalar QType
r (QTensor dev '[]
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[]
gamma QTensor dev '[] -> QTensor dev '[1] -> QTensor dev '[1]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.mul` QTensor dev '[1]
qnext)

-- delta = qnow - qexpected

trainDQN
  :: forall dev gen tr tr' slc slc' s f h
   . ( IsValidDevice dev
     , StatefulGen gen IO
     , Show s
     , Show f
     , Show h
     , s ~ Split SPitch -- TODO: keep fully open or specialize
     , f ~ Freeze SPitch
     , h ~ Spread SPitch
     , Show slc
     , Show tr
     )
  => gen
  -> Eval tr tr' slc slc' h (Leftmost s f h)
  -> (GreedyState tr tr' slc (Leftmost s f h) -> Action slc tr s f h -> QEncoding dev '[])
  -> (Analysis s f h tr slc -> IO QType)
  -> (Action slc tr s f h -> Maybe Bool -> IO QType)
  -> [Path slc' tr']
  -> Int
  -> IO ([QType], [QType], QModel dev)
trainDQN :: forall (dev :: (DeviceType, Nat)) gen tr tr' slc slc' s f h.
(IsValidDevice dev, StatefulGen gen IO, Show s, Show f, Show h,
 s ~ Split SPitch, f ~ Freeze SPitch, h ~ Spread SPitch, Show slc,
 Show tr) =>
gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
    -> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> [Path slc' tr']
-> Int
-> IO ([QType], [QType], QModel dev)
trainDQN gen
gen Eval tr tr' slc slc' h (Leftmost s f h)
eval GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode Analysis s f h tr slc -> IO QType
reward Action slc tr s f h -> Maybe Bool -> IO QType
rewardStep [Path slc' tr']
pieces Int
n = do
  model0 <- IO (QModel dev)
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
IO (QModel dev)
mkQModel
  let opt = AdamIter
-> Float
-> Float
-> HList
     '[Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 13, 5], Parameter dev QDType '[8, 13, 5],
       Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
       Parameter dev QDType '[8], Parameter dev QDType '[5],
       Parameter dev QDType '[5], Parameter dev QDType '[5],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
       Parameter dev QDType '[8], Parameter dev QDType '[8],
       Parameter dev QDType '[1, 8], Parameter dev QDType '[1],
       Parameter dev QDType '[8, 8], Parameter dev QDType '[8],
       Parameter dev QDType '[8], Parameter dev QDType '[8],
       Parameter dev QDType '[1, 8], Parameter dev QDType '[1]]
-> Adam
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1]]
forall (parameters :: [*]) (momenta :: [*]).
HMap' ZerosLike parameters momenta =>
AdamIter -> Float -> Float -> HList parameters -> Adam momenta
TT.mkAdam AdamIter
0 Float
0.9 Float
0.99 (QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model0) -- T.GD
      buffer = Int -> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int -> ReplayBuffer dev tr tr' slc s f h
mkReplayBuffer Int
bufferSize
      state0 = QModel dev
-> QModel dev
-> Adam
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1]]
-> ReplayBuffer dev tr tr' slc s f h
-> DQNState
     dev
     (Adam
        '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
          Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[5],
          Tensor dev QDType '[5], Tensor dev QDType '[5],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8], Tensor dev QDType '[8],
          Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
          Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
          Tensor dev QDType '[8], Tensor dev QDType '[8],
          Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
     tr
     tr'
     slc
     s
     f
     h
     QType
forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
       (r :: k).
QModel dev
-> QModel dev
-> opt
-> ReplayBuffer dev tr tr' slc s f h
-> DQNState dev opt tr tr' slc s f h r
DQNState QModel dev
model0 QModel dev
model0 Adam
  '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
    Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
    Tensor dev QDType '[8], Tensor dev QDType '[5],
    Tensor dev QDType '[5], Tensor dev QDType '[5],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8],
    Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
    Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
    Tensor dev QDType '[8], Tensor dev QDType '[8],
    Tensor dev QDType '[1, 8], Tensor dev QDType '[1]]
opt ReplayBuffer dev tr tr' slc s f h
forall {dev :: (DeviceType, Nat)} {tr} {tr'} {slc} {s} {f} {h}.
ReplayBuffer dev tr tr' slc s f h
buffer
  (DQNState modelTrained _ _ _, rewards, losses, accs) <- T.foldLoop (state0, [], [], []) n trainEpoch
  pure (reverse rewards, reverse losses, modelTrained) -- (modelTrained, rewards)
 where
  trainPiece :: Int
-> (DQNState
      dev
      (Adam
         '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
           Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[5],
           Tensor dev QDType '[5], Tensor dev QDType '[5],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8],
           Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
           Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8],
           Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
      tr
      tr'
      slc
      s
      f
      h
      QType,
    [QType], [QType])
-> Path slc' tr'
-> IO
     (DQNState
        dev
        (Adam
           '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[5],
             Tensor dev QDType '[5], Tensor dev QDType '[5],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
             Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
        tr
        tr'
        slc
        s
        f
        h
        QType,
      [QType], [QType])
trainPiece Int
i (DQNState
  dev
  (Adam
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
  tr
  tr'
  slc
  s
  f
  h
  QType
state, [QType]
rewards, [QType]
losses) Path slc' tr'
piece = do
    (state', r, loss) <- gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
    -> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState
     dev
     (Adam
        '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
          Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
          Tensor dev QDType '[8], Tensor dev QDType '[5],
          Tensor dev QDType '[5], Tensor dev QDType '[5],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
          Tensor dev QDType '[8], Tensor dev QDType '[8],
          Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
          Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
          Tensor dev QDType '[8], Tensor dev QDType '[8],
          Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
     tr
     tr'
     slc
     s
     f
     h
     QType
-> Int
-> Int
-> IO
     (DQNState
        dev
        (Adam
           '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[5],
             Tensor dev QDType '[5], Tensor dev QDType '[5],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
             Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
        tr
        tr'
        slc
        s
        f
        h
        QType,
      QType, QType)
forall (dev :: (DeviceType, Nat)) tr tr' slc slc' s f h gen opt
       {device :: (DeviceType, Nat)}.
(GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType,
 SumDTypeIsValid dev QDType, MeanDTypeValidation dev QDType,
 StandardFloatingPointDTypeValidation dev QDType,
 GeluDTypeIsValid device QDType, RandDTypeIsValid device QDType,
 BasicArithmeticDTypeIsValid device QDType,
 SumDTypeIsValid device QDType, MeanDTypeValidation device QDType,
 StandardFloatingPointDTypeValidation device QDType,
 BasicArithmeticDTypeIsValid dev QDType, StatefulGen gen IO,
 Optimizer
   opt
   '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
     Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[5],
     Tensor dev QDType '[5], Tensor dev QDType '[5],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8], Tensor dev QDType '[8],
     Tensor dev QDType '[1, 8], QTensor dev '[1],
     Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
     Tensor dev QDType '[8], Tensor dev QDType '[8],
     Tensor dev QDType '[1, 8], QTensor dev '[1]]
   '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
     Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
     Tensor dev QDType '[8], Tensor dev QDType '[5],
     Tensor dev QDType '[5], Tensor dev QDType '[5],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
     Tensor dev QDType '[8], Tensor dev QDType '[8],
     Tensor dev QDType '[1, 8], QTensor dev '[1],
     Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
     Tensor dev QDType '[8], Tensor dev QDType '[8],
     Tensor dev QDType '[1, 8], QTensor dev '[1]]
   QDType
   device,
 KnownDevice device, KnownDevice dev) =>
gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
    -> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState dev opt tr tr' slc s f h QType
-> Int
-> Int
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
trainLoop gen
gen Eval tr tr' slc slc' h (Leftmost s f h)
eval GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode Analysis s f h tr slc -> IO QType
reward Action slc tr s f h -> Maybe Bool -> IO QType
rewardStep Path slc' tr'
piece DQNState
  dev
  (Adam
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
  tr
  tr'
  slc
  s
  f
  h
  QType
state Int
i Int
n
    pure (state', r : rewards, loss : losses)
  trainEpoch :: (DQNState
   dev
   (Adam
      '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
        Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
        Tensor dev QDType '[8], Tensor dev QDType '[5],
        Tensor dev QDType '[5], Tensor dev QDType '[5],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
        Tensor dev QDType '[8], Tensor dev QDType '[8],
        Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
        Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
        Tensor dev QDType '[8], Tensor dev QDType '[8],
        Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
   tr
   tr'
   slc
   s
   f
   h
   QType,
 [QType], [QType], [QType])
-> Int
-> IO
     (DQNState
        dev
        (Adam
           '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[5],
             Tensor dev QDType '[5], Tensor dev QDType '[5],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
             Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
        tr
        tr'
        slc
        s
        f
        h
        QType,
      [QType], [QType], [QType])
trainEpoch (DQNState
  dev
  (Adam
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
  tr
  tr'
  slc
  s
  f
  h
  QType
state, [QType]
meanRewards, [QType]
meanLosses, [QType]
accuracies) Int
i = do
    -- run epoch
    (state', rewards, losses) <-
      ((DQNState
    dev
    (Adam
       '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
         Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
         Tensor dev QDType '[8], Tensor dev QDType '[5],
         Tensor dev QDType '[5], Tensor dev QDType '[5],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
         Tensor dev QDType '[8], Tensor dev QDType '[8],
         Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
         Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
         Tensor dev QDType '[8], Tensor dev QDType '[8],
         Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
    tr
    tr'
    slc
    s
    f
    h
    QType,
  [QType], [QType])
 -> Path slc' tr'
 -> IO
      (DQNState
         dev
         (Adam
            '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
              Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
              Tensor dev QDType '[8], Tensor dev QDType '[5],
              Tensor dev QDType '[5], Tensor dev QDType '[5],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
              Tensor dev QDType '[8], Tensor dev QDType '[8],
              Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
              Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
              Tensor dev QDType '[8], Tensor dev QDType '[8],
              Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
         tr
         tr'
         slc
         s
         f
         h
         QType,
       [QType], [QType]))
-> (DQNState
      dev
      (Adam
         '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
           Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[5],
           Tensor dev QDType '[5], Tensor dev QDType '[5],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8],
           Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
           Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8],
           Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
      tr
      tr'
      slc
      s
      f
      h
      QType,
    [QType], [QType])
-> [Path slc' tr']
-> IO
     (DQNState
        dev
        (Adam
           '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[5],
             Tensor dev QDType '[5], Tensor dev QDType '[5],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
             Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
        tr
        tr'
        slc
        s
        f
        h
        QType,
      [QType], [QType])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int
-> (DQNState
      dev
      (Adam
         '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
           Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
           Tensor dev QDType '[8], Tensor dev QDType '[5],
           Tensor dev QDType '[5], Tensor dev QDType '[5],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8],
           Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
           Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
           Tensor dev QDType '[8], Tensor dev QDType '[8],
           Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
      tr
      tr'
      slc
      s
      f
      h
      QType,
    [QType], [QType])
-> Path slc' tr'
-> IO
     (DQNState
        dev
        (Adam
           '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
             Tensor dev QDType '[8], Tensor dev QDType '[5],
             Tensor dev QDType '[5], Tensor dev QDType '[5],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
             Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
             Tensor dev QDType '[8], Tensor dev QDType '[8],
             Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
        tr
        tr'
        slc
        s
        f
        h
        QType,
      [QType], [QType])
trainPiece Int
i) (DQNState
  dev
  (Adam
     '[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
       Tensor dev QDType '[8], Tensor dev QDType '[5],
       Tensor dev QDType '[5], Tensor dev QDType '[5],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
       Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
       Tensor dev QDType '[8], Tensor dev QDType '[8],
       Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
  tr
  tr'
  slc
  s
  f
  h
  QType
state, [], []) [Path slc' tr']
pieces
    let meanRewards' = [QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean [QType]
rewards QType -> [QType] -> [QType]
forall a. a -> [a] -> [a]
: [QType]
meanRewards
        meanLosses' = [QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean [QType]
losses QType -> [QType] -> [QType]
forall a. a -> [a] -> [a]
: [QType]
meanLosses
    -- compute greedy reward ("accuracy")
    accuracies' <-
      if (i `mod` 10) == 0
        then do
          results <- mapM (runEpisode eval encode $ greedyPolicy (T.forward (pnet state'))) pieces
          case sequence results of
            Left String
error -> do
              String -> IO ()
putStrLn String
error
              [QType] -> IO [QType]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([QType] -> IO [QType]) -> [QType] -> IO [QType]
forall a b. (a -> b) -> a -> b
$ (-QType
inf) QType -> [QType] -> [QType]
forall a. a -> [a] -> [a]
: [QType]
accuracies
            Right [([(GreedyState
      tr
      tr'
      slc
      (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
    Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
    QEncoding dev '[],
    Maybe
      (GreedyState
         tr
         tr'
         slc
         (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
       [QEncoding dev '[]]),
    Maybe Bool)],
  Analysis s f h tr slc)]
episodes -> do
              let analyses :: [Analysis s f h tr slc]
analyses = (([(GreedyState
      tr
      tr'
      slc
      (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
    Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
    QEncoding dev '[],
    Maybe
      (GreedyState
         tr
         tr'
         slc
         (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
       [QEncoding dev '[]]),
    Maybe Bool)],
  Analysis s f h tr slc)
 -> Analysis s f h tr slc)
-> [([(GreedyState
         tr
         tr'
         slc
         (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
       Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
       QEncoding dev '[],
       Maybe
         (GreedyState
            tr
            tr'
            slc
            (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
          [QEncoding dev '[]]),
       Maybe Bool)],
     Analysis s f h tr slc)]
-> [Analysis s f h tr slc]
forall a b. (a -> b) -> [a] -> [b]
map ([(GreedyState
     tr
     tr'
     slc
     (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
   Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
   QEncoding dev '[],
   Maybe
     (GreedyState
        tr
        tr'
        slc
        (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
      [QEncoding dev '[]]),
   Maybe Bool)],
 Analysis s f h tr slc)
-> Analysis s f h tr slc
forall a b. (a, b) -> b
snd [([(GreedyState
      tr
      tr'
      slc
      (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
    Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
    QEncoding dev '[],
    Maybe
      (GreedyState
         tr
         tr'
         slc
         (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
       [QEncoding dev '[]]),
    Maybe Bool)],
  Analysis s f h tr slc)]
episodes
              accs <- (Analysis s f h tr slc -> IO QType)
-> [Analysis s f h tr slc] -> IO [QType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Analysis s f h tr slc -> IO QType
reward [Analysis s f h tr slc]
analyses
              when ((i `mod` 100) == 0) $ do
                putStrLn "current best analyses:"
                forM_ (zip analyses [1 ..]) $ \(Analysis [Leftmost s f h]
deriv Path tr slc
_, Integer
i) -> do
                  (Leftmost s f h -> IO ()) -> [Leftmost s f h] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Leftmost s f h -> IO ()
forall a. Show a => a -> IO ()
print [Leftmost s f h]
deriv
                  String
-> [Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)]
-> IO ()
forall (t :: * -> *).
Foldable t =>
String
-> t (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch))
-> IO ()
plotDeriv (String
"rl/deriv" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
".tex") [Leftmost s f h]
[Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)]
deriv
              pure $ mean accs : accuracies
        else pure accuracies
    -- logging
    when ((i `mod` 10) == 0) $ do
      putStrLn $ "epoch " <> show i
      let (ReplayBuffer _ bcontent) = buffer state'
      putStrLn $ "buffer size: " <> show (length bcontent)
      -- mapM_ print $ take 10 bcontent
      plotHistory "rewards" $ reverse meanRewards'
      plotHistory "losses" $ reverse meanLosses'
      plotHistory "accuracy" $ reverse accuracies'
    pure (state', meanRewards', meanLosses', accuracies')

-- Plotting
-- --------

hi :: String -> IO ()
hi String
s = String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Found the Exception:" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
s