{-# LANGUAGE DataKinds #-}
module RL.ReplayBuffer where
import Common
import GreedyParser
import RL.Encoding
import RL.ModelTypes
import System.Random (RandomGen, getStdRandom)
import System.Random.Shuffle (shuffle')
import System.Random.Stateful as Rand (StatefulGen, UniformRange (uniformRM), split)
newtype RPState tr tr' slc s f h = RPState (GreedyState tr tr' slc (Leftmost s f h))
deriving (Int -> RPState tr tr' slc s f h -> ShowS
[RPState tr tr' slc s f h] -> ShowS
RPState tr tr' slc s f h -> String
(Int -> RPState tr tr' slc s f h -> ShowS)
-> (RPState tr tr' slc s f h -> String)
-> ([RPState tr tr' slc s f h] -> ShowS)
-> Show (RPState tr tr' slc s f h)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
Int -> RPState tr tr' slc s f h -> ShowS
forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
[RPState tr tr' slc s f h] -> ShowS
forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
RPState tr tr' slc s f h -> String
$cshowsPrec :: forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
Int -> RPState tr tr' slc s f h -> ShowS
showsPrec :: Int -> RPState tr tr' slc s f h -> ShowS
$cshow :: forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
RPState tr tr' slc s f h -> String
show :: RPState tr tr' slc s f h -> String
$cshowList :: forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
[RPState tr tr' slc s f h] -> ShowS
showList :: [RPState tr tr' slc s f h] -> ShowS
Show)
newtype RPAction slc tr s f h = RPAction (Action slc tr s f h)
deriving (Int -> RPAction slc tr s f h -> ShowS
[RPAction slc tr s f h] -> ShowS
RPAction slc tr s f h -> String
(Int -> RPAction slc tr s f h -> ShowS)
-> (RPAction slc tr s f h -> String)
-> ([RPAction slc tr s f h] -> ShowS)
-> Show (RPAction slc tr s f h)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
Int -> RPAction slc tr s f h -> ShowS
forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
[RPAction slc tr s f h] -> ShowS
forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
RPAction slc tr s f h -> String
$cshowsPrec :: forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
Int -> RPAction slc tr s f h -> ShowS
showsPrec :: Int -> RPAction slc tr s f h -> ShowS
$cshow :: forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
RPAction slc tr s f h -> String
show :: RPAction slc tr s f h -> String
$cshowList :: forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
[RPAction slc tr s f h] -> ShowS
showList :: [RPAction slc tr s f h] -> ShowS
Show)
data ReplayStep dev tr tr' slc s f h = ReplayStep
{ forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> RPState tr tr' slc s f h
replayState :: !(RPState tr tr' slc s f h)
, forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> RPAction slc tr s f h
replayAction :: !(RPAction slc tr s f h)
, forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> QEncoding dev '[]
replayStep :: !(QEncoding dev '[])
, forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h)
replayNextState :: !(Maybe (RPState tr tr' slc s f h))
, forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> [QEncoding dev '[]]
replayNextSteps :: ![QEncoding dev '[]]
, forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> QType
replayReward :: !QType
}
instance (Show slc, Show s, Show f, Show h, Show tr, Show tr') => Show (ReplayStep dev tr tr' slc s f h) where
show :: ReplayStep dev tr tr' slc s f h -> String
show (ReplayStep RPState tr tr' slc s f h
s (RPAction Action slc tr s f h
a) QEncoding dev '[]
_ Maybe (RPState tr tr' slc s f h)
s' [QEncoding dev '[]]
_ QType
r) =
RPState tr tr' slc s f h -> String
forall a. Show a => a -> String
show RPState tr tr' slc s f h
s String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" -> " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Maybe (RPState tr tr' slc s f h) -> String
forall a. Show a => a -> String
show Maybe (RPState tr tr' slc s f h)
s' String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> QType -> String
forall a. Show a => a -> String
show QType
r String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"\n " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
act
where
act :: String
act = case Action slc tr s f h
a of
Left (ActionSingle SingleParent slc tr
_ LeftmostSingle s f
op) -> LeftmostSingle s f -> String
forall a. Show a => a -> String
show LeftmostSingle s f
op
Right (ActionDouble DoubleParent slc tr
_ LeftmostDouble s f h
op) -> LeftmostDouble s f h -> String
forall a. Show a => a -> String
show LeftmostDouble s f h
op
data ReplayBuffer dev tr tr' slc s f h
= ReplayBuffer !Int ![ReplayStep dev tr tr' slc s f h]
deriving (Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
[ReplayBuffer dev tr tr' slc s f h] -> ShowS
ReplayBuffer dev tr tr' slc s f h -> String
(Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS)
-> (ReplayBuffer dev tr tr' slc s f h -> String)
-> ([ReplayBuffer dev tr tr' slc s f h] -> ShowS)
-> Show (ReplayBuffer dev tr tr' slc s f h)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
[ReplayBuffer dev tr tr' slc s f h] -> ShowS
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
ReplayBuffer dev tr tr' slc s f h -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
showsPrec :: Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
ReplayBuffer dev tr tr' slc s f h -> String
show :: ReplayBuffer dev tr tr' slc s f h -> String
$cshowList :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
[ReplayBuffer dev tr tr' slc s f h] -> ShowS
showList :: [ReplayBuffer dev tr tr' slc s f h] -> ShowS
Show)
mkReplayBuffer :: Int -> ReplayBuffer dev tr tr' slc s f h
mkReplayBuffer :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int -> ReplayBuffer dev tr tr' slc s f h
mkReplayBuffer Int
n = Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
ReplayBuffer Int
n []
seedReplayBuffer :: Int -> [ReplayStep dev tr tr' slc s f h] -> ReplayBuffer dev tr tr' slc s f h
seedReplayBuffer :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
seedReplayBuffer Int
n [ReplayStep dev tr tr' slc s f h]
steps = Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
ReplayBuffer Int
n ([ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h)
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall a b. (a -> b) -> a -> b
$ Int
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. Int -> [a] -> [a]
take Int
n [ReplayStep dev tr tr' slc s f h]
steps
pushStep
:: ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
pushStep :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
pushStep (ReplayBuffer Int
n [ReplayStep dev tr tr' slc s f h]
queue) ReplayStep dev tr tr' slc s f h
trans = Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
ReplayBuffer Int
n ([ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h)
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall a b. (a -> b) -> a -> b
$ Int
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. Int -> [a] -> [a]
take Int
n ([ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h])
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a b. (a -> b) -> a -> b
$ ReplayStep dev tr tr' slc s f h
trans ReplayStep dev tr tr' slc s f h
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. a -> [a] -> [a]
: [ReplayStep dev tr tr' slc s f h]
queue
sampleSteps
:: ReplayBuffer dev tr tr' slc s f h
-> Int
-> IO [ReplayStep dev tr tr' slc s f h]
sampleSteps :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> Int -> IO [ReplayStep dev tr tr' slc s f h]
sampleSteps (ReplayBuffer Int
_ [ReplayStep dev tr tr' slc s f h]
queue) Int
n = do
gen <- (StdGen -> (StdGen, StdGen)) -> IO StdGen
forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
Rand.split
pure $ take n (shuffle' queue (length queue) gen)