{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}
module Torch.Optim where
import Control.Monad.State
import Control.Monad (foldM)
import System.Mem (performGC)
import Torch.Autograd
import Torch.Functional
import Torch.Internal.GC (mallocTrim)
import Torch.NN
import Torch.Tensor
import Torch.TensorFactories
import Prelude hiding (sqrt)
import GHC.Generics (Generic)
import Control.DeepSeq (NFData, force)
type LearningRate = Tensor
type Loss = Tensor
newtype Gradients = Gradients [Tensor] deriving (Int -> Gradients -> ShowS
[Gradients] -> ShowS
Gradients -> String
(Int -> Gradients -> ShowS)
-> (Gradients -> String)
-> ([Gradients] -> ShowS)
-> Show Gradients
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Gradients -> ShowS
showsPrec :: Int -> Gradients -> ShowS
$cshow :: Gradients -> String
show :: Gradients -> String
$cshowList :: [Gradients] -> ShowS
showList :: [Gradients] -> ShowS
Show)
newtype OptimizerState option = OptimizerState option
grad' :: Loss -> [Parameter] -> Gradients
grad' :: Tensor -> [Parameter] -> Gradients
grad' Tensor
t [Parameter]
p = [Tensor] -> Gradients
Gradients (Tensor -> [Parameter] -> [Tensor]
grad Tensor
t [Parameter]
p)
class Optimizer optimizer where
step :: LearningRate -> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
runStep :: (Parameterized model) => model -> optimizer -> Loss -> LearningRate -> IO (model, optimizer)
runStep model
paramState optimizer
optState Tensor
lossValue = model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
forall model.
Parameterized model =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
forall optimizer model.
(Optimizer optimizer, Parameterized model) =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
runStep' model
paramState optimizer
optState (Tensor -> [Parameter] -> Gradients
grad' Tensor
lossValue ([Parameter] -> Gradients) -> [Parameter] -> Gradients
forall a b. (a -> b) -> a -> b
$ model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState)
runStep' :: (Parameterized model) => model -> optimizer -> Gradients -> LearningRate -> IO (model, optimizer)
runStep' model
paramState optimizer
optState Gradients
gradients Tensor
lr = do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
let ([Tensor]
flatParameters', optimizer
optState') = Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
forall optimizer.
Optimizer optimizer =>
Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
step Tensor
lr Gradients
gradients [Tensor]
depParameters optimizer
optState
newFlatParam <- (Tensor -> IO Parameter) -> [Tensor] -> IO [Parameter]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Tensor -> IO Parameter
makeIndependent [Tensor]
flatParameters'
pure (replaceParameters paramState newFlatParam, optState')
where
flatParameters :: [Parameter]
flatParameters = model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState
depParameters :: [Tensor]
depParameters = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter -> Tensor
toDependent [Parameter]
flatParameters
data GD = GD deriving (Int -> GD -> ShowS
[GD] -> ShowS
GD -> String
(Int -> GD -> ShowS)
-> (GD -> String) -> ([GD] -> ShowS) -> Show GD
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GD -> ShowS
showsPrec :: Int -> GD -> ShowS
$cshow :: GD -> String
show :: GD -> String
$cshowList :: [GD] -> ShowS
showList :: [GD] -> ShowS
Show)
gd :: LearningRate -> Gradients -> [Tensor] -> [Tensor]
gd :: Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
parameters [Tensor]
gradients
where
step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- (Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
gd' :: LearningRate -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' Tensor
lr Gradients
gradients [Tensor]
depParameters GD
dummy = (Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr Gradients
gradients [Tensor]
depParameters, GD
dummy)
instance Optimizer GD where
step :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
step = Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd'
sgd :: LearningRate -> [Parameter] -> [Tensor] -> [Tensor]
sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor]
sgd Tensor
lr [Parameter]
parameters = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
depParameters
where
step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- (Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
depParameters :: [Tensor]
depParameters = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent [Parameter]
parameters
data GDM = GDM {GDM -> Float
beta :: Float, GDM -> [Tensor]
momentum :: [Tensor]} deriving (Int -> GDM -> ShowS
[GDM] -> ShowS
GDM -> String
(Int -> GDM -> ShowS)
-> (GDM -> String) -> ([GDM] -> ShowS) -> Show GDM
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GDM -> ShowS
showsPrec :: Int -> GDM -> ShowS
$cshow :: GDM -> String
show :: GDM -> String
$cshowList :: [GDM] -> ShowS
showList :: [GDM] -> ShowS
Show)
gdm ::
LearningRate ->
Gradients ->
[Tensor] ->
GDM ->
([Tensor], GDM)
gdm :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters (GDM Float
beta [Tensor]
momentum) =
(((Tensor, Tensor) -> Tensor) -> [(Tensor, Tensor)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor, Tensor) -> Tensor
forall a b. (a, b) -> a
fst [(Tensor, Tensor)]
runStep, Float -> [Tensor] -> GDM
GDM Float
beta (((Tensor, Tensor) -> Tensor) -> [(Tensor, Tensor)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor, Tensor) -> Tensor
forall a b. (a, b) -> b
snd [(Tensor, Tensor)]
runStep))
where
step :: Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step Tensor
p Tensor
dp Tensor
z = let z' :: Tensor
z' = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta Tensor
z Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
dp in (Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
z', Tensor
z')
runStep :: [(Tensor, Tensor)]
runStep = (Tensor -> Tensor -> Tensor -> (Tensor, Tensor))
-> [Tensor] -> [Tensor] -> [Tensor] -> [(Tensor, Tensor)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step [Tensor]
parameters [Tensor]
gradients [Tensor]
momentum
instance Optimizer GDM where
step :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
step = Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm
data Adam = Adam
{ Adam -> Float
beta1 :: Float,
Adam -> Float
beta2 :: Float,
Adam -> [Tensor]
m1 :: [Tensor],
Adam -> [Tensor]
m2 :: [Tensor],
Adam -> Int
iter :: Int
}
deriving (Int -> Adam -> ShowS
[Adam] -> ShowS
Adam -> String
(Int -> Adam -> ShowS)
-> (Adam -> String) -> ([Adam] -> ShowS) -> Show Adam
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Adam -> ShowS
showsPrec :: Int -> Adam -> ShowS
$cshow :: Adam -> String
show :: Adam -> String
$cshowList :: [Adam] -> ShowS
showList :: [Adam] -> ShowS
Show, (forall x. Adam -> Rep Adam x)
-> (forall x. Rep Adam x -> Adam) -> Generic Adam
forall x. Rep Adam x -> Adam
forall x. Adam -> Rep Adam x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Adam -> Rep Adam x
from :: forall x. Adam -> Rep Adam x
$cto :: forall x. Rep Adam x -> Adam
to :: forall x. Rep Adam x -> Adam
Generic)
instance NFData Adam
mkAdam ::
Int ->
Float ->
Float ->
[Parameter] ->
Adam
mkAdam :: Int -> Float -> Float -> [Parameter] -> Adam
mkAdam Int
iter Float
beta1 Float
beta2 [Parameter]
parameters =
Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam
Float
beta1
Float
beta2
(Parameter -> Tensor
initZeros (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
(Parameter -> Tensor
initZeros (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
Int
iter
where
initZeros :: Parameter -> Tensor
initZeros = Tensor -> Tensor
zerosLike (Tensor -> Tensor) -> (Parameter -> Tensor) -> Parameter -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
toDependent
adam ::
LearningRate ->
Gradients ->
[Tensor] ->
Adam ->
([Tensor], Adam)
adam :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adam {Float
Int
[Tensor]
beta1 :: Adam -> Float
beta2 :: Adam -> Float
m1 :: Adam -> [Tensor]
m2 :: Adam -> [Tensor]
iter :: Adam -> Int
beta1 :: Float
beta2 :: Float
m1 :: [Tensor]
m2 :: [Tensor]
iter :: Int
..} = ([Tensor]
parameters', Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam Float
beta1 Float
beta2 [Tensor]
m1' [Tensor]
m2' (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
where
f1 :: Tensor -> Tensor -> Tensor
f1 Tensor
m1 Tensor
dp = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta1 Tensor
m1 Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta1) Tensor
dp
f2 :: Tensor -> Tensor -> Tensor
f2 Tensor
m2 Tensor
dp = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta2 Tensor
m2 Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta2) (Tensor
dp Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
m1' :: [Tensor]
m1' = [Tensor] -> [Tensor]
forall a. NFData a => a -> a
force ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f1 [Tensor]
m1 [Tensor]
gradients
m2' :: [Tensor]
m2' = [Tensor] -> [Tensor]
forall a. NFData a => a -> a
force ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f2 [Tensor]
m2 [Tensor]
gradients
a :: a -> Tensor -> Tensor
a a
beta = a -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
divScalar (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
beta a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
a1 :: [Tensor]
a1 = (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Tensor -> Tensor
forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta1) [Tensor]
m1'
a2 :: [Tensor]
a2 = (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Tensor -> Tensor
forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta2) [Tensor]
m2'
eps :: Tensor
eps = Tensor
1e-37
update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
a1' Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt Tensor
a2' Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
eps)
parameters' :: [Tensor]
parameters' = (Tensor -> Tensor -> Tensor -> Tensor)
-> [Tensor] -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
a1 [Tensor]
a2
instance Optimizer Adam where
step :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
step = Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam
data Adagrad = Adagrad {Adagrad -> [Tensor]
gsum :: [Tensor]}
deriving (Int -> Adagrad -> ShowS
[Adagrad] -> ShowS
Adagrad -> String
(Int -> Adagrad -> ShowS)
-> (Adagrad -> String) -> ([Adagrad] -> ShowS) -> Show Adagrad
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Adagrad -> ShowS
showsPrec :: Int -> Adagrad -> ShowS
$cshow :: Adagrad -> String
show :: Adagrad -> String
$cshowList :: [Adagrad] -> ShowS
showList :: [Adagrad] -> ShowS
Show)
adagrad ::
LearningRate ->
Gradients ->
[Tensor] ->
Adagrad ->
([Tensor], Adagrad)
adagrad :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adagrad {[Tensor]
gsum :: Adagrad -> [Tensor]
gsum :: [Tensor]
..} = ([Tensor]
parameters', [Tensor] -> Adagrad
Adagrad [Tensor]
gsum')
where
f :: a -> a -> a
f a
gsum a
dp = a
gsum a -> a -> a
forall a. Num a => a -> a -> a
+ a
dp a -> a -> a
forall a. Num a => a -> a -> a
* a
dp
gsum' :: [Tensor]
gsum' = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
f [Tensor]
gsum [Tensor]
gradients
eps :: Tensor
eps = Tensor
1e-37
update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
a1' Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt (Tensor
a2' Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
eps))
parameters' :: [Tensor]
parameters' = (Tensor -> Tensor -> Tensor -> Tensor)
-> [Tensor] -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
gradients [Tensor]
gsum'
instance Optimizer Adagrad where
step :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
step = Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop :: forall a. a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop a
x Int
count a -> Int -> IO a
block = (a -> Int -> IO a) -> a -> [Int] -> IO a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> IO a
block a
x [Int
1 .. Int
count]