{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}

-- |
-- Module: System.Random.MWC.Probability
-- Copyright: (c) 2015-2018 Jared Tobin, Marco Zocca
-- License: MIT
--
-- Maintainer: Jared Tobin <jared@jtobin.ca>, Marco Zocca <zocca.marco gmail>
-- Stability: unstable
-- Portability: ghc
--
-- A probability monad based on sampling functions, implemented as a thin
-- wrapper over the
-- [mwc-random](https://hackage.haskell.org/package/mwc-random) library.
--
-- Probability distributions are abstract constructs that can be represented in
-- a variety of ways.  The sampling function representation is particularly
-- useful -- it's computationally efficient, and collections of samples are
-- amenable to much practical work.
--
-- Probability monads propagate uncertainty under the hood.  An expression like
-- @'beta' 1 8 >>= 'binomial' 10@ corresponds to a
-- <https://en.wikipedia.org/wiki/Beta-binomial_distribution beta-binomial>
-- distribution in which the uncertainty captured by @'beta' 1 8@ has been
-- marginalized out.
--
-- The distribution resulting from a series of effects is called the
-- /predictive distribution/ of the model described by the corresponding
-- expression.  The monadic structure lets one piece together a hierarchical
-- structure from simpler, local conditionals:
--
-- > hierarchicalModel = do
-- >   [c, d, e, f] <- replicateM 4 $ uniformR (1, 10)
-- >   a <- gamma c d
-- >   b <- gamma e f
-- >   p <- beta a b
-- >   n <- uniformR (5, 10)
-- >   binomial n p
--
-- The functor instance allows one to transforms the support of a distribution
-- while leaving its density structure invariant.  For example, @'uniform'@ is
-- a distribution over the 0-1 interval, but @fmap (+ 1) uniform@ is the
-- translated distribution over the 1-2 interval:
--
-- >>> create >>= sample (fmap (+ 1) uniform)
-- 1.5480073474340754
--
-- The applicative instance guarantees that the generated samples are generated
-- independently:
--
-- >>> create >>= sample ((,) <$> uniform <*> uniform)

module System.Random.MWC.Probability (
    module MWC
  , Prob(..)
  , samples

  , uniform
  , uniformR
  , normal
  , standardNormal
  , isoNormal
  , logNormal
  , exponential
  , inverseGaussian
  , laplace
  , gamma
  , inverseGamma
  , normalGamma
  , weibull
  , chiSquare
  , beta
  , gstudent
  , student
  , pareto
  , dirichlet
  , symmetricDirichlet
  , discreteUniform
  , zipf
  , categorical
  , discrete
  , bernoulli
  , binomial
  , negativeBinomial
  , multinomial
  , poisson
  , crp
  ) where

import Control.Applicative
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Data.Monoid (Sum(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable)
#endif
import qualified Data.Foldable as F
import Data.List (findIndex)
import qualified Data.IntMap as IM
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import System.Random.MWC.CondensedTable

-- | A probability distribution characterized by a sampling function.
--
-- >>> gen <- createSystemRandom
-- >>> sample uniform gen
-- 0.4208881170464097
newtype Prob m a = Prob { forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample :: Gen (PrimState m) -> m a }

-- | Sample from a model 'n' times.
--
-- >>> samples 2 uniform gen
-- [0.6738707766845254,0.9730405951541817]
samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples :: forall (m :: * -> *) a.
PrimMonad m =>
Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples Int
n Prob m a
model Gen (PrimState m)
gen = forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA (forall a. Int -> a -> [a]
replicate Int
n (forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample Prob m a
model Gen (PrimState m)
gen))
{-# INLINABLE samples #-}

instance Functor m => Functor (Prob m) where
  fmap :: forall a b. (a -> b) -> Prob m a -> Prob m b
fmap a -> b
h (Prob Gen (PrimState m) -> m a
f) = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
h forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gen (PrimState m) -> m a
f)

instance Monad m => Applicative (Prob m) where
  pure :: forall a. a -> Prob m a
pure  = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure
  <*> :: forall a b. Prob m (a -> b) -> Prob m a -> Prob m b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad m => Monad (Prob m) where
  return :: forall a. a -> Prob m a
return = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  Prob m a
m >>= :: forall a b. Prob m a -> (a -> Prob m b) -> Prob m b
>>= a -> Prob m b
h = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ \Gen (PrimState m)
g -> do
    a
z <- forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample Prob m a
m Gen (PrimState m)
g
    forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample (a -> Prob m b
h a
z) Gen (PrimState m)
g
  {-# INLINABLE (>>=) #-}

instance (Monad m, Num a) => Num (Prob m a) where
  + :: Prob m a -> Prob m a -> Prob m a
(+)         = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(+)
  (-)         = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
  * :: Prob m a -> Prob m a -> Prob m a
(*)         = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(*)
  abs :: Prob m a -> Prob m a
abs         = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
abs
  signum :: Prob m a -> Prob m a
signum      = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
signum
  fromInteger :: Integer -> Prob m a
fromInteger = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger

instance MonadTrans Prob where
  lift :: forall (m :: * -> *) a. Monad m => m a -> Prob m a
lift m a
m = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const m a
m

instance MonadIO m => MonadIO (Prob m) where
  liftIO :: forall a. IO a -> Prob m a
liftIO IO a
m = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
m)

instance PrimMonad m => PrimMonad (Prob m) where
  type PrimState (Prob m) = PrimState m
  primitive :: forall a.
(State# (PrimState (Prob m))
 -> (# State# (PrimState (Prob m)), a #))
-> Prob m a
primitive = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive
  {-# INLINE primitive #-}

-- | The uniform distribution at a specified type.
--
--   Note that `Double` and `Float` variates are defined over the unit
--   interval.
--
--   >>> sample uniform gen :: IO Double
--   0.29308497534914946
--   >>> sample uniform gen :: IO Bool
--   False
uniform :: (PrimMonad m, Variate a) => Prob m a
uniform :: forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
QMWC.uniform
{-# INLINABLE uniform #-}

-- | The uniform distribution over the provided interval.
--
--   >>> sample (uniformR (0, 1)) gen
--   0.44984153252922365
uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
uniformR :: forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (a, a)
r = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
QMWC.uniformR (a, a)
r
{-# INLINABLE uniformR #-}

-- | The discrete uniform distribution.
--
--   >>> sample (discreteUniform [0..10]) gen
--   6
--   >>> sample (discreteUniform "abcdefghijklmnopqrstuvwxyz") gen
--   'a'
discreteUniform :: (PrimMonad m, Foldable f) => f a -> Prob m a
discreteUniform :: forall (m :: * -> *) (f :: * -> *) a.
(PrimMonad m, Foldable f) =>
f a -> Prob m a
discreteUniform f a
cs = do
  Int
j <- forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (Int
0, forall (t :: * -> *) a. Foldable t => t a -> Int
length f a
cs forall a. Num a => a -> a -> a
- Int
1)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList f a
cs forall a. [a] -> Int -> a
!! Int
j
{-# INLINABLE discreteUniform #-}

-- | The standard normal or Gaussian distribution with mean 0 and standard
--   deviation 1.
standardNormal :: PrimMonad m => Prob m Double
standardNormal :: forall (m :: * -> *). PrimMonad m => Prob m Double
standardNormal = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall g (m :: * -> *). StatefulGen g m => g -> m Double
MWC.Dist.standard
{-# INLINABLE standardNormal #-}

-- | The normal or Gaussian distribution with specified mean and standard
--   deviation.
--
--   Note that `sd` should be positive.
normal :: PrimMonad m => Double -> Double -> Prob m Double
normal :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
normal Double
m Double
sd = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.normal Double
m Double
sd
{-# INLINABLE normal #-}

-- | The log-normal distribution with specified mean and standard deviation.
--
--   Note that `sd` should be positive.
logNormal :: PrimMonad m => Double -> Double -> Prob m Double
logNormal :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
logNormal Double
m Double
sd = forall a. Floating a => a -> a
exp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
normal Double
m Double
sd
{-# INLINABLE logNormal #-}

-- | The exponential distribution with provided rate parameter.
--
--   Note that `r` should be positive.
exponential :: PrimMonad m => Double -> Prob m Double
exponential :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Double
exponential Double
r = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Double
MWC.Dist.exponential Double
r
{-# INLINABLE exponential #-}

-- | The Laplace or double-exponential distribution with provided location and
--   scale parameters.
--
--   Note that `sigma` should be positive.
laplace :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
laplace :: forall a (m :: * -> *).
(Floating a, Variate a, PrimMonad m) =>
a -> a -> Prob m a
laplace a
mu a
sigma = do
  a
u <- forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (-a
0.5, a
0.5)
  let b :: a
b = a
sigma forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt a
2
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ a
mu forall a. Num a => a -> a -> a
- a
b forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
signum a
u forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log (a
1 forall a. Num a => a -> a -> a
- a
2 forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
abs a
u)
{-# INLINABLE laplace #-}

-- | The Weibull distribution with provided shape and scale parameters.
--
--   Note that both parameters should be positive.
weibull :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
weibull :: forall a (m :: * -> *).
(Floating a, Variate a, PrimMonad m) =>
a -> a -> Prob m a
weibull a
a a
b = do
  a
x <- forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ (- a
1forall a. Fractional a => a -> a -> a
/a
a forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log (a
1 forall a. Num a => a -> a -> a
- a
x)) forall a. Floating a => a -> a -> a
** a
1forall a. Fractional a => a -> a -> a
/a
b
{-# INLINABLE weibull #-}

-- | The gamma distribution with shape parameter `a` and scale parameter `b`.
--
--   This is the parameterization used more traditionally in frequentist
--   statistics.  It has the following corresponding probability density
--   function:
--
-- > f(x; a, b) = 1 / (Gamma(a) * b ^ a) x ^ (a - 1) e ^ (- x / b)
--
--   Note that both parameters should be positive.
gamma :: PrimMonad m => Double -> Double -> Prob m Double
gamma :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
b = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.gamma Double
a Double
b
{-# INLINABLE gamma #-}

-- | The inverse-gamma distribution with shape parameter `a` and scale
--   parameter `b`.
--
--   Note that both parameters should be positive.
inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
inverseGamma :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
inverseGamma Double
a Double
b = forall a. Fractional a => a -> a
recip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
b
{-# INLINABLE inverseGamma #-}

-- | The Normal-Gamma distribution.
--
--   Note that the `lambda`, `a`, and `b` parameters should be positive.
normalGamma :: PrimMonad m => Double -> Double -> Double -> Double -> Prob m Double
normalGamma :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Double -> Double -> Prob m Double
normalGamma Double
mu Double
lambda Double
a Double
b = do
  Double
tau <- forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
b
  let xsd :: Double
xsd = forall a. Floating a => a -> a
sqrt (forall a. Fractional a => a -> a
recip (Double
lambda forall a. Num a => a -> a -> a
* Double
tau))
  forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
normal Double
mu Double
xsd
{-# INLINABLE normalGamma #-}

-- | The chi-square distribution with the specified degrees of freedom.
--
--   Note that `k` should be positive.
chiSquare :: PrimMonad m => Int -> Prob m Double
chiSquare :: forall (m :: * -> *). PrimMonad m => Int -> Prob m Double
chiSquare Int
k = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *). StatefulGen g m => Int -> g -> m Double
MWC.Dist.chiSquare Int
k
{-# INLINABLE chiSquare #-}

-- | The beta distribution with the specified shape parameters.
--
--   Note that both parameters should be positive.
beta :: PrimMonad m => Double -> Double -> Prob m Double
beta :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
beta Double
a Double
b = do
  Double
u <- forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
1
  Double
w <- forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
b Double
1
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Double
u forall a. Fractional a => a -> a -> a
/ (Double
u forall a. Num a => a -> a -> a
+ Double
w)
{-# INLINABLE beta #-}

-- | The Pareto distribution with specified index `a` and minimum `xmin`
--   parameters.
--
--   Note that both parameters should be positive.
pareto :: PrimMonad m => Double -> Double -> Prob m Double
pareto :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
pareto Double
a Double
xmin = do
  Double
y <- forall (m :: * -> *). PrimMonad m => Double -> Prob m Double
exponential Double
a
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Double
xmin forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp Double
y
{-# INLINABLE pareto #-}

-- | The Dirichlet distribution with the provided concentration parameters.
--   The dimension of the distribution is determined by the number of
--   concentration parameters supplied.
--
--   >>> sample (dirichlet [0.1, 1, 10]) gen
--   [1.2375387187120799e-5,3.4952460651813816e-3,0.9964923785476316]
--
--   Note that all concentration parameters should be positive.
dirichlet
  :: (Traversable f, PrimMonad m) => f Double -> Prob m (f Double)
dirichlet :: forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Prob m (f Double)
dirichlet f Double
as = do
  f Double
zs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
`gamma` Double
1) f Double
as
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Fractional a => a -> a -> a
/ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum f Double
zs) f Double
zs
{-# INLINABLE dirichlet #-}

-- | The symmetric Dirichlet distribution with dimension `n`.  The provided
--   concentration parameter is simply replicated `n` times.
--
--   Note that `a` should be positive.
symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
symmetricDirichlet :: forall (m :: * -> *).
PrimMonad m =>
Int -> Double -> Prob m [Double]
symmetricDirichlet Int
n Double
a = forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Prob m (f Double)
dirichlet (forall a. Int -> a -> [a]
replicate Int
n Double
a)
{-# INLINABLE symmetricDirichlet #-}

-- | The Bernoulli distribution with success probability `p`.
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli Double
p = (forall a. Ord a => a -> a -> Bool
< Double
p) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
{-# INLINABLE bernoulli #-}

-- | The binomial distribution with number of trials `n` and success
--   probability `p`.
--
--   >>> sample (binomial 10 0.3) gen
--   4
binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial :: forall (m :: * -> *). PrimMonad m => Int -> Double -> Prob m Int
binomial Int
n Double
p = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter forall a. a -> a
id) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli Double
p)
{-# INLINABLE binomial #-}

-- | The negative binomial distribution with number of trials `n` and success
--   probability `p`.
--
--   >>> sample (negativeBinomial 10 0.3) gen
--   21
negativeBinomial :: (PrimMonad m, Integral a) => a -> Double -> Prob m Int
negativeBinomial :: forall (m :: * -> *) a.
(PrimMonad m, Integral a) =>
a -> Double -> Prob m Int
negativeBinomial a
n Double
p = do
  Double
y <- forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n) ((Double
1 forall a. Num a => a -> a -> a
- Double
p) forall a. Fractional a => a -> a -> a
/ Double
p)
  forall (m :: * -> *). PrimMonad m => Double -> Prob m Int
poisson Double
y
{-# INLINABLE negativeBinomial #-}

-- | The multinomial distribution of `n` trials and category probabilities
--   `ps`.
--
--   Note that the supplied probability container should consist of non-negative
--   values but is not required to sum to one.
multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
multinomial :: forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
Int -> f Double -> Prob m [Int]
multinomial Int
n f Double
ps = do
    let ([Double]
cumulative, Double
total) = forall a. Num a => [a] -> ([a], a)
runningTotals (forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList f Double
ps)
    forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ do
      Double
z <- forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (Double
0, Double
total)
      case forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (forall a. Ord a => a -> a -> Bool
> Double
z) [Double]
cumulative of
        Just Int
g  -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
g
        Maybe Int
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"mwc-probability: invalid probability vector"
  where
    -- Note: this is significantly faster than any
    -- of the recursions one might write by hand.
    runningTotals :: Num a => [a] -> ([a], a)
    runningTotals :: forall a. Num a => [a] -> ([a], a)
runningTotals [a]
xs = let adds :: [a]
adds = forall a. (a -> a -> a) -> [a] -> [a]
scanl1 forall a. Num a => a -> a -> a
(+) [a]
xs in ([a]
adds, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [a]
xs)
{-# INLINABLE multinomial #-}

-- | Generalized Student's t distribution with location parameter `m`, scale
--   parameter `s`, and degrees of freedom `k`.
--
--   Note that the `s` and `k` parameters should be positive.
gstudent :: PrimMonad m => Double -> Double -> Double -> Prob m Double
gstudent :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Double -> Prob m Double
gstudent Double
m Double
s Double
k = do
  Double
sd <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
sqrt (forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
inverseGamma (Double
k forall a. Fractional a => a -> a -> a
/ Double
2) (Double
s forall a. Num a => a -> a -> a
* Double
2 forall a. Fractional a => a -> a -> a
/ Double
k))
  forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
normal Double
m Double
sd
{-# INLINABLE gstudent #-}

-- | Student's t distribution with `k` degrees of freedom.
--
--   Note that `k` should be positive.
student :: PrimMonad m => Double -> Prob m Double
student :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Double
student = forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Double -> Prob m Double
gstudent Double
0 Double
1
{-# INLINABLE student #-}

-- | An isotropic or spherical Gaussian distribution with specified mean
--   vector and scalar standard deviation parameter.
--
--   Note that `sd` should be positive.
isoNormal
  :: (Traversable f, PrimMonad m) => f Double -> Double -> Prob m (f Double)
isoNormal :: forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Double -> Prob m (f Double)
isoNormal f Double
ms Double
sd = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
`normal` Double
sd) f Double
ms
{-# INLINABLE isoNormal #-}

-- | The inverse Gaussian (also known as Wald) distribution with mean parameter
--   `mu` and shape parameter `lambda`.
--
--   Note that both 'mu' and 'lambda' should be positive.
inverseGaussian :: PrimMonad m => Double -> Double -> Prob m Double
inverseGaussian :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
inverseGaussian Double
lambda Double
mu = do
  Double
nu <- forall (m :: * -> *). PrimMonad m => Prob m Double
standardNormal
  let y :: Double
y = Double
nu forall a. Floating a => a -> a -> a
** Double
2
      s :: Double
s =  forall a. Floating a => a -> a
sqrt (Double
4 forall a. Num a => a -> a -> a
* Double
mu forall a. Num a => a -> a -> a
* Double
lambda forall a. Num a => a -> a -> a
* Double
y forall a. Num a => a -> a -> a
+ Double
mu forall a. Floating a => a -> a -> a
** Double
2  forall a. Num a => a -> a -> a
* Double
y forall a. Floating a => a -> a -> a
** Double
2)
      x :: Double
x = Double
mu forall a. Num a => a -> a -> a
* (Double
1 forall a. Num a => a -> a -> a
+ Double
1 forall a. Fractional a => a -> a -> a
/ (Double
2 forall a. Num a => a -> a -> a
* Double
lambda) forall a. Num a => a -> a -> a
* (Double
mu forall a. Num a => a -> a -> a
* Double
y forall a. Num a => a -> a -> a
- Double
s))
      thresh :: Double
thresh = Double
mu forall a. Fractional a => a -> a -> a
/ (Double
mu forall a. Num a => a -> a -> a
+ Double
x)
  Double
z <- forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
  if Double
z forall a. Ord a => a -> a -> Bool
<= Double
thresh
    then forall (m :: * -> *) a. Monad m => a -> m a
return Double
x
    else forall (m :: * -> *) a. Monad m => a -> m a
return (Double
mu forall a. Floating a => a -> a -> a
** Double
2 forall a. Fractional a => a -> a -> a
/ Double
x)
{-# INLINABLE inverseGaussian #-}

-- | The Poisson distribution with rate parameter `l`.
--
--   Note that `l` should be positive.
poisson :: PrimMonad m => Double -> Prob m Int
poisson :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Int
poisson Double
l = forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob forall a b. (a -> b) -> a -> b
$ forall g (m :: * -> *) (v :: * -> *) a.
(StatefulGen g m, Vector v a) =>
CondensedTable v a -> g -> m a
genFromTable CondensedTableU Int
table where
  table :: CondensedTableU Int
table = Double -> CondensedTableU Int
tablePoisson Double
l
{-# INLINABLE poisson #-}

-- | A categorical distribution defined by the supplied probabilities.
--
--   Note that the supplied probability container should consist of non-negative
--   values but is not required to sum to one.
categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
categorical :: forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical f Double
ps = do
  [Int]
xs <- forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
Int -> f Double -> Prob m [Int]
multinomial Int
1 f Double
ps
  case [Int]
xs of
    [Int
x] -> forall (m :: * -> *) a. Monad m => a -> m a
return Int
x
    [Int]
_   -> forall a. HasCallStack => [Char] -> a
error [Char]
"mwc-probability: invalid probability vector"
{-# INLINABLE categorical #-}

-- | A categorical distribution defined by the supplied support.
--
--   Note that the supplied probabilities should be non-negative, but are not
--   required to sum to one.
--
--   >>> samples 10 (discrete [(0.1, "yeah"), (0.9, "nah")]) gen
--   ["yeah","nah","nah","nah","nah","yeah","nah","nah","nah","nah"]
discrete :: (Foldable f, PrimMonad m) => f (Double, a) -> Prob m a
discrete :: forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, PrimMonad m) =>
f (Double, a) -> Prob m a
discrete f (Double, a)
d = do
  let ([Double]
ps, [a]
xs) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList f (Double, a)
d)
  Int
idx <- forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical [Double]
ps
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a]
xs forall a. [a] -> Int -> a
!! Int
idx)
{-# INLINABLE discrete #-}

-- | The Zipf-Mandelbrot distribution.
--
--  Note that `a` should be positive, and that values close to 1 should be
--  avoided as they are very computationally intensive.
--
--  >>> samples 10 (zipf 1.1) gen
--  [11315371987423520,2746946,653,609,2,13,85,4,256184577853,50]
--
--  >>> samples 10 (zipf 1.5) gen
--  [19,3,3,1,1,2,1,191,2,1]
zipf :: (PrimMonad m, Integral b) => Double -> Prob m b
zipf :: forall (m :: * -> *) b.
(PrimMonad m, Integral b) =>
Double -> Prob m b
zipf Double
a = do
  let
    b :: Double
b = Double
2 forall a. Floating a => a -> a -> a
** (Double
a forall a. Num a => a -> a -> a
- Double
1)
    go :: Prob m b
go = do
        Double
u <- forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
        Double
v <- forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
        let xInt :: b
xInt = forall a b. (RealFrac a, Integral b) => a -> b
floor (Double
u forall a. Floating a => a -> a -> a
** (- Double
1 forall a. Fractional a => a -> a -> a
/ (Double
a forall a. Num a => a -> a -> a
- Double
1)))
            x :: Double
x = forall a b. (Integral a, Num b) => a -> b
fromIntegral b
xInt
            t :: Double
t = (Double
1 forall a. Num a => a -> a -> a
+ Double
1 forall a. Fractional a => a -> a -> a
/ Double
x) forall a. Floating a => a -> a -> a
** (Double
a forall a. Num a => a -> a -> a
- Double
1)
        if Double
v forall a. Num a => a -> a -> a
* Double
x forall a. Num a => a -> a -> a
* (Double
t forall a. Num a => a -> a -> a
- Double
1) forall a. Fractional a => a -> a -> a
/ (Double
b forall a. Num a => a -> a -> a
- Double
1) forall a. Ord a => a -> a -> Bool
<= Double
t forall a. Fractional a => a -> a -> a
/ Double
b
          then forall (m :: * -> *) a. Monad m => a -> m a
return b
xInt
          else Prob m b
go
  Prob m b
go
{-# INLINABLE zipf #-}

-- | The Chinese Restaurant Process with concentration parameter `a` and number
--   of customers `n`.
--
--   See Griffiths and Ghahramani, 2011 for details.
--
--   >>> sample (crp 1.8 50) gen
--   [22,10,7,1,2,2,4,1,1]
crp
  :: PrimMonad m
  => Double            -- ^ concentration parameter (> 1)
  -> Int               -- ^ number of customers
  -> Prob m [Integer]
crp :: forall (m :: * -> *).
PrimMonad m =>
Double -> Int -> Prob m [Integer]
crp Double
a Int
n = do
    CRPTables (Sum Integer)
ts <- CRPTables (Sum Integer) -> Int -> Prob m (CRPTables (Sum Integer))
go CRPTables (Sum Integer)
crpInitial Int
1
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Sum a -> a
getSum CRPTables (Sum Integer)
ts)
  where
    go :: CRPTables (Sum Integer) -> Int -> Prob m (CRPTables (Sum Integer))
go CRPTables (Sum Integer)
acc Int
i
      | Int
i forall a. Eq a => a -> a -> Bool
== Int
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure CRPTables (Sum Integer)
acc
      | Bool
otherwise = do
          CRPTables (Sum Integer)
acc' <- forall (m :: * -> *) b.
(PrimMonad m, Integral b) =>
Int -> CRPTables (Sum b) -> Double -> Prob m (CRPTables (Sum b))
crpSingle Int
i CRPTables (Sum Integer)
acc Double
a
          CRPTables (Sum Integer) -> Int -> Prob m (CRPTables (Sum Integer))
go CRPTables (Sum Integer)
acc' (Int
i forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINABLE crp #-}

-- | Update step of the CRP
crpSingle :: (PrimMonad m, Integral b) =>
             Int
          -> CRPTables (Sum b)
          -> Double
          -> Prob m (CRPTables (Sum b))
crpSingle :: forall (m :: * -> *) b.
(PrimMonad m, Integral b) =>
Int -> CRPTables (Sum b) -> Double -> Prob m (CRPTables (Sum b))
crpSingle Int
i CRPTables (Sum b)
zs Double
a = do
    Int
zn1 <- forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical [Double]
probs
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
zn1 CRPTables (Sum b)
zs
  where
    probs :: [Double]
probs = [Double]
pms forall a. Semigroup a => a -> a -> a
<> [Double
pm1]
    acc :: b -> Double
acc b
m = forall a b. (Integral a, Num b) => a -> b
fromIntegral b
m forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i forall a. Num a => a -> a -> a
- Double
1 forall a. Num a => a -> a -> a
+ Double
a)
    pms :: [Double]
pms = forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (b -> Double
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Sum a -> a
getSum) CRPTables (Sum b)
zs
    pm1 :: Double
pm1 = Double
a forall a. Fractional a => a -> a -> a
/ (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i forall a. Num a => a -> a -> a
- Double
1 forall a. Num a => a -> a -> a
+ Double
a)

-- Tables at the Chinese Restaurant
newtype CRPTables c = CRP {
    forall c. CRPTables c -> IntMap c
getCRPTables :: IM.IntMap c
  } deriving (CRPTables c -> CRPTables c -> Bool
forall c. Eq c => CRPTables c -> CRPTables c -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CRPTables c -> CRPTables c -> Bool
$c/= :: forall c. Eq c => CRPTables c -> CRPTables c -> Bool
== :: CRPTables c -> CRPTables c -> Bool
$c== :: forall c. Eq c => CRPTables c -> CRPTables c -> Bool
Eq, Int -> CRPTables c -> ShowS
forall c. Show c => Int -> CRPTables c -> ShowS
forall c. Show c => [CRPTables c] -> ShowS
forall c. Show c => CRPTables c -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CRPTables c] -> ShowS
$cshowList :: forall c. Show c => [CRPTables c] -> ShowS
show :: CRPTables c -> [Char]
$cshow :: forall c. Show c => CRPTables c -> [Char]
showsPrec :: Int -> CRPTables c -> ShowS
$cshowsPrec :: forall c. Show c => Int -> CRPTables c -> ShowS
Show, forall a b. a -> CRPTables b -> CRPTables a
forall a b. (a -> b) -> CRPTables a -> CRPTables b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> CRPTables b -> CRPTables a
$c<$ :: forall a b. a -> CRPTables b -> CRPTables a
fmap :: forall a b. (a -> b) -> CRPTables a -> CRPTables b
$cfmap :: forall a b. (a -> b) -> CRPTables a -> CRPTables b
Functor, forall a. Eq a => a -> CRPTables a -> Bool
forall a. Num a => CRPTables a -> a
forall a. Ord a => CRPTables a -> a
forall m. Monoid m => CRPTables m -> m
forall a. CRPTables a -> Bool
forall a. CRPTables a -> Int
forall a. CRPTables a -> [a]
forall a. (a -> a -> a) -> CRPTables a -> a
forall m a. Monoid m => (a -> m) -> CRPTables a -> m
forall b a. (b -> a -> b) -> b -> CRPTables a -> b
forall a b. (a -> b -> b) -> b -> CRPTables a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => CRPTables a -> a
$cproduct :: forall a. Num a => CRPTables a -> a
sum :: forall a. Num a => CRPTables a -> a
$csum :: forall a. Num a => CRPTables a -> a
minimum :: forall a. Ord a => CRPTables a -> a
$cminimum :: forall a. Ord a => CRPTables a -> a
maximum :: forall a. Ord a => CRPTables a -> a
$cmaximum :: forall a. Ord a => CRPTables a -> a
elem :: forall a. Eq a => a -> CRPTables a -> Bool
$celem :: forall a. Eq a => a -> CRPTables a -> Bool
length :: forall a. CRPTables a -> Int
$clength :: forall a. CRPTables a -> Int
null :: forall a. CRPTables a -> Bool
$cnull :: forall a. CRPTables a -> Bool
toList :: forall a. CRPTables a -> [a]
$ctoList :: forall a. CRPTables a -> [a]
foldl1 :: forall a. (a -> a -> a) -> CRPTables a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> CRPTables a -> a
foldr1 :: forall a. (a -> a -> a) -> CRPTables a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> CRPTables a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
foldl :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
foldr :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
fold :: forall m. Monoid m => CRPTables m -> m
$cfold :: forall m. Monoid m => CRPTables m -> m
Foldable, NonEmpty (CRPTables c) -> CRPTables c
CRPTables c -> CRPTables c -> CRPTables c
forall b. Integral b => b -> CRPTables c -> CRPTables c
forall c. NonEmpty (CRPTables c) -> CRPTables c
forall c. CRPTables c -> CRPTables c -> CRPTables c
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall c b. Integral b => b -> CRPTables c -> CRPTables c
stimes :: forall b. Integral b => b -> CRPTables c -> CRPTables c
$cstimes :: forall c b. Integral b => b -> CRPTables c -> CRPTables c
sconcat :: NonEmpty (CRPTables c) -> CRPTables c
$csconcat :: forall c. NonEmpty (CRPTables c) -> CRPTables c
<> :: CRPTables c -> CRPTables c -> CRPTables c
$c<> :: forall c. CRPTables c -> CRPTables c -> CRPTables c
Semigroup, CRPTables c
[CRPTables c] -> CRPTables c
CRPTables c -> CRPTables c -> CRPTables c
forall c. Semigroup (CRPTables c)
forall c. CRPTables c
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall c. [CRPTables c] -> CRPTables c
forall c. CRPTables c -> CRPTables c -> CRPTables c
mconcat :: [CRPTables c] -> CRPTables c
$cmconcat :: forall c. [CRPTables c] -> CRPTables c
mappend :: CRPTables c -> CRPTables c -> CRPTables c
$cmappend :: forall c. CRPTables c -> CRPTables c -> CRPTables c
mempty :: CRPTables c
$cmempty :: forall c. CRPTables c
Monoid)

-- Initial state of the CRP : one customer sitting at table #0
crpInitial :: CRPTables (Sum Integer)
crpInitial :: CRPTables (Sum Integer)
crpInitial = forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
0 forall a. Monoid a => a
mempty

-- Seat one customer at table 'k'
crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert :: forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
k (CRP IntMap (Sum a)
ts) = forall c. IntMap c -> CRPTables c
CRP forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IM.insertWith forall a. Semigroup a => a -> a -> a
(<>) Int
k (forall a. a -> Sum a
Sum a
1) IntMap (Sum a)
ts