{-# LANGUAGE FlexibleContexts #-}
module System.Random.MWC.CondensedTable (
CondensedTable
, CondensedTableV
, CondensedTableU
, genFromTable
, tableFromProbabilities
, tableFromWeights
, tableFromIntWeights
, tablePoisson
, tableBinomial
) where
import Control.Arrow (second,(***))
import Data.Word
import Data.Int
import Data.Bits
import qualified Data.Vector.Generic as G
import Data.Vector.Generic ((++))
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector as V
import Data.Vector.Generic (Vector)
import Numeric.SpecFunctions (logFactorial)
import System.Random.Stateful
import Prelude hiding ((++))
data CondensedTable v a =
CondensedTable
{-# UNPACK #-} !Word64 !(v a)
{-# UNPACK #-} !Word64 !(v a)
{-# UNPACK #-} !Word64 !(v a)
!(v a)
type CondensedTableU = CondensedTable U.Vector
type CondensedTableV = CondensedTable V.Vector
genFromTable :: (StatefulGen g m, Vector v a) => CondensedTable v a -> g -> m a
{-# INLINE genFromTable #-}
genFromTable :: forall g (m :: * -> *) (v :: * -> *) a.
(StatefulGen g m, Vector v a) =>
CondensedTable v a -> g -> m a
genFromTable CondensedTable v a
table g
gen = do
Word32
w <- forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM g
gen
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
Vector v a =>
CondensedTable v a -> Word64 -> a
lookupTable CondensedTable v a
table forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
w :: Word32)
lookupTable :: Vector v a => CondensedTable v a -> Word64 -> a
{-# INLINE lookupTable #-}
lookupTable :: forall (v :: * -> *) a.
Vector v a =>
CondensedTable v a -> Word64 -> a
lookupTable (CondensedTable Word64
na v a
aa Word64
nb v a
bb Word64
nc v a
cc v a
dd) Word64
i
| Word64
i forall a. Ord a => a -> a -> Bool
< Word64
na = v a
aa forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ( Word64
i forall a. Bits a => a -> Int -> a
`shiftR` Int
24)
| Word64
i forall a. Ord a => a -> a -> Bool
< Word64
nb = v a
bb forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ((Word64
i forall a. Num a => a -> a -> a
- Word64
na) forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
| Word64
i forall a. Ord a => a -> a -> Bool
< Word64
nc = v a
cc forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ((Word64
i forall a. Num a => a -> a -> a
- Word64
nb) forall a. Bits a => a -> Int -> a
`shiftR` Int
8 )
| Bool
otherwise = v a
dd forall {v :: * -> *} {a} {a}.
(Vector v a, Integral a) =>
v a -> a -> a
`at` ( Word64
i forall a. Num a => a -> a -> a
- Word64
nc)
where
at :: v a -> a -> a
at v a
arr a
j = forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
arr (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
j)
tableFromProbabilities
:: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
=> v (a, Double) -> CondensedTable v a
{-# INLINE tableFromProbabilities #-}
tableFromProbabilities :: forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities v (a, Double)
v
| forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v (a, Double)
tbl = forall a. String -> String -> a
pkgError String
"tableFromProbabilities" String
"empty vector of outcomes"
| Bool
otherwise = forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v a, Vector v Word32) =>
v (a, Word32) -> CondensedTable v a
tableFromIntWeights forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second forall a b. (a -> b) -> a -> b
$ forall {a}. Integral a => Double -> a
toWeight forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
* Double
mlt)) v (a, Double)
tbl
where
mlt :: Double
mlt = Double
4.294967296e9
tbl :: v (a, Double)
tbl = forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> v a
G.filter ((forall a. Ord a => a -> a -> Bool
> Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) v (a, Double)
v
toWeight :: Double -> a
toWeight Double
w | Double
w forall a. Ord a => a -> a -> Bool
> Double
mlt forall a. Num a => a -> a -> a
- Double
1 = a
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
32::Int) forall a. Num a => a -> a -> a
- a
1
| Bool
otherwise = forall a b. (RealFrac a, Integral b) => a -> b
round Double
w
tableFromWeights
:: (Vector v (a,Word32), Vector v (a,Double), Vector v a, Vector v Word32)
=> v (a, Double) -> CondensedTable v a
{-# INLINE tableFromWeights #-}
tableFromWeights :: forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromWeights = forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {v :: * -> *} {a} {c}.
(Fractional c, Vector v (a, c)) =>
v (a, c) -> v (a, c)
normalize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> v a
G.filter ((forall a. Ord a => a -> a -> Bool
> Double
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)
where
normalize :: v (a, c) -> v (a, c)
normalize v (a, c)
v
| forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v (a, c)
v = forall a. String -> String -> a
pkgError String
"tableFromWeights" String
"no positive weights"
| Bool
otherwise = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a. Fractional a => a -> a -> a
/ c
s)) v (a, c)
v
where
s :: c
s = forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a -> a
(+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) c
0 v (a, c)
v
tableFromIntWeights :: (Vector v (a,Word32), Vector v a, Vector v Word32)
=> v (a, Word32)
-> CondensedTable v a
{-# INLINE tableFromIntWeights #-}
tableFromIntWeights :: forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v a, Vector v Word32) =>
v (a, Word32) -> CondensedTable v a
tableFromIntWeights v (a, Word32)
v
| Int
n forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. String -> String -> a
pkgError String
"tableFromIntWeights" String
"empty table"
| Int
n forall a. Eq a => a -> a -> Bool
== Int
1 = let m :: Word64
m = Word64
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
32::Int) forall a. Num a => a -> a -> a
- Word64
1
in forall (v :: * -> *) a.
Word64
-> v a
-> Word64
-> v a
-> Word64
-> v a
-> v a
-> CondensedTable v a
CondensedTable
Word64
m (forall (v :: * -> *) a. Vector v a => Int -> a -> v a
G.replicate Int
256 forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. Vector v a => v a -> a
G.head v (a, Word32)
tbl)
Word64
m forall (v :: * -> *) a. Vector v a => v a
G.empty
Word64
m forall (v :: * -> *) a. Vector v a => v a
G.empty
forall (v :: * -> *) a. Vector v a => v a
G.empty
| Bool
otherwise = forall (v :: * -> *) a.
Word64
-> v a
-> Word64
-> v a
-> Word64
-> v a
-> v a
-> CondensedTable v a
CondensedTable
Word64
na v a
aa
Word64
nb v a
bb
Word64
nc v a
cc
v a
dd
where
tbl :: v (a, Word32)
tbl = forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> v a
G.filter ((forall a. Eq a => a -> a -> Bool
/=Word32
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) v (a, Word32)
v
n :: Int
n = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v (a, Word32)
tbl
table :: v (a, Word32)
table = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v a -> v b -> v (a, b)
G.zip forall a b. (a -> b) -> a -> b
$ forall a. a -> a
id forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** forall (v :: * -> *). Vector v Word32 => v Word32 -> v Word32
correctWeights forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v a, Vector v b, Vector v (a, b)) =>
v (a, b) -> (v a, v b)
G.unzip v (a, Word32)
tbl
mkTable :: Int -> v a
mkTable Int
d =
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> v b) -> v a -> v b
G.concatMap (\(a
x,Word32
w) -> forall (v :: * -> *) a. Vector v a => Int -> a -> v a
G.replicate (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int -> Word32 -> Word32
digit Int
d Word32
w) a
x) v (a, Word32)
table
len :: v a -> Word64
len = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. Vector v a => v a -> Int
G.length
aa :: v a
aa = Int -> v a
mkTable Int
0
bb :: v a
bb = Int -> v a
mkTable Int
1
cc :: v a
cc = Int -> v a
mkTable Int
2
dd :: v a
dd = Int -> v a
mkTable Int
3
na :: Word64
na = v a -> Word64
len v a
aa forall a. Bits a => a -> Int -> a
`shiftL` Int
24
nb :: Word64
nb = Word64
na forall a. Num a => a -> a -> a
+ (v a -> Word64
len v a
bb forall a. Bits a => a -> Int -> a
`shiftL` Int
16)
nc :: Word64
nc = Word64
nb forall a. Num a => a -> a -> a
+ (v a -> Word64
len v a
cc forall a. Bits a => a -> Int -> a
`shiftL` Int
8)
digit :: Int -> Word32 -> Word32
digit :: Int -> Word32 -> Word32
digit Int
0 Word32
x = Word32
x forall a. Bits a => a -> Int -> a
`shiftR` Int
24
digit Int
1 Word32
x = (Word32
x forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Word32
0xff
digit Int
2 Word32
x = (Word32
x forall a. Bits a => a -> Int -> a
`shiftR` Int
8 ) forall a. Bits a => a -> a -> a
.&. Word32
0xff
digit Int
3 Word32
x = Word32
x forall a. Bits a => a -> a -> a
.&. Word32
0xff
digit Int
_ Word32
_ = forall a. String -> String -> a
pkgError String
"digit" String
"the impossible happened!?"
{-# INLINE digit #-}
correctWeights :: G.Vector v Word32 => v Word32 -> v Word32
{-# INLINE correctWeights #-}
correctWeights :: forall (v :: * -> *). Vector v Word32 => v Word32 -> v Word32
correctWeights v Word32
v = forall (v :: * -> *) a.
Vector v a =>
(forall s. ST s (Mutable v s a)) -> v a
G.create forall a b. (a -> b) -> a -> b
$ do
let
s :: Int64
s = forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a -> a
(+) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral) Int64
0 v Word32
v :: Int64
n :: Int
n = forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v Word32
v
Mutable v (PrimState (ST s)) Word32
arr <- forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v Word32
v
let loop :: Word32 -> Int -> a -> ST s ()
loop Word32
lim Int
i a
delta
| a
delta forall a. Eq a => a -> a -> Bool
== a
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Int
i forall a. Ord a => a -> a -> Bool
>= Int
n = Word32 -> Int -> a -> ST s ()
loop Word32
1 Int
0 a
delta
| Bool
otherwise = do
Word32
w <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
M.read Mutable v (PrimState (ST s)) Word32
arr Int
i
case () of
()
_| Word32
w forall a. Ord a => a -> a -> Bool
< Word32
lim -> Word32 -> Int -> a -> ST s ()
loop Word32
lim (Int
iforall a. Num a => a -> a -> a
+Int
1) a
delta
| a
delta forall a. Ord a => a -> a -> Bool
< a
0 -> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write Mutable v (PrimState (ST s)) Word32
arr Int
i (Word32
w forall a. Num a => a -> a -> a
+ Word32
1) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word32 -> Int -> a -> ST s ()
loop Word32
lim (Int
iforall a. Num a => a -> a -> a
+Int
1) (a
delta forall a. Num a => a -> a -> a
+ a
1)
| Bool
otherwise -> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
M.write Mutable v (PrimState (ST s)) Word32
arr Int
i (Word32
w forall a. Num a => a -> a -> a
- Word32
1) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Word32 -> Int -> a -> ST s ()
loop Word32
lim (Int
iforall a. Num a => a -> a -> a
+Int
1) (a
delta forall a. Num a => a -> a -> a
- a
1)
forall {a}. (Num a, Ord a) => Word32 -> Int -> a -> ST s ()
loop Word32
255 Int
0 (Int64
s forall a. Num a => a -> a -> a
- Int64
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
32::Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v (PrimState (ST s)) Word32
arr
tablePoisson :: Double -> CondensedTableU Int
tablePoisson :: Double -> CondensedTableU Int
tablePoisson = forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Vector (Int, Double)
make
where
make :: Double -> Vector (Int, Double)
make Double
lam
| Double
lam forall a. Ord a => a -> a -> Bool
< Double
0 = forall a. String -> String -> a
pkgError String
"tablePoisson" String
"negative lambda"
| Double
lam forall a. Ord a => a -> a -> Bool
< Double
22.8 = forall a b. Unbox a => (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldr forall {b}.
Integral b =>
(Double, b) -> Maybe ((b, Double), (Double, b))
unfoldForward (forall a. Floating a => a -> a
exp (-Double
lam), Int
0)
| Bool
otherwise = forall a b. Unbox a => (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldr forall {b}.
Integral b =>
(Double, b) -> Maybe ((b, Double), (Double, b))
unfoldForward (Double
pMax, Int
nMax)
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
++ forall a. Unbox a => Vector a -> Vector a
U.tail (forall a b. Unbox a => (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldr forall {b}.
Integral b =>
(Double, b) -> Maybe ((b, Double), (Double, b))
unfoldBackward (Double
pMax, Int
nMax))
where
nMax :: Int
nMax = forall a b. (RealFrac a, Integral b) => a -> b
floor Double
lam :: Int
pMax :: Double
pMax = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMax forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log Double
lam forall a. Num a => a -> a -> a
- Double
lam forall a. Num a => a -> a -> a
- forall a. Integral a => a -> Double
logFactorial Int
nMax
unfoldForward :: (Double, b) -> Maybe ((b, Double), (Double, b))
unfoldForward (Double
p,b
i)
| Double
p forall a. Ord a => a -> a -> Bool
< Double
minP = forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just ( (b
i,Double
p)
, (Double
p forall a. Num a => a -> a -> a
* Double
lam forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (b
iforall a. Num a => a -> a -> a
+b
1), b
iforall a. Num a => a -> a -> a
+b
1)
)
unfoldBackward :: (Double, b) -> Maybe ((b, Double), (Double, b))
unfoldBackward (Double
p,b
i)
| Double
p forall a. Ord a => a -> a -> Bool
< Double
minP = forall a. Maybe a
Nothing
| Bool
otherwise = forall a. a -> Maybe a
Just ( (b
i,Double
p)
, (Double
p forall a. Fractional a => a -> a -> a
/ Double
lam forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral b
i, b
iforall a. Num a => a -> a -> a
-b
1)
)
minP :: Double
minP = Double
1.1641532182693481e-10
tableBinomial :: Int
-> Double
-> CondensedTableU Int
tableBinomial :: Int -> Double -> CondensedTableU Int
tableBinomial Int
n Double
p = forall (v :: * -> *) a.
(Vector v (a, Word32), Vector v (a, Double), Vector v a,
Vector v Word32) =>
v (a, Double) -> CondensedTable v a
tableFromProbabilities Vector (Int, Double)
makeBinom
where
makeBinom :: Vector (Int, Double)
makeBinom
| Int
n forall a. Ord a => a -> a -> Bool
<= Int
0 = forall a. String -> String -> a
pkgError String
"tableBinomial" String
"non-positive number of tries"
| Double
p forall a. Eq a => a -> a -> Bool
== Double
0 = forall a. Unbox a => a -> Vector a
U.singleton (Int
0,Double
1)
| Double
p forall a. Eq a => a -> a -> Bool
== Double
1 = forall a. Unbox a => a -> Vector a
U.singleton (Int
n,Double
1)
| Double
p forall a. Ord a => a -> a -> Bool
> Double
0 Bool -> Bool -> Bool
&& Double
p forall a. Ord a => a -> a -> Bool
< Double
1 = forall a b. Unbox a => Int -> (b -> Maybe (a, b)) -> b -> Vector a
U.unfoldrN (Int
n forall a. Num a => a -> a -> a
+ Int
1) (Double, Int) -> Maybe ((Int, Double), (Double, Int))
unfolder ((Double
1forall a. Num a => a -> a -> a
-Double
p)forall a b. (Num a, Integral b) => a -> b -> a
^Int
n, Int
0)
| Bool
otherwise = forall a. String -> String -> a
pkgError String
"tableBinomial" String
"probability is out of range"
where
h :: Double
h = Double
p forall a. Fractional a => a -> a -> a
/ (Double
1 forall a. Num a => a -> a -> a
- Double
p)
unfolder :: (Double, Int) -> Maybe ((Int, Double), (Double, Int))
unfolder (Double
t,Int
i) = forall a. a -> Maybe a
Just ( (Int
i,Double
t)
, (Double
t forall a. Num a => a -> a -> a
* (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int
n forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
- Int
i1) forall a. Num a => a -> a -> a
* Double
h forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i1, Int
i1) )
where i1 :: Int
i1 = Int
i forall a. Num a => a -> a -> a
+ Int
1
pkgError :: String -> String -> a
pkgError :: forall a. String -> String -> a
pkgError String
func String
err =
forall a. HasCallStack => String -> a
error forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ [String
"System.Random.MWC.CondensedTable.", String
func, String
": ", String
err]