{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
module System.Random.MWC.Probability (
module MWC
, Prob(..)
, samples
, uniform
, uniformR
, normal
, standardNormal
, isoNormal
, logNormal
, exponential
, inverseGaussian
, laplace
, gamma
, inverseGamma
, normalGamma
, weibull
, chiSquare
, beta
, gstudent
, student
, pareto
, dirichlet
, symmetricDirichlet
, discreteUniform
, zipf
, categorical
, discrete
, bernoulli
, binomial
, negativeBinomial
, multinomial
, poisson
, crp
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Data.Monoid (Sum(..))
#if __GLASGOW_HASKELL__ < 710
import Data.Foldable (Foldable)
#endif
import qualified Data.Foldable as F
import Data.List (findIndex)
import qualified Data.IntMap as IM
import System.Random.MWC as MWC hiding (uniform, uniformR)
import qualified System.Random.MWC as QMWC
import qualified System.Random.MWC.Distributions as MWC.Dist
import System.Random.MWC.CondensedTable
newtype Prob m a = Prob { forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample :: Gen (PrimState m) -> m a }
samples :: PrimMonad m => Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples :: forall (m :: * -> *) a.
PrimMonad m =>
Int -> Prob m a -> Gen (PrimState m) -> m [a]
samples Int
n Prob m a
model Gen (PrimState m)
gen = [m a] -> m [a]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a. Applicative f => [f a] -> f [a]
sequenceA (Int -> m a -> [m a]
forall a. Int -> a -> [a]
replicate Int
n (Prob m a -> Gen (PrimState m) -> m a
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample Prob m a
model Gen (PrimState m)
gen))
{-# INLINABLE samples #-}
instance Functor m => Functor (Prob m) where
fmap :: forall a b. (a -> b) -> Prob m a -> Prob m b
fmap a -> b
h (Prob Gen (PrimState m) -> m a
f) = (Gen (PrimState m) -> m b) -> Prob m b
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((a -> b) -> m a -> m b
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
h (m a -> m b)
-> (Gen (PrimState m) -> m a) -> Gen (PrimState m) -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gen (PrimState m) -> m a
f)
instance Monad m => Applicative (Prob m) where
pure :: forall a. a -> Prob m a
pure = (Gen (PrimState m) -> m a) -> Prob m a
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m a) -> Prob m a)
-> (a -> Gen (PrimState m) -> m a) -> a -> Prob m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> Gen (PrimState m) -> m a
forall a b. a -> b -> a
const (m a -> Gen (PrimState m) -> m a)
-> (a -> m a) -> a -> Gen (PrimState m) -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
<*> :: forall a b. Prob m (a -> b) -> Prob m a -> Prob m b
(<*>) = Prob m (a -> b) -> Prob m a -> Prob m b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance Monad m => Monad (Prob m) where
return :: forall a. a -> Prob m a
return = a -> Prob m a
forall a. a -> Prob m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Prob m a
m >>= :: forall a b. Prob m a -> (a -> Prob m b) -> Prob m b
>>= a -> Prob m b
h = (Gen (PrimState m) -> m b) -> Prob m b
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m b) -> Prob m b)
-> (Gen (PrimState m) -> m b) -> Prob m b
forall a b. (a -> b) -> a -> b
$ \Gen (PrimState m)
g -> do
z <- Prob m a -> Gen (PrimState m) -> m a
forall (m :: * -> *) a. Prob m a -> Gen (PrimState m) -> m a
sample Prob m a
m Gen (PrimState m)
g
sample (h z) g
{-# INLINABLE (>>=) #-}
instance (Monad m, Num a) => Num (Prob m a) where
+ :: Prob m a -> Prob m a -> Prob m a
(+) = (a -> a -> a) -> Prob m a -> Prob m a -> Prob m a
forall a b c. (a -> b -> c) -> Prob m a -> Prob m b -> Prob m c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(+)
(-) = (a -> a -> a) -> Prob m a -> Prob m a -> Prob m a
forall a b c. (a -> b -> c) -> Prob m a -> Prob m b -> Prob m c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
* :: Prob m a -> Prob m a -> Prob m a
(*) = (a -> a -> a) -> Prob m a -> Prob m a -> Prob m a
forall a b c. (a -> b -> c) -> Prob m a -> Prob m b -> Prob m c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(*)
abs :: Prob m a -> Prob m a
abs = (a -> a) -> Prob m a -> Prob m a
forall a b. (a -> b) -> Prob m a -> Prob m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
signum :: Prob m a -> Prob m a
signum = (a -> a) -> Prob m a -> Prob m a
forall a b. (a -> b) -> Prob m a -> Prob m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum
fromInteger :: Integer -> Prob m a
fromInteger = a -> Prob m a
forall a. a -> Prob m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Prob m a) -> (Integer -> a) -> Integer -> Prob m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
instance MonadTrans Prob where
lift :: forall (m :: * -> *) a. Monad m => m a -> Prob m a
lift m a
m = (Gen (PrimState m) -> m a) -> Prob m a
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m a) -> Prob m a)
-> (Gen (PrimState m) -> m a) -> Prob m a
forall a b. (a -> b) -> a -> b
$ m a -> Gen (PrimState m) -> m a
forall a b. a -> b -> a
const m a
m
instance MonadIO m => MonadIO (Prob m) where
liftIO :: forall a. IO a -> Prob m a
liftIO IO a
m = (Gen (PrimState m) -> m a) -> Prob m a
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m a) -> Prob m a)
-> (Gen (PrimState m) -> m a) -> Prob m a
forall a b. (a -> b) -> a -> b
$ m a -> Gen (PrimState m) -> m a
forall a b. a -> b -> a
const (IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
m)
instance PrimMonad m => PrimMonad (Prob m) where
type PrimState (Prob m) = PrimState m
primitive :: forall a.
(State# (PrimState (Prob m))
-> (# State# (PrimState (Prob m)), a #))
-> Prob m a
primitive = m a -> Prob m a
forall (m :: * -> *) a. Monad m => m a -> Prob m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> Prob m a)
-> ((State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a)
-> (State# (PrimState m) -> (# State# (PrimState m), a #))
-> Prob m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall a.
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive
{-# INLINE primitive #-}
uniform :: (PrimMonad m, Variate a) => Prob m a
uniform :: forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform = (Gen (PrimState m) -> m a) -> Prob m a
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob Gen (PrimState m) -> m a
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
forall (m :: * -> *). PrimMonad m => Gen (PrimState m) -> m a
QMWC.uniform
{-# INLINABLE uniform #-}
uniformR :: (PrimMonad m, Variate a) => (a, a) -> Prob m a
uniformR :: forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (a, a)
r = (Gen (PrimState m) -> m a) -> Prob m a
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m a) -> Prob m a)
-> (Gen (PrimState m) -> m a) -> Prob m a
forall a b. (a -> b) -> a -> b
$ (a, a) -> Gen (PrimState m) -> m a
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
forall (m :: * -> *).
PrimMonad m =>
(a, a) -> Gen (PrimState m) -> m a
QMWC.uniformR (a, a)
r
{-# INLINABLE uniformR #-}
discreteUniform :: (PrimMonad m, Foldable f) => f a -> Prob m a
discreteUniform :: forall (m :: * -> *) (f :: * -> *) a.
(PrimMonad m, Foldable f) =>
f a -> Prob m a
discreteUniform f a
cs = do
j <- (Int, Int) -> Prob m Int
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (Int
0, f a -> Int
forall a. f a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length f a
cs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
return $ F.toList cs !! j
{-# INLINABLE discreteUniform #-}
standardNormal :: PrimMonad m => Prob m Double
standardNormal :: forall (m :: * -> *). PrimMonad m => Prob m Double
standardNormal = (Gen (PrimState m) -> m Double) -> Prob m Double
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob Gen (PrimState m) -> m Double
forall g (m :: * -> *). StatefulGen g m => g -> m Double
MWC.Dist.standard
{-# INLINABLE standardNormal #-}
normal :: PrimMonad m => Double -> Double -> Prob m Double
normal :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
normal Double
m Double
sd = (Gen (PrimState m) -> m Double) -> Prob m Double
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m Double) -> Prob m Double)
-> (Gen (PrimState m) -> m Double) -> Prob m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Gen (PrimState m) -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.normal Double
m Double
sd
{-# INLINABLE normal #-}
logNormal :: PrimMonad m => Double -> Double -> Prob m Double
logNormal :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
logNormal Double
m Double
sd = Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Prob m Double -> Prob m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
normal Double
m Double
sd
{-# INLINABLE logNormal #-}
exponential :: PrimMonad m => Double -> Prob m Double
exponential :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Double
exponential Double
r = (Gen (PrimState m) -> m Double) -> Prob m Double
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m Double) -> Prob m Double)
-> (Gen (PrimState m) -> m Double) -> Prob m Double
forall a b. (a -> b) -> a -> b
$ Double -> Gen (PrimState m) -> m Double
forall g (m :: * -> *). StatefulGen g m => Double -> g -> m Double
MWC.Dist.exponential Double
r
{-# INLINABLE exponential #-}
laplace :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
laplace :: forall a (m :: * -> *).
(Floating a, Variate a, PrimMonad m) =>
a -> a -> Prob m a
laplace a
mu a
sigma = do
u <- (a, a) -> Prob m a
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (-a
0.5, a
0.5)
let b = a
sigma a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt a
2
return $ mu - b * signum u * log (1 - 2 * abs u)
{-# INLINABLE laplace #-}
weibull :: (Floating a, Variate a, PrimMonad m) => a -> a -> Prob m a
weibull :: forall a (m :: * -> *).
(Floating a, Variate a, PrimMonad m) =>
a -> a -> Prob m a
weibull a
a a
b = do
x <- Prob m a
forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
return $ (- 1/a * log (1 - x)) ** 1/b
{-# INLINABLE weibull #-}
gamma :: PrimMonad m => Double -> Double -> Prob m Double
gamma :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
b = (Gen (PrimState m) -> m Double) -> Prob m Double
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m Double) -> Prob m Double)
-> (Gen (PrimState m) -> m Double) -> Prob m Double
forall a b. (a -> b) -> a -> b
$ Double -> Double -> Gen (PrimState m) -> m Double
forall g (m :: * -> *).
StatefulGen g m =>
Double -> Double -> g -> m Double
MWC.Dist.gamma Double
a Double
b
{-# INLINABLE gamma #-}
inverseGamma :: PrimMonad m => Double -> Double -> Prob m Double
inverseGamma :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
inverseGamma Double
a Double
b = Double -> Double
forall a. Fractional a => a -> a
recip (Double -> Double) -> Prob m Double -> Prob m Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
b
{-# INLINABLE inverseGamma #-}
normalGamma :: PrimMonad m => Double -> Double -> Double -> Double -> Prob m Double
normalGamma :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Double -> Double -> Prob m Double
normalGamma Double
mu Double
lambda Double
a Double
b = do
tau <- Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
b
let xsd = Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double
forall a. Fractional a => a -> a
recip (Double
lambda Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
tau))
normal mu xsd
{-# INLINABLE normalGamma #-}
chiSquare :: PrimMonad m => Int -> Prob m Double
chiSquare :: forall (m :: * -> *). PrimMonad m => Int -> Prob m Double
chiSquare Int
k = (Gen (PrimState m) -> m Double) -> Prob m Double
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m Double) -> Prob m Double)
-> (Gen (PrimState m) -> m Double) -> Prob m Double
forall a b. (a -> b) -> a -> b
$ Int -> Gen (PrimState m) -> m Double
forall g (m :: * -> *). StatefulGen g m => Int -> g -> m Double
MWC.Dist.chiSquare Int
k
{-# INLINABLE chiSquare #-}
beta :: PrimMonad m => Double -> Double -> Prob m Double
beta :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
beta Double
a Double
b = do
u <- Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma Double
a Double
1
w <- gamma b 1
return $ u / (u + w)
{-# INLINABLE beta #-}
pareto :: PrimMonad m => Double -> Double -> Prob m Double
pareto :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
pareto Double
a Double
xmin = do
y <- Double -> Prob m Double
forall (m :: * -> *). PrimMonad m => Double -> Prob m Double
exponential Double
a
return $ xmin * exp y
{-# INLINABLE pareto #-}
dirichlet
:: (Traversable f, PrimMonad m) => f Double -> Prob m (f Double)
dirichlet :: forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Prob m (f Double)
dirichlet f Double
as = do
zs <- (Double -> Prob m Double) -> f Double -> Prob m (f Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
`gamma` Double
1) f Double
as
return $ fmap (/ sum zs) zs
{-# INLINABLE dirichlet #-}
symmetricDirichlet :: PrimMonad m => Int -> Double -> Prob m [Double]
symmetricDirichlet :: forall (m :: * -> *).
PrimMonad m =>
Int -> Double -> Prob m [Double]
symmetricDirichlet Int
n Double
a = [Double] -> Prob m [Double]
forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Prob m (f Double)
dirichlet (Int -> Double -> [Double]
forall a. Int -> a -> [a]
replicate Int
n Double
a)
{-# INLINABLE symmetricDirichlet #-}
bernoulli :: PrimMonad m => Double -> Prob m Bool
bernoulli :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli Double
p = (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p) (Double -> Bool) -> Prob m Double -> Prob m Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Prob m Double
forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
{-# INLINABLE bernoulli #-}
binomial :: PrimMonad m => Int -> Double -> Prob m Int
binomial :: forall (m :: * -> *). PrimMonad m => Int -> Double -> Prob m Int
binomial Int
n Double
p = ([Bool] -> Int) -> Prob m [Bool] -> Prob m Int
forall a b. (a -> b) -> Prob m a -> Prob m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Bool] -> Int) -> ([Bool] -> [Bool]) -> [Bool] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool -> Bool) -> [Bool] -> [Bool]
forall a. (a -> Bool) -> [a] -> [a]
filter Bool -> Bool
forall a. a -> a
id) (Prob m [Bool] -> Prob m Int) -> Prob m [Bool] -> Prob m Int
forall a b. (a -> b) -> a -> b
$ Int -> Prob m Bool -> Prob m [Bool]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Double -> Prob m Bool
forall (m :: * -> *). PrimMonad m => Double -> Prob m Bool
bernoulli Double
p)
{-# INLINABLE binomial #-}
negativeBinomial :: (PrimMonad m, Integral a) => a -> Double -> Prob m Int
negativeBinomial :: forall (m :: * -> *) a.
(PrimMonad m, Integral a) =>
a -> Double -> Prob m Int
negativeBinomial a
n Double
p = do
y <- Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
gamma (a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n) ((Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
p) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
p)
poisson y
{-# INLINABLE negativeBinomial #-}
multinomial :: (Foldable f, PrimMonad m) => Int -> f Double -> Prob m [Int]
multinomial :: forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
Int -> f Double -> Prob m [Int]
multinomial Int
n f Double
ps = do
let ([Double]
cumulative, Double
total) = [Double] -> ([Double], Double)
forall a. Num a => [a] -> ([a], a)
runningTotals (f Double -> [Double]
forall a. f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList f Double
ps)
Int -> Prob m Int -> Prob m [Int]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (Prob m Int -> Prob m [Int]) -> Prob m Int -> Prob m [Int]
forall a b. (a -> b) -> a -> b
$ do
z <- (Double, Double) -> Prob m Double
forall (m :: * -> *) a.
(PrimMonad m, Variate a) =>
(a, a) -> Prob m a
uniformR (Double
0, Double
total)
case findIndex (> z) cumulative of
Just Int
g -> Int -> Prob m Int
forall a. a -> Prob m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
g
Maybe Int
Nothing -> [Char] -> Prob m Int
forall a. HasCallStack => [Char] -> a
error [Char]
"mwc-probability: invalid probability vector"
where
runningTotals :: Num a => [a] -> ([a], a)
runningTotals :: forall a. Num a => [a] -> ([a], a)
runningTotals [a]
xs = let adds :: [a]
adds = (a -> a -> a) -> [a] -> [a]
forall a. (a -> a -> a) -> [a] -> [a]
scanl1 a -> a -> a
forall a. Num a => a -> a -> a
(+) [a]
xs in ([a]
adds, [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [a]
xs)
{-# INLINABLE multinomial #-}
gstudent :: PrimMonad m => Double -> Double -> Double -> Prob m Double
gstudent :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Double -> Prob m Double
gstudent Double
m Double
s Double
k = do
sd <- (Double -> Double) -> Prob m Double -> Prob m Double
forall a b. (a -> b) -> Prob m a -> Prob m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
inverseGamma (Double
k Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
2) (Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
2 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
k))
normal m sd
{-# INLINABLE gstudent #-}
student :: PrimMonad m => Double -> Prob m Double
student :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Double
student = Double -> Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Double -> Prob m Double
gstudent Double
0 Double
1
{-# INLINABLE student #-}
isoNormal
:: (Traversable f, PrimMonad m) => f Double -> Double -> Prob m (f Double)
isoNormal :: forall (f :: * -> *) (m :: * -> *).
(Traversable f, PrimMonad m) =>
f Double -> Double -> Prob m (f Double)
isoNormal f Double
ms Double
sd = (Double -> Prob m Double) -> f Double -> Prob m (f Double)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> f a -> f (f b)
traverse (Double -> Double -> Prob m Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
`normal` Double
sd) f Double
ms
{-# INLINABLE isoNormal #-}
inverseGaussian :: PrimMonad m => Double -> Double -> Prob m Double
inverseGaussian :: forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Prob m Double
inverseGaussian Double
lambda Double
mu = do
nu <- Prob m Double
forall (m :: * -> *). PrimMonad m => Prob m Double
standardNormal
let y = Double
nu Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
2
s = Double -> Double
forall a. Floating a => a -> a
sqrt (Double
4 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
lambda Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
y Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
mu Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
y Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
2)
x = Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
lambda) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
y Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
s))
thresh = Double
mu Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
mu Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
x)
z <- uniform
if z <= thresh
then return x
else return (mu ** 2 / x)
{-# INLINABLE inverseGaussian #-}
poisson :: PrimMonad m => Double -> Prob m Int
poisson :: forall (m :: * -> *). PrimMonad m => Double -> Prob m Int
poisson Double
l = (Gen (PrimState m) -> m Int) -> Prob m Int
forall (m :: * -> *) a. (Gen (PrimState m) -> m a) -> Prob m a
Prob ((Gen (PrimState m) -> m Int) -> Prob m Int)
-> (Gen (PrimState m) -> m Int) -> Prob m Int
forall a b. (a -> b) -> a -> b
$ CondensedTable Vector Int -> Gen (PrimState m) -> m Int
forall g (m :: * -> *) (v :: * -> *) a.
(StatefulGen g m, Vector v a) =>
CondensedTable v a -> g -> m a
genFromTable CondensedTable Vector Int
table where
table :: CondensedTable Vector Int
table = Double -> CondensedTable Vector Int
tablePoisson Double
l
{-# INLINABLE poisson #-}
categorical :: (Foldable f, PrimMonad m) => f Double -> Prob m Int
categorical :: forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical f Double
ps = do
xs <- Int -> f Double -> Prob m [Int]
forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
Int -> f Double -> Prob m [Int]
multinomial Int
1 f Double
ps
case xs of
[Int
x] -> Int -> Prob m Int
forall a. a -> Prob m a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
x
[Int]
_ -> [Char] -> Prob m Int
forall a. HasCallStack => [Char] -> a
error [Char]
"mwc-probability: invalid probability vector"
{-# INLINABLE categorical #-}
discrete :: (Foldable f, PrimMonad m) => f (Double, a) -> Prob m a
discrete :: forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, PrimMonad m) =>
f (Double, a) -> Prob m a
discrete f (Double, a)
d = do
let ([Double]
ps, [a]
xs) = [(Double, a)] -> ([Double], [a])
forall a b. [(a, b)] -> ([a], [b])
unzip (f (Double, a) -> [(Double, a)]
forall a. f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList f (Double, a)
d)
idx <- [Double] -> Prob m Int
forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical [Double]
ps
pure (xs !! idx)
{-# INLINABLE discrete #-}
zipf :: (PrimMonad m, Integral b) => Double -> Prob m b
zipf :: forall (m :: * -> *) b.
(PrimMonad m, Integral b) =>
Double -> Prob m b
zipf Double
a = do
let
b :: Double
b = Double
2 Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
go :: Prob m b
go = do
u <- Prob m Double
forall (m :: * -> *) a. (PrimMonad m, Variate a) => Prob m a
uniform
v <- uniform
let xInt = Double -> b
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double
u Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (- Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)))
x = b -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
xInt
t = (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
x) Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
a Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
if v * x * (t - 1) / (b - 1) <= t / b
then return xInt
else go
Prob m b
go
{-# INLINABLE zipf #-}
crp
:: PrimMonad m
=> Double
-> Int
-> Prob m [Integer]
crp :: forall (m :: * -> *).
PrimMonad m =>
Double -> Int -> Prob m [Integer]
crp Double
a Int
n = do
ts <- CRPTables (Sum Integer) -> Int -> Prob m (CRPTables (Sum Integer))
go CRPTables (Sum Integer)
crpInitial Int
1
pure $ F.toList (fmap getSum ts)
where
go :: CRPTables (Sum Integer) -> Int -> Prob m (CRPTables (Sum Integer))
go CRPTables (Sum Integer)
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = CRPTables (Sum Integer) -> Prob m (CRPTables (Sum Integer))
forall a. a -> Prob m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CRPTables (Sum Integer)
acc
| Bool
otherwise = do
acc' <- Int
-> CRPTables (Sum Integer)
-> Double
-> Prob m (CRPTables (Sum Integer))
forall (m :: * -> *) b.
(PrimMonad m, Integral b) =>
Int -> CRPTables (Sum b) -> Double -> Prob m (CRPTables (Sum b))
crpSingle Int
i CRPTables (Sum Integer)
acc Double
a
go acc' (i + 1)
{-# INLINABLE crp #-}
crpSingle :: (PrimMonad m, Integral b) =>
Int
-> CRPTables (Sum b)
-> Double
-> Prob m (CRPTables (Sum b))
crpSingle :: forall (m :: * -> *) b.
(PrimMonad m, Integral b) =>
Int -> CRPTables (Sum b) -> Double -> Prob m (CRPTables (Sum b))
crpSingle Int
i CRPTables (Sum b)
zs Double
a = do
zn1 <- [Double] -> Prob m Int
forall (f :: * -> *) (m :: * -> *).
(Foldable f, PrimMonad m) =>
f Double -> Prob m Int
categorical [Double]
probs
pure $ crpInsert zn1 zs
where
probs :: [Double]
probs = [Double]
pms [Double] -> [Double] -> [Double]
forall a. Semigroup a => a -> a -> a
<> [Double
pm1]
acc :: b -> Double
acc b
m = b -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral b
m Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
a)
pms :: [Double]
pms = CRPTables Double -> [Double]
forall a. CRPTables a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList (CRPTables Double -> [Double]) -> CRPTables Double -> [Double]
forall a b. (a -> b) -> a -> b
$ (Sum b -> Double) -> CRPTables (Sum b) -> CRPTables Double
forall a b. (a -> b) -> CRPTables a -> CRPTables b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (b -> Double
acc (b -> Double) -> (Sum b -> b) -> Sum b -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sum b -> b
forall a. Sum a -> a
getSum) CRPTables (Sum b)
zs
pm1 :: Double
pm1 = Double
a Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
a)
newtype CRPTables c = CRP {
forall c. CRPTables c -> IntMap c
getCRPTables :: IM.IntMap c
} deriving (CRPTables c -> CRPTables c -> Bool
(CRPTables c -> CRPTables c -> Bool)
-> (CRPTables c -> CRPTables c -> Bool) -> Eq (CRPTables c)
forall c. Eq c => CRPTables c -> CRPTables c -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall c. Eq c => CRPTables c -> CRPTables c -> Bool
== :: CRPTables c -> CRPTables c -> Bool
$c/= :: forall c. Eq c => CRPTables c -> CRPTables c -> Bool
/= :: CRPTables c -> CRPTables c -> Bool
Eq, Int -> CRPTables c -> ShowS
[CRPTables c] -> ShowS
CRPTables c -> [Char]
(Int -> CRPTables c -> ShowS)
-> (CRPTables c -> [Char])
-> ([CRPTables c] -> ShowS)
-> Show (CRPTables c)
forall c. Show c => Int -> CRPTables c -> ShowS
forall c. Show c => [CRPTables c] -> ShowS
forall c. Show c => CRPTables c -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall c. Show c => Int -> CRPTables c -> ShowS
showsPrec :: Int -> CRPTables c -> ShowS
$cshow :: forall c. Show c => CRPTables c -> [Char]
show :: CRPTables c -> [Char]
$cshowList :: forall c. Show c => [CRPTables c] -> ShowS
showList :: [CRPTables c] -> ShowS
Show, (forall a b. (a -> b) -> CRPTables a -> CRPTables b)
-> (forall a b. a -> CRPTables b -> CRPTables a)
-> Functor CRPTables
forall a b. a -> CRPTables b -> CRPTables a
forall a b. (a -> b) -> CRPTables a -> CRPTables b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> CRPTables a -> CRPTables b
fmap :: forall a b. (a -> b) -> CRPTables a -> CRPTables b
$c<$ :: forall a b. a -> CRPTables b -> CRPTables a
<$ :: forall a b. a -> CRPTables b -> CRPTables a
Functor, (forall m. Monoid m => CRPTables m -> m)
-> (forall m a. Monoid m => (a -> m) -> CRPTables a -> m)
-> (forall m a. Monoid m => (a -> m) -> CRPTables a -> m)
-> (forall a b. (a -> b -> b) -> b -> CRPTables a -> b)
-> (forall a b. (a -> b -> b) -> b -> CRPTables a -> b)
-> (forall b a. (b -> a -> b) -> b -> CRPTables a -> b)
-> (forall b a. (b -> a -> b) -> b -> CRPTables a -> b)
-> (forall a. (a -> a -> a) -> CRPTables a -> a)
-> (forall a. (a -> a -> a) -> CRPTables a -> a)
-> (forall a. CRPTables a -> [a])
-> (forall a. CRPTables a -> Bool)
-> (forall a. CRPTables a -> Int)
-> (forall a. Eq a => a -> CRPTables a -> Bool)
-> (forall a. Ord a => CRPTables a -> a)
-> (forall a. Ord a => CRPTables a -> a)
-> (forall a. Num a => CRPTables a -> a)
-> (forall a. Num a => CRPTables a -> a)
-> Foldable CRPTables
forall a. Eq a => a -> CRPTables a -> Bool
forall a. Num a => CRPTables a -> a
forall a. Ord a => CRPTables a -> a
forall m. Monoid m => CRPTables m -> m
forall a. CRPTables a -> Bool
forall a. CRPTables a -> Int
forall a. CRPTables a -> [a]
forall a. (a -> a -> a) -> CRPTables a -> a
forall m a. Monoid m => (a -> m) -> CRPTables a -> m
forall b a. (b -> a -> b) -> b -> CRPTables a -> b
forall a b. (a -> b -> b) -> b -> CRPTables a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => CRPTables m -> m
fold :: forall m. Monoid m => CRPTables m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> CRPTables a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
foldr :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> CRPTables a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
foldl :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> CRPTables a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> CRPTables a -> a
foldr1 :: forall a. (a -> a -> a) -> CRPTables a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> CRPTables a -> a
foldl1 :: forall a. (a -> a -> a) -> CRPTables a -> a
$ctoList :: forall a. CRPTables a -> [a]
toList :: forall a. CRPTables a -> [a]
$cnull :: forall a. CRPTables a -> Bool
null :: forall a. CRPTables a -> Bool
$clength :: forall a. CRPTables a -> Int
length :: forall a. CRPTables a -> Int
$celem :: forall a. Eq a => a -> CRPTables a -> Bool
elem :: forall a. Eq a => a -> CRPTables a -> Bool
$cmaximum :: forall a. Ord a => CRPTables a -> a
maximum :: forall a. Ord a => CRPTables a -> a
$cminimum :: forall a. Ord a => CRPTables a -> a
minimum :: forall a. Ord a => CRPTables a -> a
$csum :: forall a. Num a => CRPTables a -> a
sum :: forall a. Num a => CRPTables a -> a
$cproduct :: forall a. Num a => CRPTables a -> a
product :: forall a. Num a => CRPTables a -> a
Foldable, NonEmpty (CRPTables c) -> CRPTables c
CRPTables c -> CRPTables c -> CRPTables c
(CRPTables c -> CRPTables c -> CRPTables c)
-> (NonEmpty (CRPTables c) -> CRPTables c)
-> (forall b. Integral b => b -> CRPTables c -> CRPTables c)
-> Semigroup (CRPTables c)
forall b. Integral b => b -> CRPTables c -> CRPTables c
forall c. NonEmpty (CRPTables c) -> CRPTables c
forall c. CRPTables c -> CRPTables c -> CRPTables c
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall c b. Integral b => b -> CRPTables c -> CRPTables c
$c<> :: forall c. CRPTables c -> CRPTables c -> CRPTables c
<> :: CRPTables c -> CRPTables c -> CRPTables c
$csconcat :: forall c. NonEmpty (CRPTables c) -> CRPTables c
sconcat :: NonEmpty (CRPTables c) -> CRPTables c
$cstimes :: forall c b. Integral b => b -> CRPTables c -> CRPTables c
stimes :: forall b. Integral b => b -> CRPTables c -> CRPTables c
Semigroup, Semigroup (CRPTables c)
CRPTables c
Semigroup (CRPTables c) =>
CRPTables c
-> (CRPTables c -> CRPTables c -> CRPTables c)
-> ([CRPTables c] -> CRPTables c)
-> Monoid (CRPTables c)
[CRPTables c] -> CRPTables c
CRPTables c -> CRPTables c -> CRPTables c
forall c. Semigroup (CRPTables c)
forall c. CRPTables c
forall a.
Semigroup a =>
a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall c. [CRPTables c] -> CRPTables c
forall c. CRPTables c -> CRPTables c -> CRPTables c
$cmempty :: forall c. CRPTables c
mempty :: CRPTables c
$cmappend :: forall c. CRPTables c -> CRPTables c -> CRPTables c
mappend :: CRPTables c -> CRPTables c -> CRPTables c
$cmconcat :: forall c. [CRPTables c] -> CRPTables c
mconcat :: [CRPTables c] -> CRPTables c
Monoid)
crpInitial :: CRPTables (Sum Integer)
crpInitial :: CRPTables (Sum Integer)
crpInitial = Int -> CRPTables (Sum Integer) -> CRPTables (Sum Integer)
forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
0 CRPTables (Sum Integer)
forall a. Monoid a => a
mempty
crpInsert :: Num a => IM.Key -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert :: forall a. Num a => Int -> CRPTables (Sum a) -> CRPTables (Sum a)
crpInsert Int
k (CRP IntMap (Sum a)
ts) = IntMap (Sum a) -> CRPTables (Sum a)
forall c. IntMap c -> CRPTables c
CRP (IntMap (Sum a) -> CRPTables (Sum a))
-> IntMap (Sum a) -> CRPTables (Sum a)
forall a b. (a -> b) -> a -> b
$ (Sum a -> Sum a -> Sum a)
-> Int -> Sum a -> IntMap (Sum a) -> IntMap (Sum a)
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IM.insertWith Sum a -> Sum a -> Sum a
forall a. Semigroup a => a -> a -> a
(<>) Int
k (a -> Sum a
forall a. a -> Sum a
Sum a
1) IntMap (Sum a)
ts