{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveGeneric #-}

module Torch.Optim where

import Control.Monad.State
import Control.Monad (foldM)
import System.Mem (performGC)
import Torch.Autograd
import Torch.Functional
import Torch.Internal.GC (mallocTrim)
import Torch.NN
import Torch.Tensor
import Torch.TensorFactories
import Prelude hiding (sqrt)
import GHC.Generics (Generic)
import Control.DeepSeq (NFData, force)

type LearningRate = Tensor

type Loss = Tensor

newtype Gradients = Gradients [Tensor] deriving (Int -> Gradients -> ShowS
[Gradients] -> ShowS
Gradients -> String
(Int -> Gradients -> ShowS)
-> (Gradients -> String)
-> ([Gradients] -> ShowS)
-> Show Gradients
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Gradients -> ShowS
showsPrec :: Int -> Gradients -> ShowS
$cshow :: Gradients -> String
show :: Gradients -> String
$cshowList :: [Gradients] -> ShowS
showList :: [Gradients] -> ShowS
Show)

newtype OptimizerState option = OptimizerState option

grad' :: Loss -> [Parameter] -> Gradients
grad' :: Tensor -> [Parameter] -> Gradients
grad' Tensor
t [Parameter]
p = [Tensor] -> Gradients
Gradients (Tensor -> [Parameter] -> [Tensor]
grad Tensor
t [Parameter]
p)

class Optimizer optimizer where
  step :: LearningRate -> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep :: (Parameterized model) => model -> optimizer -> Loss -> LearningRate -> IO (model, optimizer)
  runStep model
paramState optimizer
optState Tensor
lossValue = model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
forall model.
Parameterized model =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
forall optimizer model.
(Optimizer optimizer, Parameterized model) =>
model -> optimizer -> Gradients -> Tensor -> IO (model, optimizer)
runStep' model
paramState optimizer
optState (Tensor -> [Parameter] -> Gradients
grad' Tensor
lossValue ([Parameter] -> Gradients) -> [Parameter] -> Gradients
forall a b. (a -> b) -> a -> b
$ model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState)

  -- | run a single iteration of an optimizer, returning new parameters and updated optimizer state
  runStep' :: (Parameterized model) => model -> optimizer -> Gradients -> LearningRate -> IO (model, optimizer)
  runStep' model
paramState optimizer
optState Gradients
gradients Tensor
lr = do
    IO ()
performGC
    CInt -> IO ()
mallocTrim CInt
0
    let ([Tensor]
flatParameters', optimizer
optState') = Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
forall optimizer.
Optimizer optimizer =>
Tensor
-> Gradients -> [Tensor] -> optimizer -> ([Tensor], optimizer)
step Tensor
lr Gradients
gradients [Tensor]
depParameters optimizer
optState
    newFlatParam <- (Tensor -> IO Parameter) -> [Tensor] -> IO [Parameter]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Tensor -> IO Parameter
makeIndependent [Tensor]
flatParameters'
    pure (replaceParameters paramState newFlatParam, optState')
    where
      flatParameters :: [Parameter]
flatParameters = model -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters model
paramState
      depParameters :: [Tensor]
depParameters = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter -> Tensor
toDependent [Parameter]
flatParameters

--
-- Gradient Descent
--

data GD = GD deriving (Int -> GD -> ShowS
[GD] -> ShowS
GD -> String
(Int -> GD -> ShowS)
-> (GD -> String) -> ([GD] -> ShowS) -> Show GD
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GD -> ShowS
showsPrec :: Int -> GD -> ShowS
$cshow :: GD -> String
show :: GD -> String
$cshowList :: [GD] -> ShowS
showList :: [GD] -> ShowS
Show)

-- | Stateless gradient descent step
gd :: LearningRate -> Gradients -> [Tensor] -> [Tensor]
gd :: Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
parameters [Tensor]
gradients
  where
    step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- (Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)

-- | Gradient descent step with a dummy state variable
gd' :: LearningRate -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd' Tensor
lr Gradients
gradients [Tensor]
depParameters GD
dummy = (Tensor -> Gradients -> [Tensor] -> [Tensor]
gd Tensor
lr Gradients
gradients [Tensor]
depParameters, GD
dummy)

instance Optimizer GD where
  step :: Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
step = Tensor -> Gradients -> [Tensor] -> GD -> ([Tensor], GD)
gd'

sgd :: LearningRate -> [Parameter] -> [Tensor] -> [Tensor]
sgd :: Tensor -> [Parameter] -> [Tensor] -> [Tensor]
sgd Tensor
lr [Parameter]
parameters = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
step [Tensor]
depParameters
  where
    step :: Tensor -> Tensor -> Tensor
step Tensor
p Tensor
dp = Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- (Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
    depParameters :: [Tensor]
depParameters = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent [Parameter]
parameters

--
-- Gradient Descent with Momentum
--

data GDM = GDM {GDM -> Float
beta :: Float, GDM -> [Tensor]
momentum :: [Tensor]} deriving (Int -> GDM -> ShowS
[GDM] -> ShowS
GDM -> String
(Int -> GDM -> ShowS)
-> (GDM -> String) -> ([GDM] -> ShowS) -> Show GDM
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GDM -> ShowS
showsPrec :: Int -> GDM -> ShowS
$cshow :: GDM -> String
show :: GDM -> String
$cshowList :: [GDM] -> ShowS
showList :: [GDM] -> ShowS
Show)

-- gradient descent with momentum step
gdm ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | beta & momentum
  GDM ->
  -- | returns new parameters + updated momentum
  ([Tensor], GDM)
gdm :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters (GDM Float
beta [Tensor]
momentum) =
  (((Tensor, Tensor) -> Tensor) -> [(Tensor, Tensor)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor, Tensor) -> Tensor
forall a b. (a, b) -> a
fst [(Tensor, Tensor)]
runStep, Float -> [Tensor] -> GDM
GDM Float
beta (((Tensor, Tensor) -> Tensor) -> [(Tensor, Tensor)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Tensor, Tensor) -> Tensor
forall a b. (a, b) -> b
snd [(Tensor, Tensor)]
runStep))
  where
    step :: Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step Tensor
p Tensor
dp Tensor
z = let z' :: Tensor
z' = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta Tensor
z Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
dp in (Tensor
p Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
z', Tensor
z')
    runStep :: [(Tensor, Tensor)]
runStep = (Tensor -> Tensor -> Tensor -> (Tensor, Tensor))
-> [Tensor] -> [Tensor] -> [Tensor] -> [(Tensor, Tensor)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> (Tensor, Tensor)
step [Tensor]
parameters [Tensor]
gradients [Tensor]
momentum

instance Optimizer GDM where
  step :: Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
step = Tensor -> Gradients -> [Tensor] -> GDM -> ([Tensor], GDM)
gdm

--
-- Adam
--

-- | State representation for Adam Optimizer
data Adam = Adam
  { Adam -> Float
beta1 :: Float, -- 1st moment forgetting factor
    Adam -> Float
beta2 :: Float, -- 2nd moment forgetting factor
    Adam -> [Tensor]
m1 :: [Tensor], -- 1st moment
    Adam -> [Tensor]
m2 :: [Tensor], -- 2nd moment
    Adam -> Int
iter :: Int -- iteration
  }
  deriving (Int -> Adam -> ShowS
[Adam] -> ShowS
Adam -> String
(Int -> Adam -> ShowS)
-> (Adam -> String) -> ([Adam] -> ShowS) -> Show Adam
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Adam -> ShowS
showsPrec :: Int -> Adam -> ShowS
$cshow :: Adam -> String
show :: Adam -> String
$cshowList :: [Adam] -> ShowS
showList :: [Adam] -> ShowS
Show, (forall x. Adam -> Rep Adam x)
-> (forall x. Rep Adam x -> Adam) -> Generic Adam
forall x. Rep Adam x -> Adam
forall x. Adam -> Rep Adam x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Adam -> Rep Adam x
from :: forall x. Adam -> Rep Adam x
$cto :: forall x. Rep Adam x -> Adam
to :: forall x. Rep Adam x -> Adam
Generic)

instance NFData Adam

mkAdam ::
  Int ->
  Float ->
  Float ->
  [Parameter] ->
  Adam
mkAdam :: Int -> Float -> Float -> [Parameter] -> Adam
mkAdam Int
iter Float
beta1 Float
beta2 [Parameter]
parameters =
  Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam
    Float
beta1
    Float
beta2
    (Parameter -> Tensor
initZeros (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    (Parameter -> Tensor
initZeros (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Parameter]
parameters)
    Int
iter
  where
    initZeros :: Parameter -> Tensor
initZeros = Tensor -> Tensor
zerosLike (Tensor -> Tensor) -> (Parameter -> Tensor) -> Parameter -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter -> Tensor
toDependent

-- | Adam step
adam ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adam parameters - beta1, beta2, moments, iteration
  Adam ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adam)
adam :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adam {Float
Int
[Tensor]
beta1 :: Adam -> Float
beta2 :: Adam -> Float
m1 :: Adam -> [Tensor]
m2 :: Adam -> [Tensor]
iter :: Adam -> Int
beta1 :: Float
beta2 :: Float
m1 :: [Tensor]
m2 :: [Tensor]
iter :: Int
..} = ([Tensor]
parameters', Float -> Float -> [Tensor] -> [Tensor] -> Int -> Adam
Adam Float
beta1 Float
beta2 [Tensor]
m1' [Tensor]
m2' (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
  where
    -- decaying averages of 1st & 2nd moments
    f1 :: Tensor -> Tensor -> Tensor
f1 Tensor
m1 Tensor
dp = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta1 Tensor
m1 Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta1) Tensor
dp
    f2 :: Tensor -> Tensor -> Tensor
f2 Tensor
m2 Tensor
dp = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
beta2 Tensor
m2 Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
1 Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
beta2) (Tensor
dp Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
dp)
    -- force to prevent spine laziness. See https://github.com/hasktorch/hasktorch/pull/728
    m1' :: [Tensor]
m1' = [Tensor] -> [Tensor]
forall a. NFData a => a -> a
force ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f1 [Tensor]
m1 [Tensor]
gradients
    m2' :: [Tensor]
m2' = [Tensor] -> [Tensor]
forall a. NFData a => a -> a
force ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
f2 [Tensor]
m2 [Tensor]
gradients
    -- bias adjustment
    a :: a -> Tensor -> Tensor
a a
beta = a -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
divScalar (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
beta a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
    a1 :: [Tensor]
a1 = (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Tensor -> Tensor
forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta1) [Tensor]
m1'
    a2 :: [Tensor]
a2 = (Tensor -> Tensor) -> [Tensor] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Float -> Tensor -> Tensor
forall {a}. (Scalar a, Num a) => a -> Tensor -> Tensor
a Float
beta2) [Tensor]
m2'
    -- parameter update
    eps :: Tensor
eps = Tensor
1e-37
    update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
a1' Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt Tensor
a2' Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
eps)
    parameters' :: [Tensor]
parameters' = (Tensor -> Tensor -> Tensor -> Tensor)
-> [Tensor] -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
a1 [Tensor]
a2

instance Optimizer Adam where
  step :: Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
step = Tensor -> Gradients -> [Tensor] -> Adam -> ([Tensor], Adam)
adam

--
-- Adagrad
--

-- | State representation for Adagrad Optimizer
data Adagrad = Adagrad {Adagrad -> [Tensor]
gsum :: [Tensor]} -- sum of squared gradients
  deriving (Int -> Adagrad -> ShowS
[Adagrad] -> ShowS
Adagrad -> String
(Int -> Adagrad -> ShowS)
-> (Adagrad -> String) -> ([Adagrad] -> ShowS) -> Show Adagrad
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Adagrad -> ShowS
showsPrec :: Int -> Adagrad -> ShowS
$cshow :: Adagrad -> String
show :: Adagrad -> String
$cshowList :: [Adagrad] -> ShowS
showList :: [Adagrad] -> ShowS
Show)

-- | Adagrad step
adagrad ::
  -- | learning rate
  LearningRate ->
  -- | model parameter gradients
  Gradients ->
  -- | model parameters
  [Tensor] ->
  -- | adagrad parameters - gsum, iteration
  Adagrad ->
  -- | returns new parameters + updated adam parameters
  ([Tensor], Adagrad)
adagrad :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad Tensor
lr (Gradients [Tensor]
gradients) [Tensor]
parameters Adagrad {[Tensor]
gsum :: Adagrad -> [Tensor]
gsum :: [Tensor]
..} = ([Tensor]
parameters', [Tensor] -> Adagrad
Adagrad [Tensor]
gsum')
  where
    -- add gradient squared to running total
    f :: a -> a -> a
f a
gsum a
dp = a
gsum a -> a -> a
forall a. Num a => a -> a -> a
+ a
dp a -> a -> a
forall a. Num a => a -> a -> a
* a
dp
    gsum' :: [Tensor]
gsum' = (Tensor -> Tensor -> Tensor) -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
f [Tensor]
gsum [Tensor]
gradients

    -- parameter update
    eps :: Tensor
eps = Tensor
1e-37
    update :: Tensor -> Tensor -> Tensor -> Tensor
update Tensor
prevParam Tensor
a1' Tensor
a2' = Tensor
prevParam Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor
lr Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor
a1' Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ (Tensor -> Tensor
sqrt (Tensor
a2' Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor
eps))
    parameters' :: [Tensor]
parameters' = (Tensor -> Tensor -> Tensor -> Tensor)
-> [Tensor] -> [Tensor] -> [Tensor] -> [Tensor]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Tensor -> Tensor -> Tensor -> Tensor
update [Tensor]
parameters [Tensor]
gradients [Tensor]
gsum'

instance Optimizer Adagrad where
  step :: Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
step = Tensor -> Gradients -> [Tensor] -> Adagrad -> ([Tensor], Adagrad)
adagrad

-- | syntactic sugar for looping with foldM
foldLoop :: a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop :: forall a. a -> Int -> (a -> Int -> IO a) -> IO a
foldLoop a
x Int
count a -> Int -> IO a
block = (a -> Int -> IO a) -> a -> [Int] -> IO a
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM a -> Int -> IO a
block a
x [Int
1 .. Int
count]