{-# LANGUAGE DataKinds #-}

module RL.ReplayBuffer where

import Common
import GreedyParser
import RL.Encoding
import RL.ModelTypes
import System.Random (RandomGen, getStdRandom)
import System.Random.Shuffle (shuffle')
import System.Random.Stateful as Rand (StatefulGen, UniformRange (uniformRM), split)

-- States and Actions
-- ------------------

newtype RPState tr tr' slc s f h = RPState (GreedyState tr tr' slc (Leftmost s f h))
  deriving (Int -> RPState tr tr' slc s f h -> ShowS
[RPState tr tr' slc s f h] -> ShowS
RPState tr tr' slc s f h -> String
(Int -> RPState tr tr' slc s f h -> ShowS)
-> (RPState tr tr' slc s f h -> String)
-> ([RPState tr tr' slc s f h] -> ShowS)
-> Show (RPState tr tr' slc s f h)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
Int -> RPState tr tr' slc s f h -> ShowS
forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
[RPState tr tr' slc s f h] -> ShowS
forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
RPState tr tr' slc s f h -> String
$cshowsPrec :: forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
Int -> RPState tr tr' slc s f h -> ShowS
showsPrec :: Int -> RPState tr tr' slc s f h -> ShowS
$cshow :: forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
RPState tr tr' slc s f h -> String
show :: RPState tr tr' slc s f h -> String
$cshowList :: forall tr tr' slc s f h.
(Show tr, Show tr', Show slc, Show s, Show f, Show h) =>
[RPState tr tr' slc s f h] -> ShowS
showList :: [RPState tr tr' slc s f h] -> ShowS
Show)

newtype RPAction slc tr s f h = RPAction (Action slc tr s f h)
  deriving (Int -> RPAction slc tr s f h -> ShowS
[RPAction slc tr s f h] -> ShowS
RPAction slc tr s f h -> String
(Int -> RPAction slc tr s f h -> ShowS)
-> (RPAction slc tr s f h -> String)
-> ([RPAction slc tr s f h] -> ShowS)
-> Show (RPAction slc tr s f h)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
Int -> RPAction slc tr s f h -> ShowS
forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
[RPAction slc tr s f h] -> ShowS
forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
RPAction slc tr s f h -> String
$cshowsPrec :: forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
Int -> RPAction slc tr s f h -> ShowS
showsPrec :: Int -> RPAction slc tr s f h -> ShowS
$cshow :: forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
RPAction slc tr s f h -> String
show :: RPAction slc tr s f h -> String
$cshowList :: forall slc tr s f h.
(Show slc, Show tr, Show s, Show f, Show h) =>
[RPAction slc tr s f h] -> ShowS
showList :: [RPAction slc tr s f h] -> ShowS
Show)

-- Replay Buffer
-- -------------

data ReplayStep dev tr tr' slc s f h = ReplayStep
  { forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> RPState tr tr' slc s f h
replayState :: !(RPState tr tr' slc s f h)
  , forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> RPAction slc tr s f h
replayAction :: !(RPAction slc tr s f h)
  , forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> QEncoding dev '[]
replayStep :: !(QEncoding dev '[])
  , forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> Maybe (RPState tr tr' slc s f h)
replayNextState :: !(Maybe (RPState tr tr' slc s f h))
  , forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> [QEncoding dev '[]]
replayNextSteps :: ![QEncoding dev '[]]
  , forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayStep dev tr tr' slc s f h -> QType
replayReward :: !QType
  }

instance (Show slc, Show s, Show f, Show h, Show tr, Show tr') => Show (ReplayStep dev tr tr' slc s f h) where
  show :: ReplayStep dev tr tr' slc s f h -> String
show (ReplayStep RPState tr tr' slc s f h
s (RPAction Action slc tr s f h
a) QEncoding dev '[]
_ Maybe (RPState tr tr' slc s f h)
s' [QEncoding dev '[]]
_ QType
r) =
    RPState tr tr' slc s f h -> String
forall a. Show a => a -> String
show RPState tr tr' slc s f h
s String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" -> " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Maybe (RPState tr tr' slc s f h) -> String
forall a. Show a => a -> String
show Maybe (RPState tr tr' slc s f h)
s' String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> QType -> String
forall a. Show a => a -> String
show QType
r String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"\n  " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
act
   where
    act :: String
act = case Action slc tr s f h
a of
      Left (ActionSingle SingleParent slc tr
_ LeftmostSingle s f
op) -> LeftmostSingle s f -> String
forall a. Show a => a -> String
show LeftmostSingle s f
op
      Right (ActionDouble DoubleParent slc tr
_ LeftmostDouble s f h
op) -> LeftmostDouble s f h -> String
forall a. Show a => a -> String
show LeftmostDouble s f h
op

data ReplayBuffer dev tr tr' slc s f h
  = ReplayBuffer !Int ![ReplayStep dev tr tr' slc s f h]
  deriving (Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
[ReplayBuffer dev tr tr' slc s f h] -> ShowS
ReplayBuffer dev tr tr' slc s f h -> String
(Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS)
-> (ReplayBuffer dev tr tr' slc s f h -> String)
-> ([ReplayBuffer dev tr tr' slc s f h] -> ShowS)
-> Show (ReplayBuffer dev tr tr' slc s f h)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
[ReplayBuffer dev tr tr' slc s f h] -> ShowS
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
ReplayBuffer dev tr tr' slc s f h -> String
$cshowsPrec :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
showsPrec :: Int -> ReplayBuffer dev tr tr' slc s f h -> ShowS
$cshow :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
ReplayBuffer dev tr tr' slc s f h -> String
show :: ReplayBuffer dev tr tr' slc s f h -> String
$cshowList :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
(Show slc, Show s, Show f, Show h, Show tr, Show tr') =>
[ReplayBuffer dev tr tr' slc s f h] -> ShowS
showList :: [ReplayBuffer dev tr tr' slc s f h] -> ShowS
Show)

mkReplayBuffer :: Int -> ReplayBuffer dev tr tr' slc s f h
mkReplayBuffer :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int -> ReplayBuffer dev tr tr' slc s f h
mkReplayBuffer Int
n = Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
ReplayBuffer Int
n []

seedReplayBuffer :: Int -> [ReplayStep dev tr tr' slc s f h] -> ReplayBuffer dev tr tr' slc s f h
seedReplayBuffer :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
seedReplayBuffer Int
n [ReplayStep dev tr tr' slc s f h]
steps = Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
ReplayBuffer Int
n ([ReplayStep dev tr tr' slc s f h]
 -> ReplayBuffer dev tr tr' slc s f h)
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall a b. (a -> b) -> a -> b
$ Int
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. Int -> [a] -> [a]
take Int
n [ReplayStep dev tr tr' slc s f h]
steps

pushStep
  :: ReplayBuffer dev tr tr' slc s f h
  -> ReplayStep dev tr tr' slc s f h
  -> ReplayBuffer dev tr tr' slc s f h
pushStep :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> ReplayStep dev tr tr' slc s f h
-> ReplayBuffer dev tr tr' slc s f h
pushStep (ReplayBuffer Int
n [ReplayStep dev tr tr' slc s f h]
queue) ReplayStep dev tr tr' slc s f h
trans = Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
Int
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
ReplayBuffer Int
n ([ReplayStep dev tr tr' slc s f h]
 -> ReplayBuffer dev tr tr' slc s f h)
-> [ReplayStep dev tr tr' slc s f h]
-> ReplayBuffer dev tr tr' slc s f h
forall a b. (a -> b) -> a -> b
$ Int
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. Int -> [a] -> [a]
take Int
n ([ReplayStep dev tr tr' slc s f h]
 -> [ReplayStep dev tr tr' slc s f h])
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a b. (a -> b) -> a -> b
$ ReplayStep dev tr tr' slc s f h
trans ReplayStep dev tr tr' slc s f h
-> [ReplayStep dev tr tr' slc s f h]
-> [ReplayStep dev tr tr' slc s f h]
forall a. a -> [a] -> [a]
: [ReplayStep dev tr tr' slc s f h]
queue

sampleSteps
  :: ReplayBuffer dev tr tr' slc s f h
  -> Int
  -> IO [ReplayStep dev tr tr' slc s f h]
sampleSteps :: forall (dev :: (DeviceType, Nat)) tr tr' slc s f h.
ReplayBuffer dev tr tr' slc s f h
-> Int -> IO [ReplayStep dev tr tr' slc s f h]
sampleSteps (ReplayBuffer Int
_ [ReplayStep dev tr tr' slc s f h]
queue) Int
n = do
  -- not great, but shuffle' doesn't integrated with StatefulGen
  gen <- (StdGen -> (StdGen, StdGen)) -> IO StdGen
forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom StdGen -> (StdGen, StdGen)
forall g. RandomGen g => g -> (g, g)
Rand.split
  pure $ take n (shuffle' queue (length queue) gen)