{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
{-# HLINT ignore "Use <$>" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module RL.DQN where
import Common
import Display (replayDerivation, viewGraph)
import GreedyParser (Action, ActionDouble (ActionDouble), ActionSingle (ActionSingle), GreedyState, getActions, initParseState, parseGreedy, parseStep, pickRandom)
import PVGrammar (Edge, Edges (Edges), Freeze (FreezeOp), Notes (Notes), PVAnalysis, PVLeftmost, Split, Spread)
import PVGrammar.Generate (derivationPlayerPV)
import PVGrammar.Parse (protoVoiceEvaluator)
import PVGrammar.Prob.Simple (PVParams, evalDoubleStep, evalSingleStep, observeDerivation, observeDerivation', observeDoubleStepParsing, observeSingleStepParsing, sampleDerivation', sampleDoubleStepParsing, sampleSingleStepParsing)
import RL.Callbacks
import RL.Encoding
import RL.Model
import RL.ModelTypes
import RL.Plotting
import RL.ReplayBuffer
import RL.TorchHelpers qualified as TH
import Control.Exception (Exception, catch, onException)
import Control.Monad (foldM, foldM_, forM_, replicateM, when)
import Control.Monad.Except qualified as ET
import Control.Monad.Primitive (RealWorld)
import Control.Monad.State qualified as ST
import Control.Monad.Trans (lift)
import Data.Foldable qualified as F
import Data.List.Extra qualified as E
import Data.Vector qualified as V
import Debug.Trace qualified as DT
import GHC.Float (double2Float)
import Inference.Conjugate (Hyper, HyperRep, Prior (expectedProbs), evalTraceLogP, printTrace, sampleProbs)
import Musicology.Pitch
import System.Random.MWC.Distributions (categorical)
import System.Random.MWC.Probability qualified as MWC
import System.Random.Stateful as Rand (StatefulGen, UniformRange (uniformRM), split)
import Torch qualified as T
import Torch.HList qualified as TT
import Torch.Lens qualified
import Torch.Typed qualified as TT
gamma :: (TT.KnownDevice dev) => QTensor dev '[]
gamma :: forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[]
gamma = QTensor dev '[]
0.99
tau :: QType
tau :: QType
tau = QType
0.1
learningRate :: (IsValidDevice dev) => Double -> TT.LearningRate dev QDType
learningRate :: forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QType -> LearningRate dev QDType
learningRate QType
progress = LearningRate dev QDType
0.1 LearningRate dev QDType
-> LearningRate dev QDType -> LearningRate dev QDType
forall a. Num a => a -> a -> a
* LearningRate dev QDType -> LearningRate dev QDType
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
TT.exp (QType -> LearningRate dev QDType -> LearningRate dev QDType
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.mulScalar QType
progress (LearningRate dev QDType -> LearningRate dev QDType
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> Tensor device dtype shape
TT.log LearningRate dev QDType
0.1))
bufferSize :: Int
bufferSize :: Int
bufferSize = Int
1_000
replayN :: Int
replayN :: Int
replayN = Int
200
epsStart :: QType
epsStart :: QType
epsStart = QType
0.9
epsEnd :: QType
epsEnd :: QType
epsEnd = QType
0.2
eps :: Int -> Int -> QType
eps :: Int -> Int -> QType
eps Int
i Int
n = QType -> QType -> QType -> QType -> QType
expSchedule QType
epsStart QType
epsEnd (Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) (Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i)
data DQNState dev opt tr tr' slc s f h r = DQNState
{ forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
(r :: k).
DQNState dev opt tr tr' slc s f h r -> QModel dev
pnet :: !(QModel dev)
, forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
(r :: k).
DQNState dev opt tr tr' slc s f h r -> QModel dev
tnet :: !(QModel dev)
, forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
(r :: k).
DQNState dev opt tr tr' slc s f h r -> opt
opt :: !opt
, forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
(r :: k).
DQNState dev opt tr tr' slc s f h r
-> ReplayBuffer dev tr tr' slc s f h
buffer :: !(ReplayBuffer dev tr tr' slc s f h)
}
greedyPolicy
:: (Applicative m)
=> (embedding -> QTensor dev '[1])
-> [embedding]
-> m Int
greedyPolicy :: forall (m :: * -> *) embedding (dev :: (DeviceType, Nat)).
Applicative m =>
(embedding -> QTensor dev '[1]) -> [embedding] -> m Int
greedyPolicy embedding -> QTensor dev '[1]
q [embedding]
actions = do
Int -> m Int
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> m Int) -> Int -> m Int
forall a b. (a -> b) -> a -> b
$ Tensor -> Int
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> Int) -> Tensor -> Int
forall a b. (a -> b) -> a -> b
$ Dim -> KeepDim -> Tensor -> Tensor
T.argmax (Int -> Dim
T.Dim Int
0) KeepDim
T.RemoveDim (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) (QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor)
-> (embedding -> QTensor dev '[1]) -> embedding -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. embedding -> QTensor dev '[1]
q (embedding -> Tensor) -> [embedding] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [embedding]
actions)
epsilonic
:: (StatefulGen gen m)
=> gen
-> QType
-> ([embedding] -> m Int)
-> [embedding]
-> m Int
epsilonic :: forall gen (m :: * -> *) embedding.
StatefulGen gen m =>
gen -> QType -> ([embedding] -> m Int) -> [embedding] -> m Int
epsilonic gen
gen QType
epsilon [embedding] -> m Int
policy [embedding]
actions = do
coin <- (QType, QType) -> gen -> m QType
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: * -> *).
StatefulGen g m =>
(QType, QType) -> g -> m QType
uniformRM (QType
0, QType
1) gen
gen
if coin >= epsilon
then policy actions
else uniformRM (0, length actions - 1) gen
softmaxPolicy
:: (StatefulGen gen m)
=> gen
-> (embedding -> QTensor dev '[1])
-> [embedding]
-> m Int
softmaxPolicy :: forall gen (m :: * -> *) embedding (dev :: (DeviceType, Nat)).
StatefulGen gen m =>
gen -> (embedding -> QTensor dev '[1]) -> [embedding] -> m Int
softmaxPolicy gen
gen embedding -> QTensor dev '[1]
q [embedding]
actions = do
let probs :: Tensor
probs = Dim -> Tensor -> Tensor
T.softmax (Int -> Dim
T.Dim Int
0) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
T.cat (Int -> Dim
T.Dim Int
0) ([Tensor] -> Tensor) -> [Tensor] -> Tensor
forall a b. (a -> b) -> a -> b
$ QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor)
-> (embedding -> QTensor dev '[1]) -> embedding -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. embedding -> QTensor dev '[1]
q (embedding -> Tensor) -> [embedding] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [embedding]
actions
Vector QType -> gen -> m Int
forall g (m :: * -> *) (v :: * -> *).
(StatefulGen g m, Vector v QType) =>
v QType -> g -> m Int
categorical ([QType] -> Vector QType
forall a. [a] -> Vector a
V.fromList ([QType] -> Vector QType) -> [QType] -> Vector QType
forall a b. (a -> b) -> a -> b
$ Tensor -> [QType]
forall a. TensorLike a => Tensor -> a
T.asValue (Tensor -> [QType]) -> Tensor -> [QType]
forall a b. (a -> b) -> a -> b
$ DType -> Tensor -> Tensor
T.toDType DType
T.Double Tensor
probs) gen
gen
runEpisode
:: forall dev tr tr' slc slc' s f h gen state action encoding step
. ( state ~ GreedyState tr tr' slc (Leftmost s f h)
, action ~ Action slc tr s f h
, encoding ~ QEncoding dev '[]
, step ~ (state, action, encoding, Maybe (state, [encoding]), Maybe Bool)
)
=> Eval tr tr' slc slc' h (Leftmost s f h)
-> (state -> action -> encoding)
-> ([encoding] -> IO Int)
-> Path slc' tr'
-> IO
( Either
String
([step], Analysis s f h tr slc)
)
runEpisode :: forall {k} (dev :: (DeviceType, Nat)) tr tr' slc slc' s f h
(gen :: k) state action encoding step.
(state ~ GreedyState tr tr' slc (Leftmost s f h),
action ~ Action slc tr s f h, encoding ~ QEncoding dev '[],
step
~ (state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)) =>
Eval tr tr' slc slc' h (Leftmost s f h)
-> (state -> action -> encoding)
-> ([encoding] -> IO Int)
-> Path slc' tr'
-> IO (Either String ([step], Analysis s f h tr slc))
runEpisode !Eval tr tr' slc slc' h (Leftmost s f h)
eval !state -> action -> encoding
encode ![encoding] -> IO Int
policyF !Path slc' tr'
input =
StateT
(Maybe (action, encoding), [encoding])
IO
(Either String ([step], Analysis s f h tr slc))
-> (Maybe (action, encoding), [encoding])
-> IO (Either String ([step], Analysis s f h tr slc))
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
ST.evalStateT (ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([step], Analysis s f h tr slc)
-> StateT
(Maybe (action, encoding), [encoding])
IO
(Either String ([step], Analysis s f h tr slc))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
ET.runExceptT (ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([step], Analysis s f h tr slc)
-> StateT
(Maybe (action, encoding), [encoding])
IO
(Either String ([step], Analysis s f h tr slc)))
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([step], Analysis s f h tr slc)
-> StateT
(Maybe (action, encoding), [encoding])
IO
(Either String ([step], Analysis s f h tr slc))
forall a b. (a -> b) -> a -> b
$ [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)],
Analysis s f h tr slc)
go [] Maybe (state, Maybe (action, encoding), Maybe Bool)
forall a. Maybe a
Nothing (GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([step], Analysis s f h tr slc))
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([step], Analysis s f h tr slc)
forall a b. (a -> b) -> a -> b
$ Eval tr tr' slc slc' h (Leftmost s f h)
-> Path slc' tr' -> GreedyState tr tr' slc (Leftmost s f h)
forall tr tr' slc slc' h v op.
Eval tr tr' slc slc' h v
-> Path slc' tr' -> GreedyState tr tr' slc op
initParseState Eval tr tr' slc slc' h (Leftmost s f h)
eval Path slc' tr'
input) (Maybe (action, encoding)
forall a. Maybe a
Nothing, [])
where
go :: [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)],
Analysis s f h tr slc)
go [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions Maybe (state, Maybe (action, encoding), Maybe Bool)
prev state
state = do
(Maybe (Action slc tr s f h, QEncoding dev '[]),
[QEncoding dev '[]])
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) ()
forall s (m :: * -> *). MonadState s m => s -> m ()
ST.put (Maybe (Action slc tr s f h, QEncoding dev '[])
forall a. Maybe a
Nothing, [])
result <- Eval tr tr' slc slc' h (Leftmost s f h)
-> ([Action slc tr s f h]
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
(Action slc tr s f h))
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
(Either
(GreedyState tr tr' slc (Leftmost s f h)) (tr, [Leftmost s f h]))
forall (m :: * -> *) tr tr' slc slc' s f h.
Monad m =>
Eval tr tr' slc slc' h (Leftmost s f h)
-> ([Action slc tr s f h]
-> ExceptT String m (Action slc tr s f h))
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
m
(Either
(GreedyState tr tr' slc (Leftmost s f h)) (tr, [Leftmost s f h]))
parseStep Eval tr tr' slc slc' h (Leftmost s f h)
eval [action]
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) action
[Action slc tr s f h]
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
(Action slc tr s f h)
policy state
GreedyState tr tr' slc (Leftmost s f h)
state
(actionAndEncoding, actions) <- ST.get
let transitions' = case Maybe (state, Maybe (action, encoding), Maybe Bool)
prev of
Maybe (state, Maybe (action, encoding), Maybe Bool)
Nothing -> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions
Just (state
prevState, Maybe (action, encoding)
prevAction, Maybe Bool
goLeft) ->
state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
addStep state
prevState Maybe (action, encoding)
prevAction ((state, [encoding]) -> Maybe (state, [encoding])
forall a. a -> Maybe a
Just (state
state, [encoding]
actions)) Maybe Bool
goLeft [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions
let goLeft' = case Maybe (state, Maybe (action, encoding), Maybe Bool)
prev of
Just (state
_, Just (Right (ActionDouble DoubleParent slc tr
_ LeftmostDouble s f h
op), encoding
_), Maybe Bool
_) -> case LeftmostDouble s f h
op of
LMDoubleFreezeLeft f
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
LMDoubleSplitLeft s
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
LeftmostDouble s f h
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
Maybe (state, Maybe (action, encoding), Maybe Bool)
_ -> Maybe Bool
forall a. Maybe a
Nothing
case result of
Right (tr
top, [Leftmost s f h]
deriv) ->
([(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)],
Analysis s f h tr slc)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)],
Analysis s f h tr slc)
forall a.
a
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
addStep state
state Maybe (action, encoding)
actionAndEncoding Maybe (state, [encoding])
forall a. Maybe a
Nothing Maybe Bool
goLeft' [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions', [Leftmost s f h] -> Path tr slc -> Analysis s f h tr slc
forall s f h tr slc.
[Leftmost s f h] -> Path tr slc -> Analysis s f h tr slc
Analysis [Leftmost s f h]
deriv (Path tr slc -> Analysis s f h tr slc)
-> Path tr slc -> Analysis s f h tr slc
forall a b. (a -> b) -> a -> b
$ tr -> Path tr slc
forall around between. around -> Path around between
PathEnd tr
top)
Left GreedyState tr tr' slc (Leftmost s f h)
state' -> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
-> GreedyState tr tr' slc (Leftmost s f h)
-> ExceptT
String
(StateT (Maybe (action, encoding), [encoding]) IO)
([(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)],
Analysis s f h tr slc)
go [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
transitions' ((state, Maybe (action, encoding), Maybe Bool)
-> Maybe (state, Maybe (action, encoding), Maybe Bool)
forall a. a -> Maybe a
Just (state
state, Maybe (action, encoding)
actionAndEncoding, Maybe Bool
goLeft')) GreedyState tr tr' slc (Leftmost s f h)
state'
where
addStep
:: state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
addStep :: state
-> Maybe (action, encoding)
-> Maybe (state, [encoding])
-> Maybe Bool
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
addStep state
state Maybe (action, encoding)
Nothing Maybe (state, [encoding])
_next Maybe Bool
_goLeft [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts = [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts
addStep state
state (Just (action
action, encoding
actEnc)) Maybe (state, [encoding])
next Maybe Bool
goLeft [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts = (state
state, action
action, encoding
actEnc, Maybe (state, [encoding])
next, Maybe Bool
goLeft) (state, action, encoding, Maybe (state, [encoding]), Maybe Bool)
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
-> [(state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)]
forall a. a -> [a] -> [a]
: [(state, action, encoding, Maybe (state, [encoding]), Maybe Bool)]
ts
policy :: [action] -> ET.ExceptT String (ST.StateT (Maybe (action, encoding), [encoding]) IO) action
policy :: [action]
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) action
policy [] = String
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) action
forall a.
String
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
ET.throwError String
"no actions to select from"
policy [action]
actions = do
let encodings :: [encoding]
encodings = state -> action -> encoding
encode state
state (action -> encoding) -> [action] -> [encoding]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [action]
actions
actionIndex <- StateT (Maybe (action, encoding), [encoding]) IO Int
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) Int
forall (m :: * -> *) a. Monad m => m a -> ExceptT String m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT (Maybe (action, encoding), [encoding]) IO Int
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) Int)
-> StateT (Maybe (action, encoding), [encoding]) IO Int
-> ExceptT
String (StateT (Maybe (action, encoding), [encoding]) IO) Int
forall a b. (a -> b) -> a -> b
$ IO Int -> StateT (Maybe (action, encoding), [encoding]) IO Int
forall (m :: * -> *) a.
Monad m =>
m a -> StateT (Maybe (action, encoding), [encoding]) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO Int -> StateT (Maybe (action, encoding), [encoding]) IO Int)
-> IO Int -> StateT (Maybe (action, encoding), [encoding]) IO Int
forall a b. (a -> b) -> a -> b
$ [encoding] -> IO Int
policyF [encoding]
encodings
let action = [action]
actions [action] -> Int -> action
forall a. HasCallStack => [a] -> Int -> a
!! Int
actionIndex
ST.put (Just (actions !! actionIndex, encodings !! actionIndex), encodings)
pure action
trainLoop
:: forall dev tr tr' slc slc' s f h gen opt
.
(_)
=> gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h) -> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState dev opt tr tr' slc s f h QType
-> Int
-> Int
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
trainLoop :: gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState dev opt tr tr' slc s f h QType
-> Int
-> Int
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
trainLoop !gen
gen !Eval tr tr' slc slc' h (Leftmost s f h)
eval !GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode !Analysis s f h tr slc -> IO QType
reward !Action slc tr s f h -> Maybe Bool -> IO QType
rewardStep !Path slc' tr'
piece oldstate :: DQNState dev opt tr tr' slc s f h QType
oldstate@(DQNState !QModel dev
pnet !QModel dev
tnet !opt
opt !ReplayBuffer dev tr tr' slc s f h
buffer) Int
i Int
n = do
let policy :: (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> IO Int
policy QEncoding dev '[] -> QTensor dev '[1]
q = gen
-> QType
-> ([QEncoding dev '[]] -> IO Int)
-> [QEncoding dev '[]]
-> IO Int
forall gen (m :: * -> *) embedding.
StatefulGen gen m =>
gen -> QType -> ([embedding] -> m Int) -> [embedding] -> m Int
epsilonic gen
gen (Int -> Int -> QType
eps Int
i Int
n) (([QEncoding dev '[]] -> IO Int) -> [QEncoding dev '[]] -> IO Int)
-> ([QEncoding dev '[]] -> IO Int) -> [QEncoding dev '[]] -> IO Int
forall a b. (a -> b) -> a -> b
$ gen
-> (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]]
-> IO Int
forall gen (m :: * -> *) embedding (dev :: (DeviceType, Nat)).
StatefulGen gen m =>
gen -> (embedding -> QTensor dev '[1]) -> [embedding] -> m Int
softmaxPolicy gen
gen QEncoding dev '[] -> QTensor dev '[1]
q
result <- Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[])
-> ([QEncoding dev '[]] -> IO Int)
-> Path slc' tr'
-> IO
(Either
String
([(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)],
Analysis s f h tr slc))
forall {k} (dev :: (DeviceType, Nat)) tr tr' slc slc' s f h
(gen :: k) state action encoding step.
(state ~ GreedyState tr tr' slc (Leftmost s f h),
action ~ Action slc tr s f h, encoding ~ QEncoding dev '[],
step
~ (state, action, encoding, Maybe (state, [encoding]),
Maybe Bool)) =>
Eval tr tr' slc slc' h (Leftmost s f h)
-> (state -> action -> encoding)
-> ([encoding] -> IO Int)
-> Path slc' tr'
-> IO (Either String ([step], Analysis s f h tr slc))
runEpisode Eval tr tr' slc slc' h (Leftmost s f h)
eval GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode ((QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> IO Int
policy ((QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> IO Int)
-> (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]]
-> IO Int
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
pnet) Path slc' tr'
piece
case result of
Left String
error -> do
String -> IO ()
forall a. Show a => a -> IO ()
print String
error
(DQNState dev opt tr tr' slc s f h QType, QType, QType)
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DQNState dev opt tr tr' slc s f h QType
oldstate, QType
0, QType
0)
Right ([(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
steps, Analysis s f h tr slc
analysis) -> do
(steps', r) <- [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
-> Analysis s f h tr slc
-> IO ([ReplayStep dev tr tr' slc s f h], QType)
rewardEpisode [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
steps Analysis s f h tr slc
analysis
let buffer' = (ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h)
-> ReplayBuffer dev tr tr' slc s f h
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
pushStep ReplayBuffer dev tr tr' slc s f h
buffer [ReplayStep dev tr tr' slc s f h]
steps'
(pnet', tnet', opt', loss) <- optimizeModels buffer'
pure (DQNState pnet' tnet' opt' buffer', r, loss)
where
mkReplay :: QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
e)
-> ReplayStep dev tr tr' slc s f h
mkReplay QType
r (!GreedyState tr tr' slc (Leftmost s f h)
state, !Action slc tr s f h
action, !QEncoding dev '[]
actEnc, !Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]])
next, e
goLeft) =
RPState tr tr' slc s f h
-> RPAction slc tr s f h
-> QEncoding dev '[]
-> Maybe (RPState tr tr' slc s f h)
-> [QEncoding dev '[]]
-> QType
-> ReplayStep dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
RPState tr tr' slc s f h
-> RPAction slc tr s f h
-> QEncoding dev '[]
-> Maybe (RPState tr tr' slc s f h)
-> [QEncoding dev '[]]
-> QType
-> ReplayStep dev tr tr' slc s f h
ReplayStep (GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
forall tr tr' slc s f h.
GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
RPState GreedyState tr tr' slc (Leftmost s f h)
state) (Action slc tr s f h -> RPAction slc tr s f h
forall slc tr s f h. Action slc tr s f h -> RPAction slc tr s f h
RPAction Action slc tr s f h
action) QEncoding dev '[]
actEnc Maybe (RPState tr tr' slc s f h)
state' [QEncoding dev '[]]
steps' QType
r
where
(!Maybe (RPState tr tr' slc s f h)
state', ![QEncoding dev '[]]
steps') = case Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]])
next of
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]])
Nothing -> (Maybe (RPState tr tr' slc s f h)
forall a. Maybe a
Nothing, [])
Just (GreedyState tr tr' slc (Leftmost s f h)
s, [QEncoding dev '[]]
acts) -> (RPState tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h)
forall a. a -> Maybe a
Just (RPState tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h))
-> RPState tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h)
forall a b. (a -> b) -> a -> b
$ GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
forall tr tr' slc s f h.
GreedyState tr tr' slc (Leftmost s f h) -> RPState tr tr' slc s f h
RPState GreedyState tr tr' slc (Leftmost s f h)
s, [QEncoding dev '[]]
acts)
rewardEpisode :: [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
-> Analysis s f h tr slc
-> IO ([ReplayStep dev tr tr' slc s f h], QType)
rewardEpisode [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
steps Analysis s f h tr slc
analysis = do
r <- Analysis s f h tr slc -> IO QType
reward Analysis s f h tr slc
analysis
let steps' = case [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
steps of
[] -> []
(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)
last : [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
rest -> QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)
-> ReplayStep dev tr tr' slc s f h
forall {tr} {tr'} {slc} {s} {f} {h} {dev :: (DeviceType, Nat)} {e}.
QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
e)
-> ReplayStep dev tr tr' slc s f h
mkReplay QType
r (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)
last ReplayStep dev tr tr' slc s f h
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. a -> [a] -> [a]
: ((GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)
-> ReplayStep dev tr tr' slc s f h)
-> [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
-> [ReplayStep dev tr tr' slc s f h]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)
-> ReplayStep dev tr tr' slc s f h
forall {tr} {tr'} {slc} {s} {f} {h} {dev :: (DeviceType, Nat)} {e}.
QType
-> (GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
e)
-> ReplayStep dev tr tr' slc s f h
mkReplay QType
0) [(GreedyState tr tr' slc (Leftmost s f h), Action slc tr s f h,
QEncoding dev '[],
Maybe
(GreedyState tr tr' slc (Leftmost s f h), [QEncoding dev '[]]),
Maybe Bool)]
rest
pure (steps', r)
optimizeModels :: ReplayBuffer dev tr tr' slc s f h
-> IO (QModel dev, QModel dev, opt, QType)
optimizeModels ReplayBuffer dev tr tr' slc s f h
buffer' = do
batch <- ReplayBuffer dev tr tr' slc s f h
-> Int -> IO [ReplayStep dev tr tr' slc s f h]
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> Int -> IO [ReplayStep dev tr tr' slc s f h]
sampleSteps ReplayBuffer dev tr tr' slc s f h
buffer' Int
replayN
let (qsNow, qsExpected) = unzip (dqnValues <$> batch)
expectedDetached <- T.detach $ T.stack (T.Dim 0) qsExpected
let !loss =
Reduction -> Tensor -> Tensor -> Tensor
T.smoothL1Loss
Reduction
T.ReduceMean
(Dim -> [Tensor] -> Tensor
T.stack (Int -> Dim
T.Dim Int
0) [Tensor]
qsNow)
Tensor
expectedDetached
!lossWithFake = Tensor -> Tensor device QDType '[]
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
TT.UnsafeMkTensor (Tensor -> Tensor device QDType '[])
-> Tensor -> Tensor device QDType '[]
forall a b. (a -> b) -> a -> b
$ Tensor
loss Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ QTensor dev '[] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QModel dev -> QTensor dev '[]
forall (dev :: (DeviceType, Nat)) (ps :: [*]).
(IsValidDevice dev, ps ~ Parameters (QModel dev)) =>
QModel dev -> QTensor dev '[]
fakeLoss QModel dev
pnet)
putStr $ "loss: " <> show (T.asValue @QType $ TT.toDynamic lossWithFake)
putStrLn $ "\tavgq: " <> show (T.asValue @QType $ T.mean $ T.stack (T.Dim 0) qsNow)
let lr = QType -> Tensor device QDType '[]
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
QType -> LearningRate dev QDType
learningRate (QType -> Tensor device QDType '[])
-> QType -> Tensor device QDType '[]
forall a b. (a -> b) -> a -> b
$ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i QType -> QType -> QType
forall a. Fractional a => a -> a -> a
/ Int -> QType
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
(pnet', opt') <- TT.runStep pnet opt lossWithFake lr
tparams <- TT.hmapM' TH.Detach $ TT.hmap' TT.ToDependent $ TT.flattenParameters tnet
pparams <- TT.hmapM' TH.Detach $ TT.hmap' TT.ToDependent $ TT.flattenParameters pnet'
let tparams' = Interpolate QType
-> HList
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
-> HList
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
-> HList
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
forall k f (xs :: [k]) (ys :: [k]) (zs :: [k]).
HZipWith f xs ys zs =>
f -> HList xs -> HList ys -> HList zs
TT.hzipWith (QType -> Interpolate QType
forall num. num -> Interpolate num
TH.Interpolate QType
tau) HList
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
pparams HList
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
tparams
tparamsNew <- TT.hmapM' TT.MakeIndependent tparams'
let tnet' = QModel dev -> HList (Parameters (QModel dev)) -> QModel dev
forall f. Parameterized f => f -> HList (Parameters f) -> f
TT.replaceParameters QModel dev
tnet HList
'[Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 13, 5], Parameter dev QDType '[8, 13, 5],
Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8], Parameter dev QDType '[8],
Parameter dev QDType '[1, 8], Parameter dev QDType '[1],
Parameter dev QDType '[8, 8], Parameter dev QDType '[8],
Parameter dev QDType '[8], Parameter dev QDType '[8],
Parameter dev QDType '[1, 8], Parameter dev QDType '[1]]
HList (Parameters (QModel dev))
tparamsNew
pure (pnet', tnet', opt', T.asValue loss)
dqnValues :: ReplayStep dev tr tr' slc s f h -> (T.Tensor, T.Tensor)
dqnValues :: ReplayStep dev tr tr' slc s f h -> (Tensor, Tensor)
dqnValues (ReplayStep RPState tr tr' slc s f h
_ RPAction slc tr s f h
_ QEncoding dev '[]
step0Enc Maybe (RPState tr tr' slc s f h)
s' [QEncoding dev '[]]
step1Encs QType
r) = (Tensor
qnow, Tensor
qexpected)
where
qzero :: QTensor dev '[1]
qzero = QTensor dev '[1]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
TT.zeros
qnext :: QTensor dev '[1]
qnext = case Maybe (RPState tr tr' slc s f h)
s' of
Maybe (RPState tr tr' slc s f h)
Nothing -> QTensor dev '[1]
qzero
Just (RPState GreedyState tr tr' slc (Leftmost s f h)
state') ->
let
nextQs :: [QTensor dev '[1]]
nextQs :: [QTensor dev '[1]]
nextQs = QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
TT.forward QModel dev
tnet (QEncoding dev '[] -> QTensor dev '[1])
-> [QEncoding dev '[]] -> [QTensor dev '[1]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [QEncoding dev '[]]
step1Encs
toVal :: QTensor dev '[1] -> QType
toVal :: QTensor dev '[1] -> QType
toVal = forall a. TensorLike a => Tensor -> a
T.asValue @QType (Tensor -> QType)
-> (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> QType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic
in
(QTensor dev '[1] -> QType)
-> [QTensor dev '[1]] -> QTensor dev '[1]
forall b a. (HasCallStack, Ord b) => (a -> b) -> [a] -> a
E.maximumOn QTensor dev '[1] -> QType
toVal [QTensor dev '[1]]
nextQs
qnow :: Tensor
qnow = QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> Tensor
forall a b. (a -> b) -> a -> b
$ QModel dev -> QEncoding dev '[] -> QTensor dev '[1]
forall f a b. HasForward f a b => f -> a -> b
T.forward QModel dev
pnet QEncoding dev '[]
step0Enc
qexpected :: Tensor
qexpected = QTensor dev '[1] -> Tensor
forall t. Unnamed t => t -> Tensor
TT.toDynamic (QTensor dev '[1] -> Tensor) -> QTensor dev '[1] -> Tensor
forall a b. (a -> b) -> a -> b
$ QType -> QTensor dev '[1] -> QTensor dev '[1]
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
TT.addScalar QType
r (QTensor dev '[]
forall (dev :: (DeviceType, Nat)).
KnownDevice dev =>
QTensor dev '[]
gamma QTensor dev '[] -> QTensor dev '[1] -> QTensor dev '[1]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`TT.mul` QTensor dev '[1]
qnext)
trainDQN
:: forall dev gen tr tr' slc slc' s f h
. ( IsValidDevice dev
, StatefulGen gen IO
, Show s
, Show f
, Show h
, s ~ Split SPitch
, f ~ Freeze SPitch
, h ~ Spread SPitch
, Show slc
, Show tr
)
=> gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h) -> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> [Path slc' tr']
-> Int
-> IO ([QType], [QType], QModel dev)
trainDQN :: forall (dev :: (DeviceType, Nat)) gen tr tr' slc slc' s f h.
(IsValidDevice dev, StatefulGen gen IO, Show s, Show f, Show h,
s ~ Split SPitch, f ~ Freeze SPitch, h ~ Spread SPitch, Show slc,
Show tr) =>
gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> [Path slc' tr']
-> Int
-> IO ([QType], [QType], QModel dev)
trainDQN gen
gen Eval tr tr' slc slc' h (Leftmost s f h)
eval GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode Analysis s f h tr slc -> IO QType
reward Action slc tr s f h -> Maybe Bool -> IO QType
rewardStep [Path slc' tr']
pieces Int
n = do
model0 <- IO (QModel dev)
forall (dev :: (DeviceType, Nat)).
IsValidDevice dev =>
IO (QModel dev)
mkQModel
let opt = AdamIter
-> Float
-> Float
-> HList
'[Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 13, 5], Parameter dev QDType '[8, 13, 5],
Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 2, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
Parameter dev QDType '[8, 1, 1, 1], Parameter dev QDType '[8],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[8, 8, 13, 5],
Parameter dev QDType '[8], Parameter dev QDType '[5],
Parameter dev QDType '[5], Parameter dev QDType '[5],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8, 8, 13, 5], Parameter dev QDType '[8],
Parameter dev QDType '[8], Parameter dev QDType '[8],
Parameter dev QDType '[1, 8], Parameter dev QDType '[1],
Parameter dev QDType '[8, 8], Parameter dev QDType '[8],
Parameter dev QDType '[8], Parameter dev QDType '[8],
Parameter dev QDType '[1, 8], Parameter dev QDType '[1]]
-> Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]]
forall (parameters :: [*]) (momenta :: [*]).
HMap' ZerosLike parameters momenta =>
AdamIter -> Float -> Float -> HList parameters -> Adam momenta
TT.mkAdam AdamIter
0 Float
0.9 Float
0.99 (QModel dev -> HList (Parameters (QModel dev))
forall f. Parameterized f => f -> HList (Parameters f)
TT.flattenParameters QModel dev
model0)
buffer = Int -> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int -> ReplayBuffer dev tr tr' slc s f h
mkReplayBuffer Int
bufferSize
state0 = QModel dev
-> QModel dev
-> Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]]
-> ReplayBuffer dev tr tr' slc s f h
-> DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType
forall {k} (dev :: (DeviceType, Nat)) opt tr tr' slc s f h
(r :: k).
QModel dev
-> QModel dev
-> opt
-> ReplayBuffer dev tr tr' slc s f h
-> DQNState dev opt tr tr' slc s f h r
DQNState QModel dev
model0 QModel dev
model0 Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]]
opt ReplayBuffer dev tr tr' slc s f h
forall {dev :: (DeviceType, Nat)} {tr} {tr'} {slc} {s} {f} {h}.
ReplayBuffer dev tr tr' slc s f h
buffer
(DQNState modelTrained _ _ _, rewards, losses, accs) <- T.foldLoop (state0, [], [], []) n trainEpoch
pure (reverse rewards, reverse losses, modelTrained)
where
trainPiece :: Int
-> (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
-> Path slc' tr'
-> IO
(DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
trainPiece Int
i (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType
state, [QType]
rewards, [QType]
losses) Path slc' tr'
piece = do
(state', r, loss) <- gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType
-> Int
-> Int
-> IO
(DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
QType, QType)
forall (dev :: (DeviceType, Nat)) tr tr' slc slc' s f h gen opt
{device :: (DeviceType, Nat)}.
(GeluDTypeIsValid dev QDType, RandDTypeIsValid dev QDType,
SumDTypeIsValid dev QDType, MeanDTypeValidation dev QDType,
StandardFloatingPointDTypeValidation dev QDType,
GeluDTypeIsValid device QDType, RandDTypeIsValid device QDType,
BasicArithmeticDTypeIsValid device QDType,
SumDTypeIsValid device QDType, MeanDTypeValidation device QDType,
StandardFloatingPointDTypeValidation device QDType,
BasicArithmeticDTypeIsValid dev QDType, StatefulGen gen IO,
Optimizer
opt
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], QTensor dev '[1]]
QDType
device,
KnownDevice device, KnownDevice dev) =>
gen
-> Eval tr tr' slc slc' h (Leftmost s f h)
-> (GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[])
-> (Analysis s f h tr slc -> IO QType)
-> (Action slc tr s f h -> Maybe Bool -> IO QType)
-> Path slc' tr'
-> DQNState dev opt tr tr' slc s f h QType
-> Int
-> Int
-> IO (DQNState dev opt tr tr' slc s f h QType, QType, QType)
trainLoop gen
gen Eval tr tr' slc slc' h (Leftmost s f h)
eval GreedyState tr tr' slc (Leftmost s f h)
-> Action slc tr s f h -> QEncoding dev '[]
encode Analysis s f h tr slc -> IO QType
reward Action slc tr s f h -> Maybe Bool -> IO QType
rewardStep Path slc' tr'
piece DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType
state Int
i Int
n
pure (state', r : rewards, loss : losses)
trainEpoch :: (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType], [QType])
-> Int
-> IO
(DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType], [QType])
trainEpoch (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType
state, [QType]
meanRewards, [QType]
meanLosses, [QType]
accuracies) Int
i = do
(state', rewards, losses) <-
((DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
-> Path slc' tr'
-> IO
(DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType]))
-> (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
-> [Path slc' tr']
-> IO
(DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int
-> (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
-> Path slc' tr'
-> IO
(DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType,
[QType], [QType])
trainPiece Int
i) (DQNState
dev
(Adam
'[Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 13, 5], Tensor dev QDType '[8, 13, 5],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 2, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8, 1, 1, 1], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[8, 8, 13, 5],
Tensor dev QDType '[8], Tensor dev QDType '[5],
Tensor dev QDType '[5], Tensor dev QDType '[5],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8, 8, 13, 5], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1],
Tensor dev QDType '[8, 8], Tensor dev QDType '[8],
Tensor dev QDType '[8], Tensor dev QDType '[8],
Tensor dev QDType '[1, 8], Tensor dev QDType '[1]])
tr
tr'
slc
s
f
h
QType
state, [], []) [Path slc' tr']
pieces
let meanRewards' = [QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean [QType]
rewards QType -> [QType] -> [QType]
forall a. a -> [a] -> [a]
: [QType]
meanRewards
meanLosses' = [QType] -> QType
forall (t :: * -> *). Foldable t => t QType -> QType
mean [QType]
losses QType -> [QType] -> [QType]
forall a. a -> [a] -> [a]
: [QType]
meanLosses
accuracies' <-
if (i `mod` 10) == 0
then do
results <- mapM (runEpisode eval encode $ greedyPolicy (T.forward (pnet state'))) pieces
case sequence results of
Left String
error -> do
String -> IO ()
putStrLn String
error
[QType] -> IO [QType]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([QType] -> IO [QType]) -> [QType] -> IO [QType]
forall a b. (a -> b) -> a -> b
$ (-QType
inf) QType -> [QType] -> [QType]
forall a. a -> [a] -> [a]
: [QType]
accuracies
Right [([(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
QEncoding dev '[],
Maybe
(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
[QEncoding dev '[]]),
Maybe Bool)],
Analysis s f h tr slc)]
episodes -> do
let analyses :: [Analysis s f h tr slc]
analyses = (([(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
QEncoding dev '[],
Maybe
(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
[QEncoding dev '[]]),
Maybe Bool)],
Analysis s f h tr slc)
-> Analysis s f h tr slc)
-> [([(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
QEncoding dev '[],
Maybe
(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
[QEncoding dev '[]]),
Maybe Bool)],
Analysis s f h tr slc)]
-> [Analysis s f h tr slc]
forall a b. (a -> b) -> [a] -> [b]
map ([(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
QEncoding dev '[],
Maybe
(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
[QEncoding dev '[]]),
Maybe Bool)],
Analysis s f h tr slc)
-> Analysis s f h tr slc
forall a b. (a, b) -> b
snd [([(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
Action slc tr (Split SPitch) (Freeze SPitch) (Spread SPitch),
QEncoding dev '[],
Maybe
(GreedyState
tr
tr'
slc
(Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)),
[QEncoding dev '[]]),
Maybe Bool)],
Analysis s f h tr slc)]
episodes
accs <- (Analysis s f h tr slc -> IO QType)
-> [Analysis s f h tr slc] -> IO [QType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Analysis s f h tr slc -> IO QType
reward [Analysis s f h tr slc]
analyses
when ((i `mod` 100) == 0) $ do
putStrLn "current best analyses:"
forM_ (zip analyses [1 ..]) $ \(Analysis [Leftmost s f h]
deriv Path tr slc
_, Integer
i) -> do
(Leftmost s f h -> IO ()) -> [Leftmost s f h] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Leftmost s f h -> IO ()
forall a. Show a => a -> IO ()
print [Leftmost s f h]
deriv
String
-> [Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)]
-> IO ()
forall (t :: * -> *).
Foldable t =>
String
-> t (Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch))
-> IO ()
plotDeriv (String
"rl/deriv" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Integer -> String
forall a. Show a => a -> String
show Integer
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
".tex") [Leftmost s f h]
[Leftmost (Split SPitch) (Freeze SPitch) (Spread SPitch)]
deriv
pure $ mean accs : accuracies
else pure accuracies
when ((i `mod` 10) == 0) $ do
putStrLn $ "epoch " <> show i
let (ReplayBuffer _ bcontent) = buffer state'
putStrLn $ "buffer size: " <> show (length bcontent)
plotHistory "rewards" $ reverse meanRewards'
plotHistory "losses" $ reverse meanLosses'
plotHistory "accuracy" $ reverse accuracies'
pure (state', meanRewards', meanLosses', accuracies')
hi :: String -> IO ()
hi String
s = String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Found the Exception:" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
s