{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module Inference.Conjugate where

import Control.Monad (replicateM)
import Control.Monad.Primitive
  ( PrimMonad
  , PrimState
  )
import Control.Monad.Reader
  ( ReaderT
  , runReaderT
  )
import Control.Monad.Reader.Class (MonadReader (ask))
import Control.Monad.State
  ( State
  , StateT
  , evalState
  , evalStateT
  , execStateT
  , runStateT
  )
import Control.Monad.State.Class
  ( get
  , modify
  , put
  )
import Control.Monad.Trans.Maybe
  ( MaybeT (MaybeT)
  , runMaybeT
  )
import Control.Monad.Writer
import Data.Dynamic
  ( Dynamic
  , Typeable
  , fromDynamic
  , toDyn
  )
import Data.Kind
import Data.Maybe (fromMaybe)
import Data.MultiSet as MS
import qualified Data.Sequence as S
import Data.Typeable
  ( Proxy (Proxy)
  , typeRep
  )
import qualified Data.Vector as V
import qualified Debug.Trace as DT
import GHC.Float (int2Double)
import GHC.Generics
import GHC.TypeNats
import Lens.Micro
import Lens.Micro.Extras
import Lens.Micro.TH (makeLenses)
import Numeric.SpecFunctions
  ( logBeta
  , logChoose
  , logFactorial
  , logGamma
  )
import System.Random.MWC.Probability hiding
  ( Uniform
  )

----------------------------------------------------------------
-- type classes and families for distributions and conjugates --
----------------------------------------------------------------

{- | Describes a family of distributions with a fixed form.
For example, a 'Bernoulli' distribution is parameterized by a probability @p@
and produces binary samples
(@True@ with probability @p@, @False@ with probability @1-p@).

Its 'Distribution' instance is:
> instance Distribution Bernoulli where
>   type Params Bernoulli = Double
>   type Support Bernoulli = Bool
>   distSample _ = uncurry bernoulli
>   distLogP _ p True = log p
>   distLogP _ p False = log (1 - p)
-}
class Distribution a where
  type Params a :: Type
  type Support a :: Type
  distSample :: (PrimMonad m) => a -> Params a -> Prob m (Support a)
  distLogP :: a -> Params a -> Support a -> Double

-- -- | Used as a type-lifted kind for conjugate distribution pairs.
-- data Conj a b = Conj a bb

-- | A type-level marker for treating a distribution as a prior.
newtype AsPrior p = AsPrior p

-- -- | A type-level marker for treating a distribution as a likelihood.
-- newtype AsLk l = AsLk l

{- | Marks two distributions as a conjugate pair of prior and likelihood.
The property of such a pair is that the posterior has the same form as the prior
(including the same 'Params' and 'Support'),
and that its parameters can be obtained analytically from the parameters of the prior
and a set of observations.

The class method 'updatePrior' returns the parameters of the posterior
given the prior parameters after a single observation.
-}
class (Distribution p, Distribution l, Support p ~ Params l) => Conjugate p l where
  priorSingleton :: p
    -- ^ provides a singleton instance of the prior distribution
    -- in order to make sampling from priors easier.

  updatePrior :: l -> Params p -> Support l -> Params p
  predLogP :: l -> Params p -> Support l -> Double

type family Hyper (a :: k) :: Type
type instance Hyper (AsPrior p) = Params p

type family Probs (a :: k) :: Type
type instance Probs (AsPrior p) = Support p

-- type family Value (a :: k) :: Type
-- type instance Value (Conj p l) = Support l

-- helper types for instantiating hyperparameters, parameters, and values
-- ----------------------------------------------------------------------

newtype HyperRep p = HyperRep {forall p. HyperRep p -> Hyper (AsPrior p)
runHyper :: Hyper (AsPrior p)}
deriving instance (Show (Hyper (AsPrior p))) => Show (HyperRep p)

type instance Hyper (a :: (Type -> Type) -> Type) = a HyperRep

newtype ProbsRep p = ProbsRep {forall p. ProbsRep p -> Probs (AsPrior p)
runProbs :: Probs (AsPrior p)}
deriving instance (Show (Probs (AsPrior p))) => Show (ProbsRep p)

type instance Probs (a :: (Type -> Type) -> Type) = a ProbsRep

-- newtype ValueRep p l = ValueRep { runValue :: Value (Conj p l) }
-- deriving instance Show (Value (Conj p l)) => Show (ValueRep p l)

-- type instance Value (a :: (Type -> Type -> Type) -> Type) = a ValueRep

-----------------------------------------------------
-- generic magic for conjugates and parameter sets --
-----------------------------------------------------

-- Jeffreys prior
-- ---------------

class Jeffreys a where
  jeffreysPrior :: Hyper a

class GJeffreys t where
  gjeffreysPrior :: forall p. t p

instance GJeffreys V1 where
  gjeffreysPrior :: forall (p :: k). V1 p
gjeffreysPrior = V1 p
forall a. HasCallStack => a
undefined -- ok

instance GJeffreys U1 where
  gjeffreysPrior :: forall (p :: k). U1 p
gjeffreysPrior = U1 p
forall k (p :: k). U1 p
U1

-- base case: k is a conjugate distribution
instance (Jeffreys (AsPrior p)) => GJeffreys (K1 i (HyperRep p)) where
  gjeffreysPrior :: forall (p :: k). K1 i (HyperRep p) p
gjeffreysPrior = HyperRep p -> K1 i (HyperRep p) p
forall k i c (p :: k). c -> K1 i c p
K1 (HyperRep p -> K1 i (HyperRep p) p)
-> HyperRep p -> K1 i (HyperRep p) p
forall a b. (a -> b) -> a -> b
$ Hyper (AsPrior p) -> HyperRep p
forall p. Hyper (AsPrior p) -> HyperRep p
HyperRep (Hyper (AsPrior p) -> HyperRep p)
-> Hyper (AsPrior p) -> HyperRep p
forall a b. (a -> b) -> a -> b
$ forall a. Jeffreys a => Hyper a
forall {k} (a :: k). Jeffreys a => Hyper a
jeffreysPrior @(AsPrior p)

-- recursive case: k is another record
instance (Jeffreys k, k HyperRep ~ Hyper k) => GJeffreys (K1 i (k HyperRep)) where
  gjeffreysPrior :: forall (p :: k). K1 i (k HyperRep) p
gjeffreysPrior = k HyperRep -> K1 i (k HyperRep) p
forall k i c (p :: k). c -> K1 i c p
K1 (k HyperRep -> K1 i (k HyperRep) p)
-> k HyperRep -> K1 i (k HyperRep) p
forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). Jeffreys a => Hyper a
forall (a :: (* -> *) -> *). Jeffreys a => Hyper a
jeffreysPrior @k

instance (GJeffreys t) => GJeffreys (M1 i c (t :: Type -> Type)) where
  gjeffreysPrior :: forall p. M1 i c t p
gjeffreysPrior = t p -> M1 i c t p
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 t p
forall p. t p
forall {k} (t :: k -> *) (p :: k). GJeffreys t => t p
gjeffreysPrior

instance (GJeffreys ta, GJeffreys tb) => GJeffreys (ta :*: tb) where
  gjeffreysPrior :: forall (p :: k). (:*:) ta tb p
gjeffreysPrior = forall {k} (t :: k -> *) (p :: k). GJeffreys t => t p
forall (t :: k -> *) (p :: k). GJeffreys t => t p
gjeffreysPrior @ta ta p -> tb p -> (:*:) ta tb p
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall {k} (t :: k -> *) (p :: k). GJeffreys t => t p
forall (t :: k -> *) (p :: k). GJeffreys t => t p
gjeffreysPrior @tb

instance (Generic (t HyperRep), GJeffreys (Rep (t HyperRep))) => Jeffreys (t :: (Type -> Type) -> Type) where
  jeffreysPrior :: Hyper t
jeffreysPrior = Rep (t HyperRep) (ZonkAny 4) -> t HyperRep
forall a x. Generic a => Rep a x -> a
forall x. Rep (t HyperRep) x -> t HyperRep
GHC.Generics.to (forall {k} (t :: k -> *) (p :: k). GJeffreys t => t p
forall (t :: * -> *) p. GJeffreys t => t p
gjeffreysPrior @(Rep (t HyperRep)))

-- uniform prior
-- -------------

class Uniform a where
  uniformPrior :: Hyper a

class GUniform t where
  guniformPrior :: forall p. t p

instance GUniform V1 where
  guniformPrior :: forall (p :: k). V1 p
guniformPrior = V1 p
forall a. HasCallStack => a
undefined -- ok

instance GUniform U1 where
  guniformPrior :: forall (p :: k). U1 p
guniformPrior = U1 p
forall k (p :: k). U1 p
U1

-- base case: k is a conjugate distribution
instance (Uniform (AsPrior p)) => GUniform (K1 i (HyperRep p)) where
  guniformPrior :: forall (p :: k). K1 i (HyperRep p) p
guniformPrior = HyperRep p -> K1 i (HyperRep p) p
forall k i c (p :: k). c -> K1 i c p
K1 (HyperRep p -> K1 i (HyperRep p) p)
-> HyperRep p -> K1 i (HyperRep p) p
forall a b. (a -> b) -> a -> b
$ Hyper (AsPrior p) -> HyperRep p
forall p. Hyper (AsPrior p) -> HyperRep p
HyperRep (Hyper (AsPrior p) -> HyperRep p)
-> Hyper (AsPrior p) -> HyperRep p
forall a b. (a -> b) -> a -> b
$ forall a. Uniform a => Hyper a
forall {k} (a :: k). Uniform a => Hyper a
uniformPrior @(AsPrior p)

-- recursive case: k is another record
instance (Uniform k, k HyperRep ~ Hyper k) => GUniform (K1 i (k HyperRep)) where
  guniformPrior :: forall (p :: k). K1 i (k HyperRep) p
guniformPrior = k HyperRep -> K1 i (k HyperRep) p
forall k i c (p :: k). c -> K1 i c p
K1 (k HyperRep -> K1 i (k HyperRep) p)
-> k HyperRep -> K1 i (k HyperRep) p
forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). Uniform a => Hyper a
forall (a :: (* -> *) -> *). Uniform a => Hyper a
uniformPrior @k

instance (GUniform t) => GUniform (M1 i c (t :: Type -> Type)) where
  guniformPrior :: forall p. M1 i c t p
guniformPrior = t p -> M1 i c t p
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 t p
forall p. t p
forall {k} (t :: k -> *) (p :: k). GUniform t => t p
guniformPrior

instance (GUniform ta, GUniform tb) => GUniform (ta :*: tb) where
  guniformPrior :: forall (p :: k). (:*:) ta tb p
guniformPrior = forall {k} (t :: k -> *) (p :: k). GUniform t => t p
forall (t :: k -> *) (p :: k). GUniform t => t p
guniformPrior @ta ta p -> tb p -> (:*:) ta tb p
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall {k} (t :: k -> *) (p :: k). GUniform t => t p
forall (t :: k -> *) (p :: k). GUniform t => t p
guniformPrior @tb

instance (Generic (t HyperRep), GUniform (Rep (t HyperRep))) => Uniform (t :: (Type -> Type) -> Type) where
  uniformPrior :: Hyper t
uniformPrior = Rep (t HyperRep) (ZonkAny 3) -> t HyperRep
forall a x. Generic a => Rep a x -> a
forall x. Rep (t HyperRep) x -> t HyperRep
GHC.Generics.to (forall {k} (t :: k -> *) (p :: k). GUniform t => t p
forall (t :: * -> *) p. GUniform t => t p
guniformPrior @(Rep (t HyperRep)))

-- sampling from prior
-- ------------------

class Prior a where
  sampleProbs :: (PrimMonad m) => Hyper a -> Prob m (Probs a)
  expectedProbs :: Hyper a -> Probs a

class GPrior i o where
  gsampleProbs :: forall m p. (PrimMonad m) => i p -> Prob m (o p)
  gexpectedProbs :: i p -> o p

instance GPrior V1 V1 where
  gsampleProbs :: forall (m :: * -> *) (p :: k). PrimMonad m => V1 p -> Prob m (V1 p)
gsampleProbs = V1 p -> Prob m (V1 p)
forall a. HasCallStack => a
undefined -- ok
  gexpectedProbs :: forall (p :: k). V1 p -> V1 p
gexpectedProbs = V1 p -> V1 p
forall a. HasCallStack => a
undefined

instance GPrior U1 U1 where
  gsampleProbs :: forall (m :: * -> *) (p :: k). PrimMonad m => U1 p -> Prob m (U1 p)
gsampleProbs U1 p
_ = U1 p -> Prob m (U1 p)
forall a. a -> Prob m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure U1 p
forall k (p :: k). U1 p
U1
  gexpectedProbs :: forall (p :: k). U1 p -> U1 p
gexpectedProbs U1 p
_ = U1 p
forall k (p :: k). U1 p
U1

-- base case: k is a conjugate distribution
instance (Prior (AsPrior p)) => GPrior (K1 i (HyperRep p)) (K1 i (ProbsRep p)) where
  gsampleProbs :: forall (m :: * -> *) (p :: k).
PrimMonad m =>
K1 i (HyperRep p) p -> Prob m (K1 i (ProbsRep p) p)
gsampleProbs (K1 (HyperRep Hyper (AsPrior p)
hyper)) =
    ProbsRep p -> K1 i (ProbsRep p) p
forall k i c (p :: k). c -> K1 i c p
K1 (ProbsRep p -> K1 i (ProbsRep p) p)
-> (Support p -> ProbsRep p) -> Support p -> K1 i (ProbsRep p) p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Probs (AsPrior p) -> ProbsRep p
Support p -> ProbsRep p
forall p. Probs (AsPrior p) -> ProbsRep p
ProbsRep (Support p -> K1 i (ProbsRep p) p)
-> Prob m (Support p) -> Prob m (K1 i (ProbsRep p) p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *).
(Prior a, PrimMonad m) =>
Hyper a -> Prob m (Probs a)
forall {k} (a :: k) (m :: * -> *).
(Prior a, PrimMonad m) =>
Hyper a -> Prob m (Probs a)
sampleProbs @(AsPrior p) Hyper (AsPrior p)
hyper
  gexpectedProbs :: forall (p :: k). K1 i (HyperRep p) p -> K1 i (ProbsRep p) p
gexpectedProbs (K1 (HyperRep Hyper (AsPrior p)
hyper)) = ProbsRep p -> K1 i (ProbsRep p) p
forall k i c (p :: k). c -> K1 i c p
K1 (ProbsRep p -> K1 i (ProbsRep p) p)
-> ProbsRep p -> K1 i (ProbsRep p) p
forall a b. (a -> b) -> a -> b
$ Probs (AsPrior p) -> ProbsRep p
forall p. Probs (AsPrior p) -> ProbsRep p
ProbsRep (Probs (AsPrior p) -> ProbsRep p)
-> Probs (AsPrior p) -> ProbsRep p
forall a b. (a -> b) -> a -> b
$ forall a. Prior a => Hyper a -> Probs a
forall {k} (a :: k). Prior a => Hyper a -> Probs a
expectedProbs @(AsPrior p) Hyper (AsPrior p)
hyper

-- recursive case: k is another record
instance
  (Prior k, k HyperRep ~ Hyper k, k ProbsRep ~ Probs k)
  => GPrior (K1 i (k HyperRep)) (K1 i (k ProbsRep))
  where
  gsampleProbs :: forall (m :: * -> *) (p :: k).
PrimMonad m =>
K1 i (k HyperRep) p -> Prob m (K1 i (k ProbsRep) p)
gsampleProbs (K1 k HyperRep
hyper) = k ProbsRep -> K1 i (k ProbsRep) p
forall k i c (p :: k). c -> K1 i c p
K1 (k ProbsRep -> K1 i (k ProbsRep) p)
-> Prob m (k ProbsRep) -> Prob m (K1 i (k ProbsRep) p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (a :: k) (m :: * -> *).
(Prior a, PrimMonad m) =>
Hyper a -> Prob m (Probs a)
forall (a :: (* -> *) -> *) (m :: * -> *).
(Prior a, PrimMonad m) =>
Hyper a -> Prob m (Probs a)
sampleProbs @k k HyperRep
Hyper k
hyper
  gexpectedProbs :: forall (p :: k). K1 i (k HyperRep) p -> K1 i (k ProbsRep) p
gexpectedProbs (K1 k HyperRep
hyper) = k ProbsRep -> K1 i (k ProbsRep) p
forall k i c (p :: k). c -> K1 i c p
K1 (k ProbsRep -> K1 i (k ProbsRep) p)
-> k ProbsRep -> K1 i (k ProbsRep) p
forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k). Prior a => Hyper a -> Probs a
forall (a :: (* -> *) -> *). Prior a => Hyper a -> Probs a
expectedProbs @k k HyperRep
Hyper k
hyper

instance (GPrior ti to) => GPrior (M1 i c ti) (M1 i' c' to) where
  gsampleProbs :: forall (m :: * -> *) (p :: k).
PrimMonad m =>
M1 i c ti p -> Prob m (M1 i' c' to p)
gsampleProbs (M1 ti p
x) = to p -> M1 i' c' to p
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (to p -> M1 i' c' to p) -> Prob m (to p) -> Prob m (M1 i' c' to p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ti p -> Prob m (to p)
forall {k} (i :: k -> *) (o :: k -> *) (m :: * -> *) (p :: k).
(GPrior i o, PrimMonad m) =>
i p -> Prob m (o p)
forall (m :: * -> *) (p :: k). PrimMonad m => ti p -> Prob m (to p)
gsampleProbs ti p
x
  gexpectedProbs :: forall (p :: k). M1 i c ti p -> M1 i' c' to p
gexpectedProbs (M1 ti p
x) = to p -> M1 i' c' to p
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (to p -> M1 i' c' to p) -> to p -> M1 i' c' to p
forall a b. (a -> b) -> a -> b
$ ti p -> to p
forall (p :: k). ti p -> to p
forall {k} (i :: k -> *) (o :: k -> *) (p :: k).
GPrior i o =>
i p -> o p
gexpectedProbs ti p
x

instance (GPrior ia oa, GPrior ib ob) => GPrior (ia :*: ib) (oa :*: ob) where
  gsampleProbs :: forall (m :: * -> *) (p :: k).
PrimMonad m =>
(:*:) ia ib p -> Prob m ((:*:) oa ob p)
gsampleProbs (ia p
a :*: ib p
b) = oa p -> ob p -> (:*:) oa ob p
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) (oa p -> ob p -> (:*:) oa ob p)
-> Prob m (oa p) -> Prob m (ob p -> (:*:) oa ob p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ia p -> Prob m (oa p)
forall {k} (i :: k -> *) (o :: k -> *) (m :: * -> *) (p :: k).
(GPrior i o, PrimMonad m) =>
i p -> Prob m (o p)
forall (m :: * -> *) (p :: k). PrimMonad m => ia p -> Prob m (oa p)
gsampleProbs ia p
a Prob m (ob p -> (:*:) oa ob p)
-> Prob m (ob p) -> Prob m ((:*:) oa ob p)
forall a b. Prob m (a -> b) -> Prob m a -> Prob m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ib p -> Prob m (ob p)
forall {k} (i :: k -> *) (o :: k -> *) (m :: * -> *) (p :: k).
(GPrior i o, PrimMonad m) =>
i p -> Prob m (o p)
forall (m :: * -> *) (p :: k). PrimMonad m => ib p -> Prob m (ob p)
gsampleProbs ib p
b
  gexpectedProbs :: forall (p :: k). (:*:) ia ib p -> (:*:) oa ob p
gexpectedProbs (ia p
a :*: ib p
b) = ia p -> oa p
forall (p :: k). ia p -> oa p
forall {k} (i :: k -> *) (o :: k -> *) (p :: k).
GPrior i o =>
i p -> o p
gexpectedProbs ia p
a oa p -> ob p -> (:*:) oa ob p
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: ib p -> ob p
forall (p :: k). ib p -> ob p
forall {k} (i :: k -> *) (o :: k -> *) (p :: k).
GPrior i o =>
i p -> o p
gexpectedProbs ib p
b

instance (GPrior ia oa, GPrior ib ob) => GPrior (ia :+: ib) (oa :+: ob) where
  gsampleProbs :: forall (m :: * -> *) (p :: k).
PrimMonad m =>
(:+:) ia ib p -> Prob m ((:+:) oa ob p)
gsampleProbs (L1 ia p
a) = oa p -> (:+:) oa ob p
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (oa p -> (:+:) oa ob p) -> Prob m (oa p) -> Prob m ((:+:) oa ob p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ia p -> Prob m (oa p)
forall {k} (i :: k -> *) (o :: k -> *) (m :: * -> *) (p :: k).
(GPrior i o, PrimMonad m) =>
i p -> Prob m (o p)
forall (m :: * -> *) (p :: k). PrimMonad m => ia p -> Prob m (oa p)
gsampleProbs ia p
a
  gsampleProbs (R1 ib p
b) = ob p -> (:+:) oa ob p
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (ob p -> (:+:) oa ob p) -> Prob m (ob p) -> Prob m ((:+:) oa ob p)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ib p -> Prob m (ob p)
forall {k} (i :: k -> *) (o :: k -> *) (m :: * -> *) (p :: k).
(GPrior i o, PrimMonad m) =>
i p -> Prob m (o p)
forall (m :: * -> *) (p :: k). PrimMonad m => ib p -> Prob m (ob p)
gsampleProbs ib p
b
  gexpectedProbs :: forall (p :: k). (:+:) ia ib p -> (:+:) oa ob p
gexpectedProbs (L1 ia p
a) = oa p -> (:+:) oa ob p
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (oa p -> (:+:) oa ob p) -> oa p -> (:+:) oa ob p
forall a b. (a -> b) -> a -> b
$ ia p -> oa p
forall (p :: k). ia p -> oa p
forall {k} (i :: k -> *) (o :: k -> *) (p :: k).
GPrior i o =>
i p -> o p
gexpectedProbs ia p
a
  gexpectedProbs (R1 ib p
b) = ob p -> (:+:) oa ob p
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (ob p -> (:+:) oa ob p) -> ob p -> (:+:) oa ob p
forall a b. (a -> b) -> a -> b
$ ib p -> ob p
forall (p :: k). ib p -> ob p
forall {k} (i :: k -> *) (o :: k -> *) (p :: k).
GPrior i o =>
i p -> o p
gexpectedProbs ib p
b

instance
  ( Generic (a HyperRep)
  , Generic (a ProbsRep)
  , GPrior (Rep (a HyperRep)) (Rep (a ProbsRep))
  )
  => Prior (a :: (Type -> Type) -> Type)
  where
  sampleProbs :: forall (m :: * -> *). PrimMonad m => Hyper a -> Prob m (Probs a)
sampleProbs Hyper a
hyper = Rep (a ProbsRep) (ZonkAny 1) -> a ProbsRep
forall a x. Generic a => Rep a x -> a
forall x. Rep (a ProbsRep) x -> a ProbsRep
GHC.Generics.to (Rep (a ProbsRep) (ZonkAny 1) -> a ProbsRep)
-> Prob m (Rep (a ProbsRep) (ZonkAny 1)) -> Prob m (a ProbsRep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Rep (a HyperRep) (ZonkAny 1)
-> Prob m (Rep (a ProbsRep) (ZonkAny 1))
forall {k} (i :: k -> *) (o :: k -> *) (m :: * -> *) (p :: k).
(GPrior i o, PrimMonad m) =>
i p -> Prob m (o p)
forall (m :: * -> *) p.
PrimMonad m =>
Rep (a HyperRep) p -> Prob m (Rep (a ProbsRep) p)
gsampleProbs (a HyperRep -> Rep (a HyperRep) (ZonkAny 1)
forall x. a HyperRep -> Rep (a HyperRep) x
forall a x. Generic a => a -> Rep a x
from a HyperRep
Hyper a
hyper)
  expectedProbs :: Hyper a -> Probs a
expectedProbs Hyper a
hyper = Rep (Probs a) (ZonkAny 2) -> Probs a
forall a x. Generic a => Rep a x -> a
forall x. Rep (Probs a) x -> Probs a
GHC.Generics.to (Rep (Probs a) (ZonkAny 2) -> Probs a)
-> Rep (Probs a) (ZonkAny 2) -> Probs a
forall a b. (a -> b) -> a -> b
$ Rep (a HyperRep) (ZonkAny 2) -> Rep (Probs a) (ZonkAny 2)
forall p. Rep (a HyperRep) p -> Rep (Probs a) p
forall {k} (i :: k -> *) (o :: k -> *) (p :: k).
GPrior i o =>
i p -> o p
gexpectedProbs (Rep (a HyperRep) (ZonkAny 2) -> Rep (Probs a) (ZonkAny 2))
-> Rep (a HyperRep) (ZonkAny 2) -> Rep (Probs a) (ZonkAny 2)
forall a b. (a -> b) -> a -> b
$ a HyperRep -> Rep (a HyperRep) (ZonkAny 2)
forall x. a HyperRep -> Rep (a HyperRep) x
forall a x. Generic a => a -> Rep a x
from a HyperRep
Hyper a
hyper

-----------------------
-- likelihood monads --
-----------------------

type Accessor r p = forall f. Lens' (r f) (f p)

class (Monad m) => RandomInterpreter m r | m -> r where
  type SampleCtx m a :: Constraint
  sampleValue :: (Conjugate p l, SampleCtx m l) => String -> l -> Accessor r p -> m (Support l)
  sampleConst :: (Distribution d, SampleCtx m d) => String -> d -> Params d -> m (Support d)
  permutationPlate :: (Ord a) => Int -> (Int -> m a) -> m [a]

newtype Trace (r :: (Type -> Type) -> Type) = Trace {forall (r :: (* -> *) -> *). Trace r -> Seq (String, Dynamic)
runTrace :: S.Seq (String, Dynamic)}
  deriving (Int -> Trace r -> ShowS
[Trace r] -> ShowS
Trace r -> String
(Int -> Trace r -> ShowS)
-> (Trace r -> String) -> ([Trace r] -> ShowS) -> Show (Trace r)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (r :: (* -> *) -> *). Int -> Trace r -> ShowS
forall (r :: (* -> *) -> *). [Trace r] -> ShowS
forall (r :: (* -> *) -> *). Trace r -> String
$cshowsPrec :: forall (r :: (* -> *) -> *). Int -> Trace r -> ShowS
showsPrec :: Int -> Trace r -> ShowS
$cshow :: forall (r :: (* -> *) -> *). Trace r -> String
show :: Trace r -> String
$cshowList :: forall (r :: (* -> *) -> *). [Trace r] -> ShowS
showList :: [Trace r] -> ShowS
Show)

observeValue
  :: (Conjugate p l, Typeable (Support l), Monad m)
  => String
  -> l
  -> Accessor r p
  -> Support l
  -> StateT (Trace r) m ()
observeValue :: forall p l (m :: * -> *) (r :: (* -> *) -> *).
(Conjugate p l, Typeable (Support l), Monad m) =>
String -> l -> Accessor r p -> Support l -> StateT (Trace r) m ()
observeValue String
name l
_ Accessor r p
_ Support l
val = (Trace r -> Trace r) -> StateT (Trace r) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Trace r -> Trace r) -> StateT (Trace r) m ())
-> (Trace r -> Trace r) -> StateT (Trace r) m ()
forall a b. (a -> b) -> a -> b
$ \(Trace Seq (String, Dynamic)
st) -> Seq (String, Dynamic) -> Trace r
forall (r :: (* -> *) -> *). Seq (String, Dynamic) -> Trace r
Trace (Seq (String, Dynamic) -> Trace r)
-> Seq (String, Dynamic) -> Trace r
forall a b. (a -> b) -> a -> b
$ Seq (String, Dynamic)
st Seq (String, Dynamic) -> (String, Dynamic) -> Seq (String, Dynamic)
forall a. Seq a -> a -> Seq a
S.|> (String
name, Support l -> Dynamic
forall a. Typeable a => a -> Dynamic
toDyn Support l
val)

observeConst
  :: (Distribution d, Typeable (Support d), Monad m)
  => String
  -> d
  -> Params d
  -> Support d
  -> StateT (Trace r) m ()
observeConst :: forall d (m :: * -> *) (r :: (* -> *) -> *).
(Distribution d, Typeable (Support d), Monad m) =>
String -> d -> Params d -> Support d -> StateT (Trace r) m ()
observeConst String
name d
_ Params d
_ Support d
val = (Trace r -> Trace r) -> StateT (Trace r) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Trace r -> Trace r) -> StateT (Trace r) m ())
-> (Trace r -> Trace r) -> StateT (Trace r) m ()
forall a b. (a -> b) -> a -> b
$ \(Trace Seq (String, Dynamic)
st) -> Seq (String, Dynamic) -> Trace r
forall (r :: (* -> *) -> *). Seq (String, Dynamic) -> Trace r
Trace (Seq (String, Dynamic) -> Trace r)
-> Seq (String, Dynamic) -> Trace r
forall a b. (a -> b) -> a -> b
$ Seq (String, Dynamic)
st Seq (String, Dynamic) -> (String, Dynamic) -> Seq (String, Dynamic)
forall a. Seq a -> a -> Seq a
S.|> (String
name, Support d -> Dynamic
forall a. Typeable a => a -> Dynamic
toDyn Support d
val)

takeTrace :: (Typeable a) => Trace r -> Maybe ((String, a), Trace r)
takeTrace :: forall a (r :: (* -> *) -> *).
Typeable a =>
Trace r -> Maybe ((String, a), Trace r)
takeTrace (Trace Seq (String, Dynamic)
t) = do
  ((name, valDyn), rest) <- case Seq (String, Dynamic) -> ViewL (String, Dynamic)
forall a. Seq a -> ViewL a
S.viewl Seq (String, Dynamic)
t of
    ViewL (String, Dynamic)
S.EmptyL -> Maybe ((String, Dynamic), Seq (String, Dynamic))
forall a. Maybe a
Nothing
    (String, Dynamic)
entry S.:< Seq (String, Dynamic)
rest -> ((String, Dynamic), Seq (String, Dynamic))
-> Maybe ((String, Dynamic), Seq (String, Dynamic))
forall a. a -> Maybe a
Just ((String, Dynamic)
entry, Seq (String, Dynamic)
rest)
  val <- fromDynamic valDyn
  pure ((name, val), Trace rest)

peekTrace :: Trace r -> Maybe (String, Dynamic)
peekTrace :: forall (r :: (* -> *) -> *). Trace r -> Maybe (String, Dynamic)
peekTrace (Trace Seq (String, Dynamic)
t) = case Seq (String, Dynamic) -> ViewL (String, Dynamic)
forall a. Seq a -> ViewL a
S.viewl Seq (String, Dynamic)
t of
  ViewL (String, Dynamic)
S.EmptyL -> Maybe (String, Dynamic)
forall a. Maybe a
Nothing
  (String, Dynamic)
item S.:< Seq (String, Dynamic)
_rest -> (String, Dynamic) -> Maybe (String, Dynamic)
forall a. a -> Maybe a
Just (String, Dynamic)
item

-- just sample
-- -----------

newtype SampleI m r a = SampleI (ReaderT (r ProbsRep) (Prob m) a)
  deriving ((forall a b. (a -> b) -> SampleI m r a -> SampleI m r b)
-> (forall a b. a -> SampleI m r b -> SampleI m r a)
-> Functor (SampleI m r)
forall a b. a -> SampleI m r b -> SampleI m r a
forall a b. (a -> b) -> SampleI m r a -> SampleI m r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
a -> SampleI m r b -> SampleI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
(a -> b) -> SampleI m r a -> SampleI m r b
$cfmap :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
(a -> b) -> SampleI m r a -> SampleI m r b
fmap :: forall a b. (a -> b) -> SampleI m r a -> SampleI m r b
$c<$ :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
a -> SampleI m r b -> SampleI m r a
<$ :: forall a b. a -> SampleI m r b -> SampleI m r a
Functor, Functor (SampleI m r)
Functor (SampleI m r) =>
(forall a. a -> SampleI m r a)
-> (forall a b.
    SampleI m r (a -> b) -> SampleI m r a -> SampleI m r b)
-> (forall a b c.
    (a -> b -> c) -> SampleI m r a -> SampleI m r b -> SampleI m r c)
-> (forall a b. SampleI m r a -> SampleI m r b -> SampleI m r b)
-> (forall a b. SampleI m r a -> SampleI m r b -> SampleI m r a)
-> Applicative (SampleI m r)
forall a. a -> SampleI m r a
forall a b. SampleI m r a -> SampleI m r b -> SampleI m r a
forall a b. SampleI m r a -> SampleI m r b -> SampleI m r b
forall a b. SampleI m r (a -> b) -> SampleI m r a -> SampleI m r b
forall a b c.
(a -> b -> c) -> SampleI m r a -> SampleI m r b -> SampleI m r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *) (r :: (* -> *) -> *).
Monad m =>
Functor (SampleI m r)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> SampleI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> SampleI m r b -> SampleI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> SampleI m r b -> SampleI m r b
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r (a -> b) -> SampleI m r a -> SampleI m r b
forall (m :: * -> *) (r :: (* -> *) -> *) a b c.
Monad m =>
(a -> b -> c) -> SampleI m r a -> SampleI m r b -> SampleI m r c
$cpure :: forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> SampleI m r a
pure :: forall a. a -> SampleI m r a
$c<*> :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r (a -> b) -> SampleI m r a -> SampleI m r b
<*> :: forall a b. SampleI m r (a -> b) -> SampleI m r a -> SampleI m r b
$cliftA2 :: forall (m :: * -> *) (r :: (* -> *) -> *) a b c.
Monad m =>
(a -> b -> c) -> SampleI m r a -> SampleI m r b -> SampleI m r c
liftA2 :: forall a b c.
(a -> b -> c) -> SampleI m r a -> SampleI m r b -> SampleI m r c
$c*> :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> SampleI m r b -> SampleI m r b
*> :: forall a b. SampleI m r a -> SampleI m r b -> SampleI m r b
$c<* :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> SampleI m r b -> SampleI m r a
<* :: forall a b. SampleI m r a -> SampleI m r b -> SampleI m r a
Applicative, Applicative (SampleI m r)
Applicative (SampleI m r) =>
(forall a b.
 SampleI m r a -> (a -> SampleI m r b) -> SampleI m r b)
-> (forall a b. SampleI m r a -> SampleI m r b -> SampleI m r b)
-> (forall a. a -> SampleI m r a)
-> Monad (SampleI m r)
forall a. a -> SampleI m r a
forall a b. SampleI m r a -> SampleI m r b -> SampleI m r b
forall a b. SampleI m r a -> (a -> SampleI m r b) -> SampleI m r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (m :: * -> *) (r :: (* -> *) -> *).
Monad m =>
Applicative (SampleI m r)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> SampleI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> SampleI m r b -> SampleI m r b
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> (a -> SampleI m r b) -> SampleI m r b
$c>>= :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> (a -> SampleI m r b) -> SampleI m r b
>>= :: forall a b. SampleI m r a -> (a -> SampleI m r b) -> SampleI m r b
$c>> :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
SampleI m r a -> SampleI m r b -> SampleI m r b
>> :: forall a b. SampleI m r a -> SampleI m r b -> SampleI m r b
$creturn :: forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> SampleI m r a
return :: forall a. a -> SampleI m r a
Monad)

instance (PrimMonad m) => RandomInterpreter (SampleI m r) r where
  type SampleCtx (SampleI m r) a = ()
  sampleValue
    :: forall p l
     . (Conjugate p l)
    => String
    -> l
    -> Accessor r p
    -> SampleI m r (Support l)
  sampleValue :: forall p l.
Conjugate p l =>
String -> l -> Accessor r p -> SampleI m r (Support l)
sampleValue String
_ l
lk Accessor r p
getProbs = ReaderT (r ProbsRep) (Prob m) (Support l)
-> SampleI m r (Support l)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (Prob m) a -> SampleI m r a
SampleI (ReaderT (r ProbsRep) (Prob m) (Support l)
 -> SampleI m r (Support l))
-> ReaderT (r ProbsRep) (Prob m) (Support l)
-> SampleI m r (Support l)
forall a b. (a -> b) -> a -> b
$ do
    probs <- ReaderT (r ProbsRep) (Prob m) (r ProbsRep)
forall r (m :: * -> *). MonadReader r m => m r
ask
    lift $ distSample lk $ runProbs $ view getProbs probs
  sampleConst
    :: forall d
     . (Distribution d)
    => String
    -> d
    -> Params d
    -> SampleI m r (Support d)
  sampleConst :: forall d.
Distribution d =>
String -> d -> Params d -> SampleI m r (Support d)
sampleConst String
_ d
dist Params d
params = ReaderT (r ProbsRep) (Prob m) (Support d)
-> SampleI m r (Support d)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (Prob m) a -> SampleI m r a
SampleI (ReaderT (r ProbsRep) (Prob m) (Support d)
 -> SampleI m r (Support d))
-> ReaderT (r ProbsRep) (Prob m) (Support d)
-> SampleI m r (Support d)
forall a b. (a -> b) -> a -> b
$ Prob m (Support d) -> ReaderT (r ProbsRep) (Prob m) (Support d)
forall (m :: * -> *) a. Monad m => m a -> ReaderT (r ProbsRep) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (Support d) -> ReaderT (r ProbsRep) (Prob m) (Support d))
-> Prob m (Support d) -> ReaderT (r ProbsRep) (Prob m) (Support d)
forall a b. (a -> b) -> a -> b
$ d -> Params d -> Prob m (Support d)
forall a (m :: * -> *).
(Distribution a, PrimMonad m) =>
a -> Params a -> Prob m (Support a)
forall (m :: * -> *).
PrimMonad m =>
d -> Params d -> Prob m (Support d)
distSample d
dist Params d
params
  permutationPlate :: forall a. Ord a => Int -> (Int -> SampleI m r a) -> SampleI m r [a]
permutationPlate = Int -> (Int -> SampleI m r a) -> SampleI m r [a]
forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m [a]
replicateMWithI

sampleResult :: p ProbsRep -> SampleI m p a -> Gen (PrimState m) -> m a
sampleResult :: forall (p :: (* -> *) -> *) (m :: * -> *) a.
p ProbsRep -> SampleI m p a -> Gen (PrimState m) -> m a
sampleResult p ProbsRep
probs (SampleI ReaderT (p ProbsRep) (Prob m) a
a) = Prob m a -> Gen (PrimState m) -> m a
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample (ReaderT (p ProbsRep) (Prob m) a -> p ProbsRep -> Prob m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (p ProbsRep) (Prob m) a
a p ProbsRep
probs)

-- sample and trace the execution process
-- --------------------------------------

newtype TraceI m r a = TraceI (ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) a)
  deriving ((forall a b. (a -> b) -> TraceI m r a -> TraceI m r b)
-> (forall a b. a -> TraceI m r b -> TraceI m r a)
-> Functor (TraceI m r)
forall a b. a -> TraceI m r b -> TraceI m r a
forall a b. (a -> b) -> TraceI m r a -> TraceI m r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
a -> TraceI m r b -> TraceI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
(a -> b) -> TraceI m r a -> TraceI m r b
$cfmap :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
(a -> b) -> TraceI m r a -> TraceI m r b
fmap :: forall a b. (a -> b) -> TraceI m r a -> TraceI m r b
$c<$ :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Functor m =>
a -> TraceI m r b -> TraceI m r a
<$ :: forall a b. a -> TraceI m r b -> TraceI m r a
Functor, Functor (TraceI m r)
Functor (TraceI m r) =>
(forall a. a -> TraceI m r a)
-> (forall a b.
    TraceI m r (a -> b) -> TraceI m r a -> TraceI m r b)
-> (forall a b c.
    (a -> b -> c) -> TraceI m r a -> TraceI m r b -> TraceI m r c)
-> (forall a b. TraceI m r a -> TraceI m r b -> TraceI m r b)
-> (forall a b. TraceI m r a -> TraceI m r b -> TraceI m r a)
-> Applicative (TraceI m r)
forall a. a -> TraceI m r a
forall a b. TraceI m r a -> TraceI m r b -> TraceI m r a
forall a b. TraceI m r a -> TraceI m r b -> TraceI m r b
forall a b. TraceI m r (a -> b) -> TraceI m r a -> TraceI m r b
forall a b c.
(a -> b -> c) -> TraceI m r a -> TraceI m r b -> TraceI m r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (m :: * -> *) (r :: (* -> *) -> *).
Monad m =>
Functor (TraceI m r)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> TraceI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> TraceI m r b -> TraceI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> TraceI m r b -> TraceI m r b
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r (a -> b) -> TraceI m r a -> TraceI m r b
forall (m :: * -> *) (r :: (* -> *) -> *) a b c.
Monad m =>
(a -> b -> c) -> TraceI m r a -> TraceI m r b -> TraceI m r c
$cpure :: forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> TraceI m r a
pure :: forall a. a -> TraceI m r a
$c<*> :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r (a -> b) -> TraceI m r a -> TraceI m r b
<*> :: forall a b. TraceI m r (a -> b) -> TraceI m r a -> TraceI m r b
$cliftA2 :: forall (m :: * -> *) (r :: (* -> *) -> *) a b c.
Monad m =>
(a -> b -> c) -> TraceI m r a -> TraceI m r b -> TraceI m r c
liftA2 :: forall a b c.
(a -> b -> c) -> TraceI m r a -> TraceI m r b -> TraceI m r c
$c*> :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> TraceI m r b -> TraceI m r b
*> :: forall a b. TraceI m r a -> TraceI m r b -> TraceI m r b
$c<* :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> TraceI m r b -> TraceI m r a
<* :: forall a b. TraceI m r a -> TraceI m r b -> TraceI m r a
Applicative, Applicative (TraceI m r)
Applicative (TraceI m r) =>
(forall a b. TraceI m r a -> (a -> TraceI m r b) -> TraceI m r b)
-> (forall a b. TraceI m r a -> TraceI m r b -> TraceI m r b)
-> (forall a. a -> TraceI m r a)
-> Monad (TraceI m r)
forall a. a -> TraceI m r a
forall a b. TraceI m r a -> TraceI m r b -> TraceI m r b
forall a b. TraceI m r a -> (a -> TraceI m r b) -> TraceI m r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (m :: * -> *) (r :: (* -> *) -> *).
Monad m =>
Applicative (TraceI m r)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> TraceI m r a
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> TraceI m r b -> TraceI m r b
forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> (a -> TraceI m r b) -> TraceI m r b
$c>>= :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> (a -> TraceI m r b) -> TraceI m r b
>>= :: forall a b. TraceI m r a -> (a -> TraceI m r b) -> TraceI m r b
$c>> :: forall (m :: * -> *) (r :: (* -> *) -> *) a b.
Monad m =>
TraceI m r a -> TraceI m r b -> TraceI m r b
>> :: forall a b. TraceI m r a -> TraceI m r b -> TraceI m r b
$creturn :: forall (m :: * -> *) (r :: (* -> *) -> *) a.
Monad m =>
a -> TraceI m r a
return :: forall a. a -> TraceI m r a
Monad)

instance (PrimMonad m) => RandomInterpreter (TraceI m r) r where
  type SampleCtx (TraceI m r) l = Typeable (Support l)
  sampleValue
    :: forall p l
     . (Conjugate p l, Typeable (Support l))
    => String
    -> l
    -> Accessor r p
    -> TraceI m r (Support l)
  sampleValue :: forall p l.
(Conjugate p l, Typeable (Support l)) =>
String -> l -> Accessor r p -> TraceI m r (Support l)
sampleValue String
name l
lk Accessor r p
getProbs = ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support l)
-> TraceI m r (Support l)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) a -> TraceI m r a
TraceI (ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support l)
 -> TraceI m r (Support l))
-> ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support l)
-> TraceI m r (Support l)
forall a b. (a -> b) -> a -> b
$ do
    probs <- ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (r ProbsRep)
forall r (m :: * -> *). MonadReader r m => m r
ask
    val <- lift $ lift $ distSample lk $ runProbs $ view getProbs probs
    modify $ \(Trace Seq (String, Dynamic)
obs) -> Seq (String, Dynamic) -> Trace r
forall (r :: (* -> *) -> *). Seq (String, Dynamic) -> Trace r
Trace (Seq (String, Dynamic) -> Trace r)
-> Seq (String, Dynamic) -> Trace r
forall a b. (a -> b) -> a -> b
$ Seq (String, Dynamic)
obs Seq (String, Dynamic) -> (String, Dynamic) -> Seq (String, Dynamic)
forall a. Seq a -> a -> Seq a
S.|> (String
name, Support l -> Dynamic
forall a. Typeable a => a -> Dynamic
toDyn Support l
val)
    pure val
  sampleConst
    :: forall d
     . (Distribution d, Typeable (Support d))
    => String
    -> d
    -> Params d
    -> TraceI m r (Support d)
  sampleConst :: forall d.
(Distribution d, Typeable (Support d)) =>
String -> d -> Params d -> TraceI m r (Support d)
sampleConst String
name d
dist Params d
params = ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support d)
-> TraceI m r (Support d)
forall (m :: * -> *) (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) a -> TraceI m r a
TraceI (ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support d)
 -> TraceI m r (Support d))
-> ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support d)
-> TraceI m r (Support d)
forall a b. (a -> b) -> a -> b
$ do
    val <- StateT (Trace r) (Prob m) (Support d)
-> ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support d)
forall (m :: * -> *) a. Monad m => m a -> ReaderT (r ProbsRep) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT (Trace r) (Prob m) (Support d)
 -> ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support d))
-> StateT (Trace r) (Prob m) (Support d)
-> ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) (Support d)
forall a b. (a -> b) -> a -> b
$ Prob m (Support d) -> StateT (Trace r) (Prob m) (Support d)
forall (m :: * -> *) a. Monad m => m a -> StateT (Trace r) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Prob m (Support d) -> StateT (Trace r) (Prob m) (Support d))
-> Prob m (Support d) -> StateT (Trace r) (Prob m) (Support d)
forall a b. (a -> b) -> a -> b
$ d -> Params d -> Prob m (Support d)
forall a (m :: * -> *).
(Distribution a, PrimMonad m) =>
a -> Params a -> Prob m (Support a)
forall (m :: * -> *).
PrimMonad m =>
d -> Params d -> Prob m (Support d)
distSample d
dist Params d
params
    modify $ \(Trace Seq (String, Dynamic)
obs) -> Seq (String, Dynamic) -> Trace r
forall (r :: (* -> *) -> *). Seq (String, Dynamic) -> Trace r
Trace (Seq (String, Dynamic) -> Trace r)
-> Seq (String, Dynamic) -> Trace r
forall a b. (a -> b) -> a -> b
$ Seq (String, Dynamic)
obs Seq (String, Dynamic) -> (String, Dynamic) -> Seq (String, Dynamic)
forall a. Seq a -> a -> Seq a
S.|> (String
name, Support d -> Dynamic
forall a. Typeable a => a -> Dynamic
toDyn Support d
val)
    pure val
  permutationPlate :: forall a. Ord a => Int -> (Int -> TraceI m r a) -> TraceI m r [a]
permutationPlate = Int -> (Int -> TraceI m r a) -> TraceI m r [a]
forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m [a]
replicateMWithI

sampleTrace
  :: r ProbsRep -> TraceI m r a -> Gen (PrimState m) -> m (a, Trace r)
sampleTrace :: forall (r :: (* -> *) -> *) (m :: * -> *) a.
r ProbsRep -> TraceI m r a -> Gen (PrimState m) -> m (a, Trace r)
sampleTrace r ProbsRep
probs (TraceI ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) a
a) = do
  let st :: StateT (Trace r) (Prob m) a
st = ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) a
-> r ProbsRep -> StateT (Trace r) (Prob m) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (r ProbsRep) (StateT (Trace r) (Prob m)) a
a r ProbsRep
probs
      pr :: Prob m (a, Trace r)
pr = StateT (Trace r) (Prob m) a -> Trace r -> Prob m (a, Trace r)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Trace r) (Prob m) a
st (Seq (String, Dynamic) -> Trace r
forall (r :: (* -> *) -> *). Seq (String, Dynamic) -> Trace r
Trace Seq (String, Dynamic)
forall a. Monoid a => a
mempty)
  Prob m (a, Trace r) -> Gen (PrimState m) -> m (a, Trace r)
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample Prob m (a, Trace r)
pr

-- evaluate the probability of a trace
-- -----------------------------------

newtype EvalTraceI r a = EvalTraceI (ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a)
  deriving ((forall a b. (a -> b) -> EvalTraceI r a -> EvalTraceI r b)
-> (forall a b. a -> EvalTraceI r b -> EvalTraceI r a)
-> Functor (EvalTraceI r)
forall a b. a -> EvalTraceI r b -> EvalTraceI r a
forall a b. (a -> b) -> EvalTraceI r a -> EvalTraceI r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (r :: (* -> *) -> *) a b.
a -> EvalTraceI r b -> EvalTraceI r a
forall (r :: (* -> *) -> *) a b.
(a -> b) -> EvalTraceI r a -> EvalTraceI r b
$cfmap :: forall (r :: (* -> *) -> *) a b.
(a -> b) -> EvalTraceI r a -> EvalTraceI r b
fmap :: forall a b. (a -> b) -> EvalTraceI r a -> EvalTraceI r b
$c<$ :: forall (r :: (* -> *) -> *) a b.
a -> EvalTraceI r b -> EvalTraceI r a
<$ :: forall a b. a -> EvalTraceI r b -> EvalTraceI r a
Functor, Functor (EvalTraceI r)
Functor (EvalTraceI r) =>
(forall a. a -> EvalTraceI r a)
-> (forall a b.
    EvalTraceI r (a -> b) -> EvalTraceI r a -> EvalTraceI r b)
-> (forall a b c.
    (a -> b -> c)
    -> EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r c)
-> (forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b)
-> (forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r a)
-> Applicative (EvalTraceI r)
forall a. a -> EvalTraceI r a
forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r a
forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
forall a b.
EvalTraceI r (a -> b) -> EvalTraceI r a -> EvalTraceI r b
forall a b c.
(a -> b -> c) -> EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (r :: (* -> *) -> *). Functor (EvalTraceI r)
forall (r :: (* -> *) -> *) a. a -> EvalTraceI r a
forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r a
forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
forall (r :: (* -> *) -> *) a b.
EvalTraceI r (a -> b) -> EvalTraceI r a -> EvalTraceI r b
forall (r :: (* -> *) -> *) a b c.
(a -> b -> c) -> EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r c
$cpure :: forall (r :: (* -> *) -> *) a. a -> EvalTraceI r a
pure :: forall a. a -> EvalTraceI r a
$c<*> :: forall (r :: (* -> *) -> *) a b.
EvalTraceI r (a -> b) -> EvalTraceI r a -> EvalTraceI r b
<*> :: forall a b.
EvalTraceI r (a -> b) -> EvalTraceI r a -> EvalTraceI r b
$cliftA2 :: forall (r :: (* -> *) -> *) a b c.
(a -> b -> c) -> EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r c
liftA2 :: forall a b c.
(a -> b -> c) -> EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r c
$c*> :: forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
*> :: forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
$c<* :: forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r a
<* :: forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r a
Applicative, Applicative (EvalTraceI r)
Applicative (EvalTraceI r) =>
(forall a b.
 EvalTraceI r a -> (a -> EvalTraceI r b) -> EvalTraceI r b)
-> (forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b)
-> (forall a. a -> EvalTraceI r a)
-> Monad (EvalTraceI r)
forall a. a -> EvalTraceI r a
forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
forall a b.
EvalTraceI r a -> (a -> EvalTraceI r b) -> EvalTraceI r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (r :: (* -> *) -> *). Applicative (EvalTraceI r)
forall (r :: (* -> *) -> *) a. a -> EvalTraceI r a
forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> (a -> EvalTraceI r b) -> EvalTraceI r b
$c>>= :: forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> (a -> EvalTraceI r b) -> EvalTraceI r b
>>= :: forall a b.
EvalTraceI r a -> (a -> EvalTraceI r b) -> EvalTraceI r b
$c>> :: forall (r :: (* -> *) -> *) a b.
EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
>> :: forall a b. EvalTraceI r a -> EvalTraceI r b -> EvalTraceI r b
$creturn :: forall (r :: (* -> *) -> *) a. a -> EvalTraceI r a
return :: forall a. a -> EvalTraceI r a
Monad)

instance RandomInterpreter (EvalTraceI r) r where
  type SampleCtx (EvalTraceI r) l = Typeable (Support l)
  sampleValue
    :: forall p l
     . (Conjugate p l, Typeable (Support l))
    => String
    -> l
    -> Accessor r p
    -> EvalTraceI r (Support l)
  sampleValue :: forall p l.
(Conjugate p l, Typeable (Support l)) =>
String -> l -> Accessor r p -> EvalTraceI r (Support l)
sampleValue String
_ l
lk Accessor r p
getProbs = ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) (Support l)
-> EvalTraceI r (Support l)
forall (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a
-> EvalTraceI r a
EvalTraceI (ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) (Support l)
 -> EvalTraceI r (Support l))
-> ReaderT
     (r ProbsRep) (StateT (Trace r, Double) Maybe) (Support l)
-> EvalTraceI r (Support l)
forall a b. (a -> b) -> a -> b
$ do
    probs <- ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) (r ProbsRep)
forall r (m :: * -> *). MonadReader r m => m r
ask
    (trace, totalLogP) <- get
    ((_name, val), trace') <- lift $ lift $ takeTrace trace
    let logP = l -> Params l -> Support l -> Double
forall a. Distribution a => a -> Params a -> Support a -> Double
distLogP l
lk (ProbsRep p -> Probs (AsPrior p)
forall p. ProbsRep p -> Probs (AsPrior p)
runProbs (ProbsRep p -> Probs (AsPrior p))
-> ProbsRep p -> Probs (AsPrior p)
forall a b. (a -> b) -> a -> b
$ Getting (ProbsRep p) (r ProbsRep) (ProbsRep p)
-> r ProbsRep -> ProbsRep p
forall a s. Getting a s a -> s -> a
view Getting (ProbsRep p) (r ProbsRep) (ProbsRep p)
Accessor r p
getProbs r ProbsRep
probs) Support l
val
    put (trace', totalLogP + logP)
    pure val
  sampleConst
    :: forall d
     . (Distribution d, Typeable (Support d))
    => String
    -> d
    -> Params d
    -> EvalTraceI r (Support d)
  sampleConst :: forall d.
(Distribution d, Typeable (Support d)) =>
String -> d -> Params d -> EvalTraceI r (Support d)
sampleConst String
_ d
dist Params d
params = ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) (Support d)
-> EvalTraceI r (Support d)
forall (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a
-> EvalTraceI r a
EvalTraceI (ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) (Support d)
 -> EvalTraceI r (Support d))
-> ReaderT
     (r ProbsRep) (StateT (Trace r, Double) Maybe) (Support d)
-> EvalTraceI r (Support d)
forall a b. (a -> b) -> a -> b
$ do
    (trace, totalLogP) <- ReaderT
  (r ProbsRep) (StateT (Trace r, Double) Maybe) (Trace r, Double)
forall s (m :: * -> *). MonadState s m => m s
get
    ((_name, val), trace') <- lift $ lift $ takeTrace trace
    let logP = d -> Params d -> Support d -> Double
forall a. Distribution a => a -> Params a -> Support a -> Double
distLogP d
dist Params d
params Support d
val
    put (trace', totalLogP + logP)
    pure val
  permutationPlate :: forall a.
Ord a =>
Int -> (Int -> EvalTraceI r a) -> EvalTraceI r [a]
permutationPlate Int
n Int -> EvalTraceI r a
submodel = ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) [a]
-> EvalTraceI r [a]
forall (r :: (* -> *) -> *) a.
ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a
-> EvalTraceI r a
EvalTraceI (ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) [a]
 -> EvalTraceI r [a])
-> ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) [a]
-> EvalTraceI r [a]
forall a b. (a -> b) -> a -> b
$ do
    probs <- ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) (r ProbsRep)
forall r (m :: * -> *). MonadReader r m => m r
ask
    (trace, totalLogP) <- get
    (results, (trace', logP)) <-
      lift $
        lift $
          runTraceLogP
            probs
            trace
            (replicateMWithI n submodel)
    let unique = [a] -> MultiSet a
forall a. Ord a => [a] -> MultiSet a
MS.fromList [a]
results
        permutations =
          Int -> Double
forall a. Integral a => a -> Double
logFactorial Int
n Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Int -> Double
forall a. Integral a => a -> Double
logFactorial (Int -> Double) -> ((a, Int) -> Int) -> (a, Int) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Int) -> Int
forall a b. (a, b) -> b
snd ((a, Int) -> Double) -> [(a, Int)] -> [Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MultiSet a -> [(a, Int)]
forall a. MultiSet a -> [(a, Int)]
MS.toOccurList MultiSet a
unique)
    put (trace', totalLogP + logP + permutations)
    pure results

runTraceLogP
  :: r ProbsRep -> Trace r -> EvalTraceI r a -> Maybe (a, (Trace r, Double))
runTraceLogP :: forall (r :: (* -> *) -> *) a.
r ProbsRep
-> Trace r -> EvalTraceI r a -> Maybe (a, (Trace r, Double))
runTraceLogP r ProbsRep
probs Trace r
trace (EvalTraceI ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a
model) = do
  StateT (Trace r, Double) Maybe a
-> (Trace r, Double) -> Maybe (a, (Trace r, Double))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a
-> r ProbsRep -> StateT (Trace r, Double) Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (r ProbsRep) (StateT (Trace r, Double) Maybe) a
model r ProbsRep
probs) (Trace r
trace, Double
0)

evalTraceLogP :: r ProbsRep -> Trace r -> EvalTraceI r a -> Maybe (a, Double)
evalTraceLogP :: forall (r :: (* -> *) -> *) a.
r ProbsRep -> Trace r -> EvalTraceI r a -> Maybe (a, Double)
evalTraceLogP r ProbsRep
probs Trace r
trace EvalTraceI r a
model = do
  (val, (_trace, logp)) <- r ProbsRep
-> Trace r -> EvalTraceI r a -> Maybe (a, (Trace r, Double))
forall (r :: (* -> *) -> *) a.
r ProbsRep
-> Trace r -> EvalTraceI r a -> Maybe (a, (Trace r, Double))
runTraceLogP r ProbsRep
probs Trace r
trace EvalTraceI r a
model
  pure (val, logp)

-- evaluate the predictive probability of a trace
-- -----------------------------------

newtype EvalPredTraceI r a = EvalPredTraceI (ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a)
  deriving ((forall a b. (a -> b) -> EvalPredTraceI r a -> EvalPredTraceI r b)
-> (forall a b. a -> EvalPredTraceI r b -> EvalPredTraceI r a)
-> Functor (EvalPredTraceI r)
forall a b. a -> EvalPredTraceI r b -> EvalPredTraceI r a
forall a b. (a -> b) -> EvalPredTraceI r a -> EvalPredTraceI r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (r :: (* -> *) -> *) a b.
a -> EvalPredTraceI r b -> EvalPredTraceI r a
forall (r :: (* -> *) -> *) a b.
(a -> b) -> EvalPredTraceI r a -> EvalPredTraceI r b
$cfmap :: forall (r :: (* -> *) -> *) a b.
(a -> b) -> EvalPredTraceI r a -> EvalPredTraceI r b
fmap :: forall a b. (a -> b) -> EvalPredTraceI r a -> EvalPredTraceI r b
$c<$ :: forall (r :: (* -> *) -> *) a b.
a -> EvalPredTraceI r b -> EvalPredTraceI r a
<$ :: forall a b. a -> EvalPredTraceI r b -> EvalPredTraceI r a
Functor, Functor (EvalPredTraceI r)
Functor (EvalPredTraceI r) =>
(forall a. a -> EvalPredTraceI r a)
-> (forall a b.
    EvalPredTraceI r (a -> b)
    -> EvalPredTraceI r a -> EvalPredTraceI r b)
-> (forall a b c.
    (a -> b -> c)
    -> EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r c)
-> (forall a b.
    EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b)
-> (forall a b.
    EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r a)
-> Applicative (EvalPredTraceI r)
forall a. a -> EvalPredTraceI r a
forall a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r a
forall a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
forall a b.
EvalPredTraceI r (a -> b)
-> EvalPredTraceI r a -> EvalPredTraceI r b
forall a b c.
(a -> b -> c)
-> EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (r :: (* -> *) -> *). Functor (EvalPredTraceI r)
forall (r :: (* -> *) -> *) a. a -> EvalPredTraceI r a
forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r a
forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r (a -> b)
-> EvalPredTraceI r a -> EvalPredTraceI r b
forall (r :: (* -> *) -> *) a b c.
(a -> b -> c)
-> EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r c
$cpure :: forall (r :: (* -> *) -> *) a. a -> EvalPredTraceI r a
pure :: forall a. a -> EvalPredTraceI r a
$c<*> :: forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r (a -> b)
-> EvalPredTraceI r a -> EvalPredTraceI r b
<*> :: forall a b.
EvalPredTraceI r (a -> b)
-> EvalPredTraceI r a -> EvalPredTraceI r b
$cliftA2 :: forall (r :: (* -> *) -> *) a b c.
(a -> b -> c)
-> EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r c
liftA2 :: forall a b c.
(a -> b -> c)
-> EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r c
$c*> :: forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
*> :: forall a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
$c<* :: forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r a
<* :: forall a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r a
Applicative, Applicative (EvalPredTraceI r)
Applicative (EvalPredTraceI r) =>
(forall a b.
 EvalPredTraceI r a
 -> (a -> EvalPredTraceI r b) -> EvalPredTraceI r b)
-> (forall a b.
    EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b)
-> (forall a. a -> EvalPredTraceI r a)
-> Monad (EvalPredTraceI r)
forall a. a -> EvalPredTraceI r a
forall a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
forall a b.
EvalPredTraceI r a
-> (a -> EvalPredTraceI r b) -> EvalPredTraceI r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (r :: (* -> *) -> *). Applicative (EvalPredTraceI r)
forall (r :: (* -> *) -> *) a. a -> EvalPredTraceI r a
forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a
-> (a -> EvalPredTraceI r b) -> EvalPredTraceI r b
$c>>= :: forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a
-> (a -> EvalPredTraceI r b) -> EvalPredTraceI r b
>>= :: forall a b.
EvalPredTraceI r a
-> (a -> EvalPredTraceI r b) -> EvalPredTraceI r b
$c>> :: forall (r :: (* -> *) -> *) a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
>> :: forall a b.
EvalPredTraceI r a -> EvalPredTraceI r b -> EvalPredTraceI r b
$creturn :: forall (r :: (* -> *) -> *) a. a -> EvalPredTraceI r a
return :: forall a. a -> EvalPredTraceI r a
Monad)

instance RandomInterpreter (EvalPredTraceI r) r where
  type SampleCtx (EvalPredTraceI r) l = Typeable (Support l)
  sampleValue
    :: forall p l
     . (Conjugate p l, Typeable (Support l))
    => String
    -> l
    -> Accessor r p
    -> EvalPredTraceI r (Support l)
  sampleValue :: forall p l.
(Conjugate p l, Typeable (Support l)) =>
String -> l -> Accessor r p -> EvalPredTraceI r (Support l)
sampleValue String
_ l
lk Accessor r p
getHyper = ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) (Support l)
-> EvalPredTraceI r (Support l)
forall (r :: (* -> *) -> *) a.
ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a
-> EvalPredTraceI r a
EvalPredTraceI (ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) (Support l)
 -> EvalPredTraceI r (Support l))
-> ReaderT
     (r HyperRep) (StateT (Trace r, Double) Maybe) (Support l)
-> EvalPredTraceI r (Support l)
forall a b. (a -> b) -> a -> b
$ do
    hyper <- ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) (r HyperRep)
forall r (m :: * -> *). MonadReader r m => m r
ask
    (trace, totalLogP) <- get
    ((_name, val), trace') <- lift $ lift $ takeTrace trace
    let logP = forall p l. Conjugate p l => l -> Params p -> Support l -> Double
predLogP @p l
lk (HyperRep p -> Hyper (AsPrior p)
forall p. HyperRep p -> Hyper (AsPrior p)
runHyper (HyperRep p -> Hyper (AsPrior p))
-> HyperRep p -> Hyper (AsPrior p)
forall a b. (a -> b) -> a -> b
$ Getting (HyperRep p) (r HyperRep) (HyperRep p)
-> r HyperRep -> HyperRep p
forall a s. Getting a s a -> s -> a
view Getting (HyperRep p) (r HyperRep) (HyperRep p)
Accessor r p
getHyper r HyperRep
hyper) Support l
val
    put (trace', totalLogP + logP)
    pure val
  sampleConst
    :: forall d
     . (Distribution d, Typeable (Support d))
    => String
    -> d
    -> Params d
    -> EvalPredTraceI r (Support d)
  sampleConst :: forall d.
(Distribution d, Typeable (Support d)) =>
String -> d -> Params d -> EvalPredTraceI r (Support d)
sampleConst String
_ d
dist Params d
params = ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) (Support d)
-> EvalPredTraceI r (Support d)
forall (r :: (* -> *) -> *) a.
ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a
-> EvalPredTraceI r a
EvalPredTraceI (ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) (Support d)
 -> EvalPredTraceI r (Support d))
-> ReaderT
     (r HyperRep) (StateT (Trace r, Double) Maybe) (Support d)
-> EvalPredTraceI r (Support d)
forall a b. (a -> b) -> a -> b
$ do
    (trace, totalLogP) <- ReaderT
  (r HyperRep) (StateT (Trace r, Double) Maybe) (Trace r, Double)
forall s (m :: * -> *). MonadState s m => m s
get
    ((_name, val), trace') <- lift $ lift $ takeTrace trace
    let logP = d -> Params d -> Support d -> Double
forall a. Distribution a => a -> Params a -> Support a -> Double
distLogP d
dist Params d
params Support d
val
    put (trace', totalLogP + logP)
    pure val
  permutationPlate :: forall a.
Ord a =>
Int -> (Int -> EvalPredTraceI r a) -> EvalPredTraceI r [a]
permutationPlate Int
n Int -> EvalPredTraceI r a
submodel = ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) [a]
-> EvalPredTraceI r [a]
forall (r :: (* -> *) -> *) a.
ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a
-> EvalPredTraceI r a
EvalPredTraceI (ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) [a]
 -> EvalPredTraceI r [a])
-> ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) [a]
-> EvalPredTraceI r [a]
forall a b. (a -> b) -> a -> b
$ do
    probs <- ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) (r HyperRep)
forall r (m :: * -> *). MonadReader r m => m r
ask
    (trace, totalLogP) <- get
    (results, (trace', logP)) <-
      lift $
        lift $
          runTracePredLogP
            probs
            trace
            (replicateMWithI n submodel)
    let unique = [a] -> MultiSet a
forall a. Ord a => [a] -> MultiSet a
MS.fromList [a]
results
        permutations =
          Int -> Double
forall a. Integral a => a -> Double
logFactorial Int
n Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Int -> Double
forall a. Integral a => a -> Double
logFactorial (Int -> Double) -> ((a, Int) -> Int) -> (a, Int) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, Int) -> Int
forall a b. (a, b) -> b
snd ((a, Int) -> Double) -> [(a, Int)] -> [Double]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MultiSet a -> [(a, Int)]
forall a. MultiSet a -> [(a, Int)]
MS.toOccurList MultiSet a
unique)
    put (trace', totalLogP + logP + permutations)
    pure results

runTracePredLogP
  :: r HyperRep -> Trace r -> EvalPredTraceI r a -> Maybe (a, (Trace r, Double))
runTracePredLogP :: forall (r :: (* -> *) -> *) a.
r HyperRep
-> Trace r -> EvalPredTraceI r a -> Maybe (a, (Trace r, Double))
runTracePredLogP r HyperRep
hyper Trace r
trace (EvalPredTraceI ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a
model) = do
  StateT (Trace r, Double) Maybe a
-> (Trace r, Double) -> Maybe (a, (Trace r, Double))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a
-> r HyperRep -> StateT (Trace r, Double) Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (r HyperRep) (StateT (Trace r, Double) Maybe) a
model r HyperRep
hyper) (Trace r
trace, Double
0)

evalTracePredLogP
  :: r HyperRep -> Trace r -> EvalPredTraceI r a -> Maybe (a, Double)
evalTracePredLogP :: forall (r :: (* -> *) -> *) a.
r HyperRep -> Trace r -> EvalPredTraceI r a -> Maybe (a, Double)
evalTracePredLogP r HyperRep
hyper Trace r
trace EvalPredTraceI r a
model = do
  (val, (_trace, logp)) <- r HyperRep
-> Trace r -> EvalPredTraceI r a -> Maybe (a, (Trace r, Double))
forall (r :: (* -> *) -> *) a.
r HyperRep
-> Trace r -> EvalPredTraceI r a -> Maybe (a, (Trace r, Double))
runTracePredLogP r HyperRep
hyper Trace r
trace EvalPredTraceI r a
model
  pure (val, logp)

-- update priors
-- -------------

newtype UpdatePriorsI r a = UpdatePriorsI (StateT (Trace r, r HyperRep) Maybe a)
  deriving ((forall a b. (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b)
-> (forall a b. a -> UpdatePriorsI r b -> UpdatePriorsI r a)
-> Functor (UpdatePriorsI r)
forall a b. a -> UpdatePriorsI r b -> UpdatePriorsI r a
forall a b. (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (r :: (* -> *) -> *) a b.
a -> UpdatePriorsI r b -> UpdatePriorsI r a
forall (r :: (* -> *) -> *) a b.
(a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
$cfmap :: forall (r :: (* -> *) -> *) a b.
(a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
fmap :: forall a b. (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
$c<$ :: forall (r :: (* -> *) -> *) a b.
a -> UpdatePriorsI r b -> UpdatePriorsI r a
<$ :: forall a b. a -> UpdatePriorsI r b -> UpdatePriorsI r a
Functor, Functor (UpdatePriorsI r)
Functor (UpdatePriorsI r) =>
(forall a. a -> UpdatePriorsI r a)
-> (forall a b.
    UpdatePriorsI r (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b)
-> (forall a b c.
    (a -> b -> c)
    -> UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r c)
-> (forall a b.
    UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b)
-> (forall a b.
    UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r a)
-> Applicative (UpdatePriorsI r)
forall a. a -> UpdatePriorsI r a
forall a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r a
forall a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
forall a b.
UpdatePriorsI r (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
forall a b c.
(a -> b -> c)
-> UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (r :: (* -> *) -> *). Functor (UpdatePriorsI r)
forall (r :: (* -> *) -> *) a. a -> UpdatePriorsI r a
forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r a
forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
forall (r :: (* -> *) -> *) a b c.
(a -> b -> c)
-> UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r c
$cpure :: forall (r :: (* -> *) -> *) a. a -> UpdatePriorsI r a
pure :: forall a. a -> UpdatePriorsI r a
$c<*> :: forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
<*> :: forall a b.
UpdatePriorsI r (a -> b) -> UpdatePriorsI r a -> UpdatePriorsI r b
$cliftA2 :: forall (r :: (* -> *) -> *) a b c.
(a -> b -> c)
-> UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r c
liftA2 :: forall a b c.
(a -> b -> c)
-> UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r c
$c*> :: forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
*> :: forall a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
$c<* :: forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r a
<* :: forall a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r a
Applicative, Applicative (UpdatePriorsI r)
Applicative (UpdatePriorsI r) =>
(forall a b.
 UpdatePriorsI r a -> (a -> UpdatePriorsI r b) -> UpdatePriorsI r b)
-> (forall a b.
    UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b)
-> (forall a. a -> UpdatePriorsI r a)
-> Monad (UpdatePriorsI r)
forall a. a -> UpdatePriorsI r a
forall a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
forall a b.
UpdatePriorsI r a -> (a -> UpdatePriorsI r b) -> UpdatePriorsI r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (r :: (* -> *) -> *). Applicative (UpdatePriorsI r)
forall (r :: (* -> *) -> *) a. a -> UpdatePriorsI r a
forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> (a -> UpdatePriorsI r b) -> UpdatePriorsI r b
$c>>= :: forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> (a -> UpdatePriorsI r b) -> UpdatePriorsI r b
>>= :: forall a b.
UpdatePriorsI r a -> (a -> UpdatePriorsI r b) -> UpdatePriorsI r b
$c>> :: forall (r :: (* -> *) -> *) a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
>> :: forall a b.
UpdatePriorsI r a -> UpdatePriorsI r b -> UpdatePriorsI r b
$creturn :: forall (r :: (* -> *) -> *) a. a -> UpdatePriorsI r a
return :: forall a. a -> UpdatePriorsI r a
Monad)

instance RandomInterpreter (UpdatePriorsI r) r where
  type SampleCtx (UpdatePriorsI r) l = Typeable (Support l)
  sampleValue
    :: forall p l
     . (Conjugate p l, Typeable (Support l))
    => String
    -> l
    -> Accessor r p
    -> UpdatePriorsI r (Support l)
  sampleValue :: forall p l.
(Conjugate p l, Typeable (Support l)) =>
String -> l -> Accessor r p -> UpdatePriorsI r (Support l)
sampleValue String
_ l
lk Accessor r p
accessor = StateT (Trace r, r HyperRep) Maybe (Support l)
-> UpdatePriorsI r (Support l)
forall (r :: (* -> *) -> *) a.
StateT (Trace r, r HyperRep) Maybe a -> UpdatePriorsI r a
UpdatePriorsI (StateT (Trace r, r HyperRep) Maybe (Support l)
 -> UpdatePriorsI r (Support l))
-> StateT (Trace r, r HyperRep) Maybe (Support l)
-> UpdatePriorsI r (Support l)
forall a b. (a -> b) -> a -> b
$ do
    (trace, priors) <- StateT (Trace r, r HyperRep) Maybe (Trace r, r HyperRep)
forall s (m :: * -> *). MonadState s m => m s
get
    ((_name, val), trace') <- lift $ takeTrace trace
    let priors' :: r HyperRep
        priors' =
          ASetter (r HyperRep) (r HyperRep) (HyperRep p) (HyperRep p)
-> (HyperRep p -> HyperRep p) -> r HyperRep -> r HyperRep
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over
            ASetter (r HyperRep) (r HyperRep) (HyperRep p) (HyperRep p)
Accessor r p
accessor
            (\(HyperRep Hyper (AsPrior p)
pr) -> Hyper (AsPrior p) -> HyperRep p
forall p. Hyper (AsPrior p) -> HyperRep p
HyperRep (Hyper (AsPrior p) -> HyperRep p)
-> Hyper (AsPrior p) -> HyperRep p
forall a b. (a -> b) -> a -> b
$ forall p l. Conjugate p l => l -> Params p -> Support l -> Params p
updatePrior @p @l l
lk Hyper (AsPrior p)
Params p
pr Support l
val)
            r HyperRep
priors
    put (trace', priors')
    pure val
  sampleConst :: forall d.
(Distribution d, SampleCtx (UpdatePriorsI r) d) =>
String -> d -> Params d -> UpdatePriorsI r (Support d)
sampleConst String
_ d
_ Params d
_ = StateT (Trace r, r HyperRep) Maybe (Support d)
-> UpdatePriorsI r (Support d)
forall (r :: (* -> *) -> *) a.
StateT (Trace r, r HyperRep) Maybe a -> UpdatePriorsI r a
UpdatePriorsI (StateT (Trace r, r HyperRep) Maybe (Support d)
 -> UpdatePriorsI r (Support d))
-> StateT (Trace r, r HyperRep) Maybe (Support d)
-> UpdatePriorsI r (Support d)
forall a b. (a -> b) -> a -> b
$ do
    (trace, priors) <- StateT (Trace r, r HyperRep) Maybe (Trace r, r HyperRep)
forall s (m :: * -> *). MonadState s m => m s
get
    ((_name, val), trace') <- lift $ takeTrace trace
    put (trace', priors)
    pure val
  permutationPlate :: forall a.
Ord a =>
Int -> (Int -> UpdatePriorsI r a) -> UpdatePriorsI r [a]
permutationPlate = Int -> (Int -> UpdatePriorsI r a) -> UpdatePriorsI r [a]
forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m [a]
replicateMWithI

getPosterior
  :: r HyperRep -> Trace r -> UpdatePriorsI r a -> Maybe (r HyperRep)
getPosterior :: forall (r :: (* -> *) -> *) a.
r HyperRep -> Trace r -> UpdatePriorsI r a -> Maybe (r HyperRep)
getPosterior r HyperRep
priors Trace r
trace (UpdatePriorsI StateT (Trace r, r HyperRep) Maybe a
model) = do
  (_trace, posteriors) <- StateT (Trace r, r HyperRep) Maybe a
-> (Trace r, r HyperRep) -> Maybe (Trace r, r HyperRep)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT (Trace r, r HyperRep) Maybe a
model (Trace r
trace, r HyperRep
priors)
  pure posteriors

-- show how a trace is generated by a model
-- ----------------------------------------

newtype ShowTraceI r a = ShowTraceI (MaybeT (WriterT String (State (Trace r))) a)
  deriving ((forall a b. (a -> b) -> ShowTraceI r a -> ShowTraceI r b)
-> (forall a b. a -> ShowTraceI r b -> ShowTraceI r a)
-> Functor (ShowTraceI r)
forall a b. a -> ShowTraceI r b -> ShowTraceI r a
forall a b. (a -> b) -> ShowTraceI r a -> ShowTraceI r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (r :: (* -> *) -> *) a b.
a -> ShowTraceI r b -> ShowTraceI r a
forall (r :: (* -> *) -> *) a b.
(a -> b) -> ShowTraceI r a -> ShowTraceI r b
$cfmap :: forall (r :: (* -> *) -> *) a b.
(a -> b) -> ShowTraceI r a -> ShowTraceI r b
fmap :: forall a b. (a -> b) -> ShowTraceI r a -> ShowTraceI r b
$c<$ :: forall (r :: (* -> *) -> *) a b.
a -> ShowTraceI r b -> ShowTraceI r a
<$ :: forall a b. a -> ShowTraceI r b -> ShowTraceI r a
Functor, Functor (ShowTraceI r)
Functor (ShowTraceI r) =>
(forall a. a -> ShowTraceI r a)
-> (forall a b.
    ShowTraceI r (a -> b) -> ShowTraceI r a -> ShowTraceI r b)
-> (forall a b c.
    (a -> b -> c)
    -> ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r c)
-> (forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b)
-> (forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r a)
-> Applicative (ShowTraceI r)
forall a. a -> ShowTraceI r a
forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r a
forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
forall a b.
ShowTraceI r (a -> b) -> ShowTraceI r a -> ShowTraceI r b
forall a b c.
(a -> b -> c) -> ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (r :: (* -> *) -> *). Functor (ShowTraceI r)
forall (r :: (* -> *) -> *) a. a -> ShowTraceI r a
forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r a
forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
forall (r :: (* -> *) -> *) a b.
ShowTraceI r (a -> b) -> ShowTraceI r a -> ShowTraceI r b
forall (r :: (* -> *) -> *) a b c.
(a -> b -> c) -> ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r c
$cpure :: forall (r :: (* -> *) -> *) a. a -> ShowTraceI r a
pure :: forall a. a -> ShowTraceI r a
$c<*> :: forall (r :: (* -> *) -> *) a b.
ShowTraceI r (a -> b) -> ShowTraceI r a -> ShowTraceI r b
<*> :: forall a b.
ShowTraceI r (a -> b) -> ShowTraceI r a -> ShowTraceI r b
$cliftA2 :: forall (r :: (* -> *) -> *) a b c.
(a -> b -> c) -> ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r c
liftA2 :: forall a b c.
(a -> b -> c) -> ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r c
$c*> :: forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
*> :: forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
$c<* :: forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r a
<* :: forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r a
Applicative, Applicative (ShowTraceI r)
Applicative (ShowTraceI r) =>
(forall a b.
 ShowTraceI r a -> (a -> ShowTraceI r b) -> ShowTraceI r b)
-> (forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b)
-> (forall a. a -> ShowTraceI r a)
-> Monad (ShowTraceI r)
forall a. a -> ShowTraceI r a
forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
forall a b.
ShowTraceI r a -> (a -> ShowTraceI r b) -> ShowTraceI r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (r :: (* -> *) -> *). Applicative (ShowTraceI r)
forall (r :: (* -> *) -> *) a. a -> ShowTraceI r a
forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> (a -> ShowTraceI r b) -> ShowTraceI r b
$c>>= :: forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> (a -> ShowTraceI r b) -> ShowTraceI r b
>>= :: forall a b.
ShowTraceI r a -> (a -> ShowTraceI r b) -> ShowTraceI r b
$c>> :: forall (r :: (* -> *) -> *) a b.
ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
>> :: forall a b. ShowTraceI r a -> ShowTraceI r b -> ShowTraceI r b
$creturn :: forall (r :: (* -> *) -> *) a. a -> ShowTraceI r a
return :: forall a. a -> ShowTraceI r a
Monad)

showTraceItem
  :: forall l r
   . (Show (Support l), Typeable l, Typeable (Support l))
  => String
  -> ShowTraceI r (Support l)
showTraceItem :: forall l (r :: (* -> *) -> *).
(Show (Support l), Typeable l, Typeable (Support l)) =>
String -> ShowTraceI r (Support l)
showTraceItem String
name = MaybeT (WriterT String (State (Trace r))) (Support l)
-> ShowTraceI r (Support l)
forall (r :: (* -> *) -> *) a.
MaybeT (WriterT String (State (Trace r))) a -> ShowTraceI r a
ShowTraceI (MaybeT (WriterT String (State (Trace r))) (Support l)
 -> ShowTraceI r (Support l))
-> MaybeT (WriterT String (State (Trace r))) (Support l)
-> ShowTraceI r (Support l)
forall a b. (a -> b) -> a -> b
$ do
  trace <- MaybeT (WriterT String (State (Trace r))) (Trace r)
forall s (m :: * -> *). MonadState s m => m s
get
  ((_name, val), trace') <- MaybeT $ pure $ takeTrace trace
  put trace'
  let distName = TypeRep -> String
forall a. Show a => a -> String
show (Proxy l -> TypeRep
forall {k} (proxy :: k -> *) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep (Proxy l
forall {k} (t :: k). Proxy t
Proxy :: Proxy l))
  tell $
    "Sampled value "
      <> show val
      <> " from a "
      <> distName
      <> " at "
      <> name
      <> ".\n"
  pure val

instance RandomInterpreter (ShowTraceI r) r where
  type
    SampleCtx (ShowTraceI r) l =
      (Typeable (Support l), Typeable l, Show (Support l))
  sampleValue
    :: forall p l
     . (Conjugate p l, Typeable (Support l), Typeable l, Show (Support l))
    => String
    -> l
    -> Accessor r p
    -> ShowTraceI r (Support l)
  sampleValue :: forall p l.
(Conjugate p l, Typeable (Support l), Typeable l,
 Show (Support l)) =>
String -> l -> Accessor r p -> ShowTraceI r (Support l)
sampleValue String
name l
_ Accessor r p
acc = forall l (r :: (* -> *) -> *).
(Show (Support l), Typeable l, Typeable (Support l)) =>
String -> ShowTraceI r (Support l)
showTraceItem @l String
name
  sampleConst
    :: forall d
     . (Distribution d, SampleCtx (ShowTraceI r) d)
    => String
    -> d
    -> Params d
    -> ShowTraceI r (Support d)
  sampleConst :: forall d.
(Distribution d, SampleCtx (ShowTraceI r) d) =>
String -> d -> Params d -> ShowTraceI r (Support d)
sampleConst String
name d
_ Params d
_ = forall l (r :: (* -> *) -> *).
(Show (Support l), Typeable l, Typeable (Support l)) =>
String -> ShowTraceI r (Support l)
showTraceItem @d String
name
  permutationPlate :: forall a.
Ord a =>
Int -> (Int -> ShowTraceI r a) -> ShowTraceI r [a]
permutationPlate = Int -> (Int -> ShowTraceI r a) -> ShowTraceI r [a]
forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m [a]
replicateMWithI

showTrace :: Trace r -> ShowTraceI r a -> (Maybe a, String)
showTrace :: forall (r :: (* -> *) -> *) a.
Trace r -> ShowTraceI r a -> (Maybe a, String)
showTrace Trace r
trace (ShowTraceI MaybeT (WriterT String (State (Trace r))) a
model) =
  State (Trace r) (Maybe a, String) -> Trace r -> (Maybe a, String)
forall s a. State s a -> s -> a
evalState (WriterT String (State (Trace r)) (Maybe a)
-> State (Trace r) (Maybe a, String)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (MaybeT (WriterT String (State (Trace r))) a
-> WriterT String (State (Trace r)) (Maybe a)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT MaybeT (WriterT String (State (Trace r))) a
model)) Trace r
trace

printTrace :: Trace r -> ShowTraceI r a -> IO ()
printTrace :: forall (r :: (* -> *) -> *) a. Trace r -> ShowTraceI r a -> IO ()
printTrace Trace r
trace ShowTraceI r a
model = do
  let (Maybe a
res, String
txt) = Trace r -> ShowTraceI r a -> (Maybe a, String)
forall (r :: (* -> *) -> *) a.
Trace r -> ShowTraceI r a -> (Maybe a, String)
showTrace Trace r
trace ShowTraceI r a
model
  String -> IO ()
putStrLn String
txt
  case Maybe a
res of
    Maybe a
Nothing -> do
      String -> IO ()
putStrLn String
"Trace does not match the model (stops here)"
    Just a
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- tracing interpreter (for debugging)
-- -----------------------------------

newtype TraceTraceI r a = TraceTraceI (State (Trace r) a)
  deriving ((forall a b. (a -> b) -> TraceTraceI r a -> TraceTraceI r b)
-> (forall a b. a -> TraceTraceI r b -> TraceTraceI r a)
-> Functor (TraceTraceI r)
forall a b. a -> TraceTraceI r b -> TraceTraceI r a
forall a b. (a -> b) -> TraceTraceI r a -> TraceTraceI r b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (r :: (* -> *) -> *) a b.
a -> TraceTraceI r b -> TraceTraceI r a
forall (r :: (* -> *) -> *) a b.
(a -> b) -> TraceTraceI r a -> TraceTraceI r b
$cfmap :: forall (r :: (* -> *) -> *) a b.
(a -> b) -> TraceTraceI r a -> TraceTraceI r b
fmap :: forall a b. (a -> b) -> TraceTraceI r a -> TraceTraceI r b
$c<$ :: forall (r :: (* -> *) -> *) a b.
a -> TraceTraceI r b -> TraceTraceI r a
<$ :: forall a b. a -> TraceTraceI r b -> TraceTraceI r a
Functor, Functor (TraceTraceI r)
Functor (TraceTraceI r) =>
(forall a. a -> TraceTraceI r a)
-> (forall a b.
    TraceTraceI r (a -> b) -> TraceTraceI r a -> TraceTraceI r b)
-> (forall a b c.
    (a -> b -> c)
    -> TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r c)
-> (forall a b.
    TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b)
-> (forall a b.
    TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r a)
-> Applicative (TraceTraceI r)
forall a. a -> TraceTraceI r a
forall a b. TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r a
forall a b. TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
forall a b.
TraceTraceI r (a -> b) -> TraceTraceI r a -> TraceTraceI r b
forall a b c.
(a -> b -> c)
-> TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (r :: (* -> *) -> *). Functor (TraceTraceI r)
forall (r :: (* -> *) -> *) a. a -> TraceTraceI r a
forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r a
forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
forall (r :: (* -> *) -> *) a b.
TraceTraceI r (a -> b) -> TraceTraceI r a -> TraceTraceI r b
forall (r :: (* -> *) -> *) a b c.
(a -> b -> c)
-> TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r c
$cpure :: forall (r :: (* -> *) -> *) a. a -> TraceTraceI r a
pure :: forall a. a -> TraceTraceI r a
$c<*> :: forall (r :: (* -> *) -> *) a b.
TraceTraceI r (a -> b) -> TraceTraceI r a -> TraceTraceI r b
<*> :: forall a b.
TraceTraceI r (a -> b) -> TraceTraceI r a -> TraceTraceI r b
$cliftA2 :: forall (r :: (* -> *) -> *) a b c.
(a -> b -> c)
-> TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r c
liftA2 :: forall a b c.
(a -> b -> c)
-> TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r c
$c*> :: forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
*> :: forall a b. TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
$c<* :: forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r a
<* :: forall a b. TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r a
Applicative, Applicative (TraceTraceI r)
Applicative (TraceTraceI r) =>
(forall a b.
 TraceTraceI r a -> (a -> TraceTraceI r b) -> TraceTraceI r b)
-> (forall a b.
    TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b)
-> (forall a. a -> TraceTraceI r a)
-> Monad (TraceTraceI r)
forall a. a -> TraceTraceI r a
forall a b. TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
forall a b.
TraceTraceI r a -> (a -> TraceTraceI r b) -> TraceTraceI r b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (r :: (* -> *) -> *). Applicative (TraceTraceI r)
forall (r :: (* -> *) -> *) a. a -> TraceTraceI r a
forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> (a -> TraceTraceI r b) -> TraceTraceI r b
$c>>= :: forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> (a -> TraceTraceI r b) -> TraceTraceI r b
>>= :: forall a b.
TraceTraceI r a -> (a -> TraceTraceI r b) -> TraceTraceI r b
$c>> :: forall (r :: (* -> *) -> *) a b.
TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
>> :: forall a b. TraceTraceI r a -> TraceTraceI r b -> TraceTraceI r b
$creturn :: forall (r :: (* -> *) -> *) a. a -> TraceTraceI r a
return :: forall a. a -> TraceTraceI r a
Monad)

traceTraceItem
  :: forall l r
   . (Show (Support l), Typeable l, Typeable (Support l))
  => String
  -> TraceTraceI r (Support l)
traceTraceItem :: forall l (r :: (* -> *) -> *).
(Show (Support l), Typeable l, Typeable (Support l)) =>
String -> TraceTraceI r (Support l)
traceTraceItem String
name = State (Trace r) (Support l) -> TraceTraceI r (Support l)
forall (r :: (* -> *) -> *) a. State (Trace r) a -> TraceTraceI r a
TraceTraceI (State (Trace r) (Support l) -> TraceTraceI r (Support l))
-> State (Trace r) (Support l) -> TraceTraceI r (Support l)
forall a b. (a -> b) -> a -> b
$ do
  trace <- StateT (Trace r) Identity (Trace r)
forall s (m :: * -> *). MonadState s m => m s
get
  let loc = TypeRep -> String
forall a. Show a => a -> String
show (Proxy l -> TypeRep
forall {k} (proxy :: k -> *) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep (Proxy l
forall {k} (t :: k). Proxy t
Proxy :: Proxy l)) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" at " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
name
  let itemMaybe = Trace r -> Maybe (String, Dynamic)
forall (r :: (* -> *) -> *). Trace r -> Maybe (String, Dynamic)
peekTrace Trace r
trace
  case itemMaybe of
    Maybe (String, Dynamic)
Nothing -> String -> State (Trace r) (Support l)
forall a. HasCallStack => String -> a
error (String -> State (Trace r) (Support l))
-> String -> State (Trace r) (Support l)
forall a b. (a -> b) -> a -> b
$ String
"Expected " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" but trace is empty."
    Just (String
tname, Dynamic
tval) ->
      if String
name String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
tname
        then case Trace r -> Maybe ((String, Support l), Trace r)
forall a (r :: (* -> *) -> *).
Typeable a =>
Trace r -> Maybe ((String, a), Trace r)
takeTrace Trace r
trace of
          Just ((String
_tname, Support l
val), Trace r
trace') -> do
            Trace r -> StateT (Trace r) Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put Trace r
trace'
            String -> StateT (Trace r) Identity ()
forall (f :: * -> *). Applicative f => String -> f ()
DT.traceM (String -> StateT (Trace r) Identity ())
-> String -> StateT (Trace r) Identity ()
forall a b. (a -> b) -> a -> b
$ String
"Sampled value " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Support l -> String
forall a. Show a => a -> String
show Support l
val String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" from a " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
loc String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."
            Support l -> State (Trace r) (Support l)
forall a. a -> StateT (Trace r) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Support l
val
          Maybe ((String, Support l), Trace r)
Nothing -> String -> State (Trace r) (Support l)
forall a. HasCallStack => String -> a
error (String -> State (Trace r) (Support l))
-> String -> State (Trace r) (Support l)
forall a b. (a -> b) -> a -> b
$ String
"Incompatible trace at " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
loc String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
". Got " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> (String, Dynamic) -> String
forall a. Show a => a -> String
show (String
tname, Dynamic
tval)
        else String -> State (Trace r) (Support l)
forall a. HasCallStack => String -> a
error (String -> State (Trace r) (Support l))
-> String -> State (Trace r) (Support l)
forall a b. (a -> b) -> a -> b
$ String
"RV names don't match. expected: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
". actual: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
tname String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"."

instance RandomInterpreter (TraceTraceI r) r where
  type
    SampleCtx (TraceTraceI r) l =
      (Typeable (Support l), Typeable l, Show (Support l))
  sampleValue
    :: forall p l
     . (Conjugate p l, Typeable (Support l), Typeable l, Show (Support l))
    => String
    -> l
    -> Accessor r p
    -> TraceTraceI r (Support l)
  sampleValue :: forall p l.
(Conjugate p l, Typeable (Support l), Typeable l,
 Show (Support l)) =>
String -> l -> Accessor r p -> TraceTraceI r (Support l)
sampleValue String
name l
_ Accessor r p
acc = forall l (r :: (* -> *) -> *).
(Show (Support l), Typeable l, Typeable (Support l)) =>
String -> TraceTraceI r (Support l)
traceTraceItem @l String
name
  sampleConst
    :: forall d
     . (Distribution d, SampleCtx (TraceTraceI r) d)
    => String
    -> d
    -> Params d
    -> TraceTraceI r (Support d)
  sampleConst :: forall d.
(Distribution d, SampleCtx (TraceTraceI r) d) =>
String -> d -> Params d -> TraceTraceI r (Support d)
sampleConst String
name d
_ Params d
_ = forall l (r :: (* -> *) -> *).
(Show (Support l), Typeable l, Typeable (Support l)) =>
String -> TraceTraceI r (Support l)
traceTraceItem @d String
name
  permutationPlate :: forall a.
Ord a =>
Int -> (Int -> TraceTraceI r a) -> TraceTraceI r [a]
permutationPlate = Int -> (Int -> TraceTraceI r a) -> TraceTraceI r [a]
forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m [a]
replicateMWithI

traceTrace :: Trace r -> TraceTraceI r a -> a
traceTrace :: forall (r :: (* -> *) -> *) a. Trace r -> TraceTraceI r a -> a
traceTrace Trace r
trace (TraceTraceI State (Trace r) a
model) = State (Trace r) a -> Trace r -> a
forall s a. State s a -> s -> a
evalState State (Trace r) a
model Trace r
trace

-------------------
-- distributions --
-------------------

-- Beta
-- ----

data Beta = Beta
  deriving (Beta -> Beta -> Bool
(Beta -> Beta -> Bool) -> (Beta -> Beta -> Bool) -> Eq Beta
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Beta -> Beta -> Bool
== :: Beta -> Beta -> Bool
$c/= :: Beta -> Beta -> Bool
/= :: Beta -> Beta -> Bool
Eq, Eq Beta
Eq Beta =>
(Beta -> Beta -> Ordering)
-> (Beta -> Beta -> Bool)
-> (Beta -> Beta -> Bool)
-> (Beta -> Beta -> Bool)
-> (Beta -> Beta -> Bool)
-> (Beta -> Beta -> Beta)
-> (Beta -> Beta -> Beta)
-> Ord Beta
Beta -> Beta -> Bool
Beta -> Beta -> Ordering
Beta -> Beta -> Beta
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Beta -> Beta -> Ordering
compare :: Beta -> Beta -> Ordering
$c< :: Beta -> Beta -> Bool
< :: Beta -> Beta -> Bool
$c<= :: Beta -> Beta -> Bool
<= :: Beta -> Beta -> Bool
$c> :: Beta -> Beta -> Bool
> :: Beta -> Beta -> Bool
$c>= :: Beta -> Beta -> Bool
>= :: Beta -> Beta -> Bool
$cmax :: Beta -> Beta -> Beta
max :: Beta -> Beta -> Beta
$cmin :: Beta -> Beta -> Beta
min :: Beta -> Beta -> Beta
Ord, Int -> Beta -> ShowS
[Beta] -> ShowS
Beta -> String
(Int -> Beta -> ShowS)
-> (Beta -> String) -> ([Beta] -> ShowS) -> Show Beta
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Beta -> ShowS
showsPrec :: Int -> Beta -> ShowS
$cshow :: Beta -> String
show :: Beta -> String
$cshowList :: [Beta] -> ShowS
showList :: [Beta] -> ShowS
Show, (forall x. Beta -> Rep Beta x)
-> (forall x. Rep Beta x -> Beta) -> Generic Beta
forall x. Rep Beta x -> Beta
forall x. Beta -> Rep Beta x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Beta -> Rep Beta x
from :: forall x. Beta -> Rep Beta x
$cto :: forall x. Rep Beta x -> Beta
to :: forall x. Rep Beta x -> Beta
Generic)

instance Distribution Beta where
  type Params Beta = (Double, Double)
  type Support Beta = Double
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Beta -> Params Beta -> Prob m (Support Beta)
distSample Beta
_ = (Double -> Double -> Prob m Double)
-> (Double, Double) -> Prob m Double
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
beta
  distLogP :: Beta -> Params Beta -> Support Beta -> Double
distLogP Beta
_ (Double
a, Double
b) Support Beta
p =
    Double -> Double
forall a. Floating a => a -> a
log (Double
Support Beta
p Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log ((Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
Support Beta
p) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double -> Double
logBeta Double
a Double
b

instance Jeffreys (AsPrior Beta) where
  jeffreysPrior :: Hyper (AsPrior Beta)
jeffreysPrior = (Double
0.5, Double
0.5)

instance Uniform (AsPrior Beta) where
  uniformPrior :: Hyper (AsPrior Beta)
uniformPrior = (Double
1, Double
1)

instance Prior (AsPrior Beta) where
  sampleProbs :: forall (m :: * -> *).
PrimMonad m =>
Hyper (AsPrior Beta) -> Prob m (Probs (AsPrior Beta))
sampleProbs = Beta -> Params Beta -> Prob m (Support Beta)
forall a (m :: * -> *).
(Distribution a, PrimMonad m) =>
a -> Params a -> Prob m (Support a)
forall (m :: * -> *).
PrimMonad m =>
Beta -> Params Beta -> Prob m (Support Beta)
distSample Beta
Beta
  expectedProbs :: Hyper (AsPrior Beta) -> Probs (AsPrior Beta)
expectedProbs (Double
a, Double
b) = Double
a Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b)

-- Bernoulli
-- ---------

data Bernoulli = Bernoulli
  deriving (Bernoulli -> Bernoulli -> Bool
(Bernoulli -> Bernoulli -> Bool)
-> (Bernoulli -> Bernoulli -> Bool) -> Eq Bernoulli
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Bernoulli -> Bernoulli -> Bool
== :: Bernoulli -> Bernoulli -> Bool
$c/= :: Bernoulli -> Bernoulli -> Bool
/= :: Bernoulli -> Bernoulli -> Bool
Eq, Eq Bernoulli
Eq Bernoulli =>
(Bernoulli -> Bernoulli -> Ordering)
-> (Bernoulli -> Bernoulli -> Bool)
-> (Bernoulli -> Bernoulli -> Bool)
-> (Bernoulli -> Bernoulli -> Bool)
-> (Bernoulli -> Bernoulli -> Bool)
-> (Bernoulli -> Bernoulli -> Bernoulli)
-> (Bernoulli -> Bernoulli -> Bernoulli)
-> Ord Bernoulli
Bernoulli -> Bernoulli -> Bool
Bernoulli -> Bernoulli -> Ordering
Bernoulli -> Bernoulli -> Bernoulli
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Bernoulli -> Bernoulli -> Ordering
compare :: Bernoulli -> Bernoulli -> Ordering
$c< :: Bernoulli -> Bernoulli -> Bool
< :: Bernoulli -> Bernoulli -> Bool
$c<= :: Bernoulli -> Bernoulli -> Bool
<= :: Bernoulli -> Bernoulli -> Bool
$c> :: Bernoulli -> Bernoulli -> Bool
> :: Bernoulli -> Bernoulli -> Bool
$c>= :: Bernoulli -> Bernoulli -> Bool
>= :: Bernoulli -> Bernoulli -> Bool
$cmax :: Bernoulli -> Bernoulli -> Bernoulli
max :: Bernoulli -> Bernoulli -> Bernoulli
$cmin :: Bernoulli -> Bernoulli -> Bernoulli
min :: Bernoulli -> Bernoulli -> Bernoulli
Ord, Int -> Bernoulli -> ShowS
[Bernoulli] -> ShowS
Bernoulli -> String
(Int -> Bernoulli -> ShowS)
-> (Bernoulli -> String)
-> ([Bernoulli] -> ShowS)
-> Show Bernoulli
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Bernoulli -> ShowS
showsPrec :: Int -> Bernoulli -> ShowS
$cshow :: Bernoulli -> String
show :: Bernoulli -> String
$cshowList :: [Bernoulli] -> ShowS
showList :: [Bernoulli] -> ShowS
Show, (forall x. Bernoulli -> Rep Bernoulli x)
-> (forall x. Rep Bernoulli x -> Bernoulli) -> Generic Bernoulli
forall x. Rep Bernoulli x -> Bernoulli
forall x. Bernoulli -> Rep Bernoulli x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Bernoulli -> Rep Bernoulli x
from :: forall x. Bernoulli -> Rep Bernoulli x
$cto :: forall x. Rep Bernoulli x -> Bernoulli
to :: forall x. Rep Bernoulli x -> Bernoulli
Generic)

instance Distribution Bernoulli where
  type Params Bernoulli = Double
  type Support Bernoulli = Bool
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Bernoulli -> Params Bernoulli -> Prob m (Support Bernoulli)
distSample Bernoulli
_ = Double -> Prob m Bool
Params Bernoulli -> Prob m (Support Bernoulli)
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli
  distLogP :: Bernoulli -> Params Bernoulli -> Support Bernoulli -> Double
distLogP Bernoulli
_ Params Bernoulli
p Bool
Support Bernoulli
True = Double -> Double
forall a. Floating a => a -> a
log Double
Params Bernoulli
p
  distLogP Bernoulli
_ Params Bernoulli
p Bool
Support Bernoulli
False = Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
Params Bernoulli
p)

-- Binomial
-- --------

newtype Binomial = Binomial Int
  deriving (Binomial -> Binomial -> Bool
(Binomial -> Binomial -> Bool)
-> (Binomial -> Binomial -> Bool) -> Eq Binomial
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Binomial -> Binomial -> Bool
== :: Binomial -> Binomial -> Bool
$c/= :: Binomial -> Binomial -> Bool
/= :: Binomial -> Binomial -> Bool
Eq, Eq Binomial
Eq Binomial =>
(Binomial -> Binomial -> Ordering)
-> (Binomial -> Binomial -> Bool)
-> (Binomial -> Binomial -> Bool)
-> (Binomial -> Binomial -> Bool)
-> (Binomial -> Binomial -> Bool)
-> (Binomial -> Binomial -> Binomial)
-> (Binomial -> Binomial -> Binomial)
-> Ord Binomial
Binomial -> Binomial -> Bool
Binomial -> Binomial -> Ordering
Binomial -> Binomial -> Binomial
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Binomial -> Binomial -> Ordering
compare :: Binomial -> Binomial -> Ordering
$c< :: Binomial -> Binomial -> Bool
< :: Binomial -> Binomial -> Bool
$c<= :: Binomial -> Binomial -> Bool
<= :: Binomial -> Binomial -> Bool
$c> :: Binomial -> Binomial -> Bool
> :: Binomial -> Binomial -> Bool
$c>= :: Binomial -> Binomial -> Bool
>= :: Binomial -> Binomial -> Bool
$cmax :: Binomial -> Binomial -> Binomial
max :: Binomial -> Binomial -> Binomial
$cmin :: Binomial -> Binomial -> Binomial
min :: Binomial -> Binomial -> Binomial
Ord, Int -> Binomial -> ShowS
[Binomial] -> ShowS
Binomial -> String
(Int -> Binomial -> ShowS)
-> (Binomial -> String) -> ([Binomial] -> ShowS) -> Show Binomial
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Binomial -> ShowS
showsPrec :: Int -> Binomial -> ShowS
$cshow :: Binomial -> String
show :: Binomial -> String
$cshowList :: [Binomial] -> ShowS
showList :: [Binomial] -> ShowS
Show, (forall x. Binomial -> Rep Binomial x)
-> (forall x. Rep Binomial x -> Binomial) -> Generic Binomial
forall x. Rep Binomial x -> Binomial
forall x. Binomial -> Rep Binomial x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Binomial -> Rep Binomial x
from :: forall x. Binomial -> Rep Binomial x
$cto :: forall x. Rep Binomial x -> Binomial
to :: forall x. Rep Binomial x -> Binomial
Generic)

instance Distribution Binomial where
  type Params Binomial = Double
  type Support Binomial = Int
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Binomial -> Params Binomial -> Prob m (Support Binomial)
distSample (Binomial Int
n) = Int -> Double -> Prob m Int
forall (m :: * -> *). PrimMonad m => Int -> Double -> Prob m Int
binomial Int
n
  distLogP :: Binomial -> Params Binomial -> Support Binomial -> Double
distLogP (Binomial Int
n) Params Binomial
p Support Binomial
k =
    Int -> Int -> Double
logChoose Int
n Int
Support Binomial
k Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double -> Double
forall a. Floating a => a -> a
log Double
Params Binomial
p Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
k') Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
Params Binomial
p) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
n' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
k'))
   where
    k' :: Double
k' = Int -> Double
int2Double Int
Support Binomial
k
    n' :: Double
n' = Int -> Double
int2Double Int
n

-- Categorical
-- -----------

type Categorical :: Nat -> Type
data Categorical n = Categorical
  deriving (Categorical n -> Categorical n -> Bool
(Categorical n -> Categorical n -> Bool)
-> (Categorical n -> Categorical n -> Bool) -> Eq (Categorical n)
forall (n :: Nat). Categorical n -> Categorical n -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Nat). Categorical n -> Categorical n -> Bool
== :: Categorical n -> Categorical n -> Bool
$c/= :: forall (n :: Nat). Categorical n -> Categorical n -> Bool
/= :: Categorical n -> Categorical n -> Bool
Eq, Eq (Categorical n)
Eq (Categorical n) =>
(Categorical n -> Categorical n -> Ordering)
-> (Categorical n -> Categorical n -> Bool)
-> (Categorical n -> Categorical n -> Bool)
-> (Categorical n -> Categorical n -> Bool)
-> (Categorical n -> Categorical n -> Bool)
-> (Categorical n -> Categorical n -> Categorical n)
-> (Categorical n -> Categorical n -> Categorical n)
-> Ord (Categorical n)
Categorical n -> Categorical n -> Bool
Categorical n -> Categorical n -> Ordering
Categorical n -> Categorical n -> Categorical n
forall (n :: Nat). Eq (Categorical n)
forall (n :: Nat). Categorical n -> Categorical n -> Bool
forall (n :: Nat). Categorical n -> Categorical n -> Ordering
forall (n :: Nat). Categorical n -> Categorical n -> Categorical n
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: forall (n :: Nat). Categorical n -> Categorical n -> Ordering
compare :: Categorical n -> Categorical n -> Ordering
$c< :: forall (n :: Nat). Categorical n -> Categorical n -> Bool
< :: Categorical n -> Categorical n -> Bool
$c<= :: forall (n :: Nat). Categorical n -> Categorical n -> Bool
<= :: Categorical n -> Categorical n -> Bool
$c> :: forall (n :: Nat). Categorical n -> Categorical n -> Bool
> :: Categorical n -> Categorical n -> Bool
$c>= :: forall (n :: Nat). Categorical n -> Categorical n -> Bool
>= :: Categorical n -> Categorical n -> Bool
$cmax :: forall (n :: Nat). Categorical n -> Categorical n -> Categorical n
max :: Categorical n -> Categorical n -> Categorical n
$cmin :: forall (n :: Nat). Categorical n -> Categorical n -> Categorical n
min :: Categorical n -> Categorical n -> Categorical n
Ord, Int -> Categorical n -> ShowS
[Categorical n] -> ShowS
Categorical n -> String
(Int -> Categorical n -> ShowS)
-> (Categorical n -> String)
-> ([Categorical n] -> ShowS)
-> Show (Categorical n)
forall (n :: Nat). Int -> Categorical n -> ShowS
forall (n :: Nat). [Categorical n] -> ShowS
forall (n :: Nat). Categorical n -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (n :: Nat). Int -> Categorical n -> ShowS
showsPrec :: Int -> Categorical n -> ShowS
$cshow :: forall (n :: Nat). Categorical n -> String
show :: Categorical n -> String
$cshowList :: forall (n :: Nat). [Categorical n] -> ShowS
showList :: [Categorical n] -> ShowS
Show, (forall x. Categorical n -> Rep (Categorical n) x)
-> (forall x. Rep (Categorical n) x -> Categorical n)
-> Generic (Categorical n)
forall (n :: Nat) x. Rep (Categorical n) x -> Categorical n
forall (n :: Nat) x. Categorical n -> Rep (Categorical n) x
forall x. Rep (Categorical n) x -> Categorical n
forall x. Categorical n -> Rep (Categorical n) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (n :: Nat) x. Categorical n -> Rep (Categorical n) x
from :: forall x. Categorical n -> Rep (Categorical n) x
$cto :: forall (n :: Nat) x. Rep (Categorical n) x -> Categorical n
to :: forall x. Rep (Categorical n) x -> Categorical n
Generic)

instance Distribution (Categorical n) where
  type Params (Categorical n) = V.Vector Double
  type Support (Categorical n) = Int
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Categorical n
-> Params (Categorical n) -> Prob m (Support (Categorical n))
distSample Categorical n
_ = Vector Double -> Prob m Int
Params (Categorical n) -> Prob m (Support (Categorical n))
forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical
  distLogP :: Categorical n
-> Params (Categorical n) -> Support (Categorical n) -> Double
distLogP Categorical n
_ Params (Categorical n)
ps Support (Categorical n)
cat = Double -> Double
forall a. Floating a => a -> a
log (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe Double
0 (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ Vector Double
Params (Categorical n)
ps Vector Double -> Int -> Maybe Double
forall a. Vector a -> Int -> Maybe a
V.!? Int
Support (Categorical n)
cat

-- Dirichlet
-- ---------

type Dirichlet :: Nat -> Type
data Dirichlet n = Dirichlet
  deriving (Dirichlet n -> Dirichlet n -> Bool
(Dirichlet n -> Dirichlet n -> Bool)
-> (Dirichlet n -> Dirichlet n -> Bool) -> Eq (Dirichlet n)
forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
== :: Dirichlet n -> Dirichlet n -> Bool
$c/= :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
/= :: Dirichlet n -> Dirichlet n -> Bool
Eq, Eq (Dirichlet n)
Eq (Dirichlet n) =>
(Dirichlet n -> Dirichlet n -> Ordering)
-> (Dirichlet n -> Dirichlet n -> Bool)
-> (Dirichlet n -> Dirichlet n -> Bool)
-> (Dirichlet n -> Dirichlet n -> Bool)
-> (Dirichlet n -> Dirichlet n -> Bool)
-> (Dirichlet n -> Dirichlet n -> Dirichlet n)
-> (Dirichlet n -> Dirichlet n -> Dirichlet n)
-> Ord (Dirichlet n)
Dirichlet n -> Dirichlet n -> Bool
Dirichlet n -> Dirichlet n -> Ordering
Dirichlet n -> Dirichlet n -> Dirichlet n
forall (n :: Nat). Eq (Dirichlet n)
forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
forall (n :: Nat). Dirichlet n -> Dirichlet n -> Ordering
forall (n :: Nat). Dirichlet n -> Dirichlet n -> Dirichlet n
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Ordering
compare :: Dirichlet n -> Dirichlet n -> Ordering
$c< :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
< :: Dirichlet n -> Dirichlet n -> Bool
$c<= :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
<= :: Dirichlet n -> Dirichlet n -> Bool
$c> :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
> :: Dirichlet n -> Dirichlet n -> Bool
$c>= :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Bool
>= :: Dirichlet n -> Dirichlet n -> Bool
$cmax :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Dirichlet n
max :: Dirichlet n -> Dirichlet n -> Dirichlet n
$cmin :: forall (n :: Nat). Dirichlet n -> Dirichlet n -> Dirichlet n
min :: Dirichlet n -> Dirichlet n -> Dirichlet n
Ord, Int -> Dirichlet n -> ShowS
[Dirichlet n] -> ShowS
Dirichlet n -> String
(Int -> Dirichlet n -> ShowS)
-> (Dirichlet n -> String)
-> ([Dirichlet n] -> ShowS)
-> Show (Dirichlet n)
forall (n :: Nat). Int -> Dirichlet n -> ShowS
forall (n :: Nat). [Dirichlet n] -> ShowS
forall (n :: Nat). Dirichlet n -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (n :: Nat). Int -> Dirichlet n -> ShowS
showsPrec :: Int -> Dirichlet n -> ShowS
$cshow :: forall (n :: Nat). Dirichlet n -> String
show :: Dirichlet n -> String
$cshowList :: forall (n :: Nat). [Dirichlet n] -> ShowS
showList :: [Dirichlet n] -> ShowS
Show, (forall x. Dirichlet n -> Rep (Dirichlet n) x)
-> (forall x. Rep (Dirichlet n) x -> Dirichlet n)
-> Generic (Dirichlet n)
forall (n :: Nat) x. Rep (Dirichlet n) x -> Dirichlet n
forall (n :: Nat) x. Dirichlet n -> Rep (Dirichlet n) x
forall x. Rep (Dirichlet n) x -> Dirichlet n
forall x. Dirichlet n -> Rep (Dirichlet n) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (n :: Nat) x. Dirichlet n -> Rep (Dirichlet n) x
from :: forall x. Dirichlet n -> Rep (Dirichlet n) x
$cto :: forall (n :: Nat) x. Rep (Dirichlet n) x -> Dirichlet n
to :: forall x. Rep (Dirichlet n) x -> Dirichlet n
Generic)

instance Distribution (Dirichlet n) where
  type Params (Dirichlet n) = V.Vector Double
  type Support (Dirichlet n) = V.Vector Double
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Dirichlet n
-> Params (Dirichlet n) -> Prob m (Support (Dirichlet n))
distSample Dirichlet n
_ = Vector Double -> Prob m (Vector Double)
Params (Dirichlet n) -> Prob m (Support (Dirichlet n))
forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Prob m (f Double)
dirichlet
  distLogP :: Dirichlet n
-> Params (Dirichlet n) -> Support (Dirichlet n) -> Double
distLogP Dirichlet n
_ Params (Dirichlet n)
counts Support (Dirichlet n)
probs = Double
logp Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
logz
   where
    logp :: Double
logp = Vector Double -> Double
forall a. Num a => Vector a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Double -> Double -> Double)
-> Vector Double -> Vector Double -> Vector Double
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith (\Double
a Double
x -> Double -> Double
forall a. Floating a => a -> a
log Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)) Vector Double
Params (Dirichlet n)
counts Vector Double
Support (Dirichlet n)
probs)
    logz :: Double
logz = Double -> Double
logGamma (Vector Double -> Double
forall a. Num a => Vector a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Vector Double
Params (Dirichlet n)
counts) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Vector Double -> Double
forall a. Num a => Vector a -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Double -> Double
logGamma (Double -> Double) -> Vector Double -> Vector Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Double
Params (Dirichlet n)
counts)

instance (KnownNat n) => Jeffreys (AsPrior (Dirichlet n)) where
  jeffreysPrior :: Hyper (AsPrior (Dirichlet n))
jeffreysPrior = Int -> Double -> Vector Double
forall a. Int -> a -> Vector a
V.replicate (Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)) Double
0.5

instance (KnownNat n) => Uniform (AsPrior (Dirichlet n)) where
  uniformPrior :: Hyper (AsPrior (Dirichlet n))
uniformPrior = Int -> Double -> Vector Double
forall a. Int -> a -> Vector a
V.replicate (Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)) Double
1

instance Prior (AsPrior (Dirichlet n)) where
  sampleProbs :: forall (m :: * -> *).
PrimMonad m =>
Hyper (AsPrior (Dirichlet n))
-> Prob m (Probs (AsPrior (Dirichlet n)))
sampleProbs = Dirichlet (ZonkAny 0)
-> Params (Dirichlet (ZonkAny 0))
-> Prob m (Support (Dirichlet (ZonkAny 0)))
forall a (m :: * -> *).
(Distribution a, PrimMonad m) =>
a -> Params a -> Prob m (Support a)
forall (m :: * -> *).
PrimMonad m =>
Dirichlet (ZonkAny 0)
-> Params (Dirichlet (ZonkAny 0))
-> Prob m (Support (Dirichlet (ZonkAny 0)))
distSample Dirichlet (ZonkAny 0)
forall (n :: Nat). Dirichlet n
Dirichlet
  expectedProbs :: V.Vector Double -> V.Vector Double
  expectedProbs :: Vector Double -> Vector Double
expectedProbs Vector Double
probs = (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
norm) (Double -> Double) -> Vector Double -> Vector Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector Double
probs
   where
    norm :: Double
norm = Vector Double -> Double
forall a. Num a => Vector a -> a
V.sum Vector Double
probs

-- Geometric (from 0)
-- ------------------

data Geometric0 = Geometric0
  deriving (Geometric0 -> Geometric0 -> Bool
(Geometric0 -> Geometric0 -> Bool)
-> (Geometric0 -> Geometric0 -> Bool) -> Eq Geometric0
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Geometric0 -> Geometric0 -> Bool
== :: Geometric0 -> Geometric0 -> Bool
$c/= :: Geometric0 -> Geometric0 -> Bool
/= :: Geometric0 -> Geometric0 -> Bool
Eq, Eq Geometric0
Eq Geometric0 =>
(Geometric0 -> Geometric0 -> Ordering)
-> (Geometric0 -> Geometric0 -> Bool)
-> (Geometric0 -> Geometric0 -> Bool)
-> (Geometric0 -> Geometric0 -> Bool)
-> (Geometric0 -> Geometric0 -> Bool)
-> (Geometric0 -> Geometric0 -> Geometric0)
-> (Geometric0 -> Geometric0 -> Geometric0)
-> Ord Geometric0
Geometric0 -> Geometric0 -> Bool
Geometric0 -> Geometric0 -> Ordering
Geometric0 -> Geometric0 -> Geometric0
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Geometric0 -> Geometric0 -> Ordering
compare :: Geometric0 -> Geometric0 -> Ordering
$c< :: Geometric0 -> Geometric0 -> Bool
< :: Geometric0 -> Geometric0 -> Bool
$c<= :: Geometric0 -> Geometric0 -> Bool
<= :: Geometric0 -> Geometric0 -> Bool
$c> :: Geometric0 -> Geometric0 -> Bool
> :: Geometric0 -> Geometric0 -> Bool
$c>= :: Geometric0 -> Geometric0 -> Bool
>= :: Geometric0 -> Geometric0 -> Bool
$cmax :: Geometric0 -> Geometric0 -> Geometric0
max :: Geometric0 -> Geometric0 -> Geometric0
$cmin :: Geometric0 -> Geometric0 -> Geometric0
min :: Geometric0 -> Geometric0 -> Geometric0
Ord, Int -> Geometric0 -> ShowS
[Geometric0] -> ShowS
Geometric0 -> String
(Int -> Geometric0 -> ShowS)
-> (Geometric0 -> String)
-> ([Geometric0] -> ShowS)
-> Show Geometric0
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Geometric0 -> ShowS
showsPrec :: Int -> Geometric0 -> ShowS
$cshow :: Geometric0 -> String
show :: Geometric0 -> String
$cshowList :: [Geometric0] -> ShowS
showList :: [Geometric0] -> ShowS
Show, (forall x. Geometric0 -> Rep Geometric0 x)
-> (forall x. Rep Geometric0 x -> Geometric0) -> Generic Geometric0
forall x. Rep Geometric0 x -> Geometric0
forall x. Geometric0 -> Rep Geometric0 x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Geometric0 -> Rep Geometric0 x
from :: forall x. Geometric0 -> Rep Geometric0 x
$cto :: forall x. Rep Geometric0 x -> Geometric0
to :: forall x. Rep Geometric0 x -> Geometric0
Generic)

instance Distribution Geometric0 where
  type Params Geometric0 = Double
  type Support Geometric0 = Int
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Geometric0 -> Params Geometric0 -> Prob m (Support Geometric0)
distSample Geometric0
_ = Double -> Prob m Int
Params Geometric0 -> Prob m (Support Geometric0)
forall {m :: * -> *} {b}.
(PrimMonad m, Num b) =>
Double -> Prob m b
geometric0
   where
    geometric0 :: Double -> Prob m b
geometric0 Double
p = do
      coin <- Double -> Prob m Bool
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli Double
p
      if coin then pure 0 else (1 +) <$> geometric0 p
  distLogP :: Geometric0 -> Params Geometric0 -> Support Geometric0 -> Double
distLogP Geometric0
_ Params Geometric0
p Support Geometric0
val
    | Int
Support Geometric0
val Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = (Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
Params Geometric0
p) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
int2Double Int
Support Geometric0
val) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log Double
Params Geometric0
p
    | Bool
otherwise = Double -> Double
forall a. Floating a => a -> a
log Double
0

-- Geometric (from 1)
-- ------------------

data Geometric1 = Geometric1
  deriving (Geometric1 -> Geometric1 -> Bool
(Geometric1 -> Geometric1 -> Bool)
-> (Geometric1 -> Geometric1 -> Bool) -> Eq Geometric1
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Geometric1 -> Geometric1 -> Bool
== :: Geometric1 -> Geometric1 -> Bool
$c/= :: Geometric1 -> Geometric1 -> Bool
/= :: Geometric1 -> Geometric1 -> Bool
Eq, Eq Geometric1
Eq Geometric1 =>
(Geometric1 -> Geometric1 -> Ordering)
-> (Geometric1 -> Geometric1 -> Bool)
-> (Geometric1 -> Geometric1 -> Bool)
-> (Geometric1 -> Geometric1 -> Bool)
-> (Geometric1 -> Geometric1 -> Bool)
-> (Geometric1 -> Geometric1 -> Geometric1)
-> (Geometric1 -> Geometric1 -> Geometric1)
-> Ord Geometric1
Geometric1 -> Geometric1 -> Bool
Geometric1 -> Geometric1 -> Ordering
Geometric1 -> Geometric1 -> Geometric1
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Geometric1 -> Geometric1 -> Ordering
compare :: Geometric1 -> Geometric1 -> Ordering
$c< :: Geometric1 -> Geometric1 -> Bool
< :: Geometric1 -> Geometric1 -> Bool
$c<= :: Geometric1 -> Geometric1 -> Bool
<= :: Geometric1 -> Geometric1 -> Bool
$c> :: Geometric1 -> Geometric1 -> Bool
> :: Geometric1 -> Geometric1 -> Bool
$c>= :: Geometric1 -> Geometric1 -> Bool
>= :: Geometric1 -> Geometric1 -> Bool
$cmax :: Geometric1 -> Geometric1 -> Geometric1
max :: Geometric1 -> Geometric1 -> Geometric1
$cmin :: Geometric1 -> Geometric1 -> Geometric1
min :: Geometric1 -> Geometric1 -> Geometric1
Ord, Int -> Geometric1 -> ShowS
[Geometric1] -> ShowS
Geometric1 -> String
(Int -> Geometric1 -> ShowS)
-> (Geometric1 -> String)
-> ([Geometric1] -> ShowS)
-> Show Geometric1
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Geometric1 -> ShowS
showsPrec :: Int -> Geometric1 -> ShowS
$cshow :: Geometric1 -> String
show :: Geometric1 -> String
$cshowList :: [Geometric1] -> ShowS
showList :: [Geometric1] -> ShowS
Show, (forall x. Geometric1 -> Rep Geometric1 x)
-> (forall x. Rep Geometric1 x -> Geometric1) -> Generic Geometric1
forall x. Rep Geometric1 x -> Geometric1
forall x. Geometric1 -> Rep Geometric1 x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Geometric1 -> Rep Geometric1 x
from :: forall x. Geometric1 -> Rep Geometric1 x
$cto :: forall x. Rep Geometric1 x -> Geometric1
to :: forall x. Rep Geometric1 x -> Geometric1
Generic)

instance Distribution Geometric1 where
  type Params Geometric1 = Double
  type Support Geometric1 = Int
  distSample :: forall (m :: * -> *).
PrimMonad m =>
Geometric1 -> Params Geometric1 -> Prob m (Support Geometric1)
distSample Geometric1
_ = Double -> Prob m Int
Params Geometric1 -> Prob m (Support Geometric1)
forall {m :: * -> *} {b}.
(PrimMonad m, Num b) =>
Double -> Prob m b
geometric1
   where
    geometric1 :: Double -> Prob m b
geometric1 Double
p = do
      coin <- Double -> Prob m Bool
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli Double
p
      if coin then pure 1 else (1 +) <$> geometric1 p
  distLogP :: Geometric1 -> Params Geometric1 -> Support Geometric1 -> Double
distLogP Geometric1
_ Params Geometric1
p Support Geometric1
val
    | Int
Support Geometric1
val Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1 = (Double -> Double
forall a. Floating a => a -> a
log (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
Params Geometric1
p) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Int -> Double
int2Double (Int
Support Geometric1
val Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
forall a. Floating a => a -> a
log Double
Params Geometric1
p
    | Bool
otherwise = Double -> Double
forall a. Floating a => a -> a
log Double
0

-----------------------------
-- conjugate distributions --
-----------------------------

-- beta bernoulli
-- --------------

instance Conjugate Beta Bernoulli where
  priorSingleton :: Beta
priorSingleton = Beta
Beta
  updatePrior :: Bernoulli -> Params Beta -> Support Bernoulli -> Params Beta
updatePrior Bernoulli
_ (Double
a, Double
b) Bool
Support Bernoulli
False = (Double
a, Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1)
  updatePrior Bernoulli
_ (Double
a, Double
b) Bool
Support Bernoulli
True = (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1, Double
b)
  predLogP :: Bernoulli -> Params Beta -> Support Bernoulli -> Double
predLogP Bernoulli
_ (Double
a, Double
b) Bool
Support Bernoulli
False = Double -> Double
forall a. Floating a => a -> a
log Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b)
  predLogP Bernoulli
_ (Double
a, Double
b) Bool
Support Bernoulli
True = Double -> Double
forall a. Floating a => a -> a
log Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b)

-- B(1+a, b)   G(1+a) G(b) G(a+b)   G(1+a) G(a+b)   a G(a) G(a+b)        a
-- --------- = ------------------ = ------------- = ----------------- = ---
-- B(a, b)     G(a) G(b) G(a+b+1)   G(a) G(a+b+1)   G(a) (a+b) G(a+b)   a+b

-- beta binomial
-- -------------

instance Conjugate Beta Binomial where
  priorSingleton :: Beta
priorSingleton = Beta
Beta
  updatePrior :: Binomial -> Params Beta -> Support Binomial -> Params Beta
updatePrior (Binomial Int
n) (Double
a, Double
b) Support Binomial
x = (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
x', Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double
n' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
x'))
   where
    x' :: Double
x' = Int -> Double
int2Double Int
Support Binomial
x
    n' :: Double
n' = Int -> Double
int2Double Int
n
  predLogP :: Binomial -> Params Beta -> Support Binomial -> Double
predLogP (Binomial Int
n) (Double
a, Double
b) Support Binomial
k =
    Int -> Int -> Double
logChoose Int
n Int
Support Binomial
k Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double -> Double
logBeta (Double
k' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
a) (Double
n' Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
k' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
a) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double -> Double
logBeta Double
a Double
b
   where
    n' :: Double
n' = Int -> Double
int2Double Int
n
    k' :: Double
k' = Int -> Double
int2Double Int
Support Binomial
k

-- beta geometric0
-- ---------------

instance Conjugate Beta Geometric0 where
  priorSingleton :: Beta
priorSingleton = Beta
Beta
  updatePrior :: Geometric0 -> Params Beta -> Support Geometric0 -> Params Beta
updatePrior Geometric0
_ (Double
a, Double
b) Support Geometric0
k = (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1, Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Int -> Double
int2Double Int
Support Geometric0
k)
  predLogP :: Geometric0 -> Params Beta -> Support Geometric0 -> Double
predLogP Geometric0
_ (Double
a, Double
b) Support Geometric0
k =
    (Double -> Double
forall a. Floating a => a -> a
log Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
logGamma (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
logGamma (Double
k' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b))
      Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Double -> Double
logGamma Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
logGamma (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
k' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1))
   where
    k' :: Double
k' = Int -> Double
int2Double Int
Support Geometric0
k

-- beta geometric1
-- ---------------

instance Conjugate Beta Geometric1 where
  priorSingleton :: Beta
priorSingleton = Beta
Beta
  updatePrior :: Geometric1 -> Params Beta -> Support Geometric1 -> Params Beta
updatePrior Geometric1
_ (Double
a, Double
b) Support Geometric1
k = (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1, Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Int -> Double
int2Double (Int
Support Geometric1
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
  predLogP :: Geometric1 -> Params Beta -> Support Geometric1 -> Double
predLogP Geometric1
_ (Double
a, Double
b) Support Geometric1
k =
    (Double -> Double
forall a. Floating a => a -> a
log Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
logGamma (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
logGamma (Double
k' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b))
      Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Double -> Double
logGamma Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double -> Double
logGamma (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
k' Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1))
   where
    k' :: Double
k' = Int -> Double
int2Double Int
Support Geometric1
k Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1

-- dirichlet categorical
-- ---------------------

instance Conjugate (Dirichlet n) (Categorical n) where
  priorSingleton :: Dirichlet n
priorSingleton = Dirichlet n
forall (n :: Nat). Dirichlet n
Dirichlet
  updatePrior :: Categorical n
-> Params (Dirichlet n)
-> Support (Categorical n)
-> Params (Dirichlet n)
updatePrior Categorical n
_ Params (Dirichlet n)
counts Support (Categorical n)
obs
    | Int
Support (Categorical n)
obs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0
    , Int
Support (Categorical n)
obs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Double -> Int
forall a. Vector a -> Int
V.length Vector Double
Params (Dirichlet n)
counts =
        Vector Double
Params (Dirichlet n)
counts Vector Double -> [(Int, Double)] -> Vector Double
forall a. Vector a -> [(Int, a)] -> Vector a
V.// [(Int
Support (Categorical n)
obs, (Vector Double
Params (Dirichlet n)
counts Vector Double -> Int -> Double
forall a. Vector a -> Int -> a
V.! Int
Support (Categorical n)
obs) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1)]
    | Bool
otherwise =
        Params (Dirichlet n)
counts
  predLogP :: Categorical n
-> Params (Dirichlet n) -> Support (Categorical n) -> Double
predLogP Categorical n
_ Params (Dirichlet n)
counts Support (Categorical n)
obs =
    Double -> Double
forall a. Floating a => a -> a
log (Double -> Maybe Double -> Double
forall a. a -> Maybe a -> a
fromMaybe Double
0 (Maybe Double -> Double) -> Maybe Double -> Double
forall a b. (a -> b) -> a -> b
$ Vector Double
Params (Dirichlet n)
counts Vector Double -> Int -> Maybe Double
forall a. Vector a -> Int -> Maybe a
V.!? Int
Support (Categorical n)
obs) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double -> Double
forall a. Floating a => a -> a
log (Vector Double -> Double
forall a. Num a => Vector a -> a
V.sum Vector Double
Params (Dirichlet n)
counts)

-- utilities and helpers
-- =====================

replicateMWithI :: (Applicative m) => Int -> (Int -> m a) -> m [a]
replicateMWithI :: forall (m :: * -> *) a.
Applicative m =>
Int -> (Int -> m a) -> m [a]
replicateMWithI Int
i Int -> m a
f = (Int -> m a) -> [Int] -> m [a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Int -> m a
f [Int
0 .. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]