{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module Torch.Typed.Serialize where

import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Serialize as S
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import Torch.Typed.Tensor
import Torch.Typed.Parameter
import Torch.Typed.NN
import Torch.Typed.Autograd

-- | save list of tensors to file
save ::
  forall tensors.
  ATen.Castable (HList tensors) [D.ATenTensor] =>
  -- | list of input tensors
  HList tensors ->
  -- | file
  FilePath ->
  IO ()
save :: forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
HList tensors -> FilePath -> IO ()
save = (ForeignPtr TensorList -> FilePath -> IO ())
-> HList tensors -> FilePath -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save

-- | load list of tensors from file
load ::
  forall tensors.
  ATen.Castable (HList tensors) [D.ATenTensor] =>
  -- | file
  FilePath ->
  IO (HList tensors)
load :: forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
FilePath -> IO (HList tensors)
load = (FilePath -> IO (ForeignPtr TensorList))
-> FilePath -> IO (HList tensors)
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 FilePath -> IO (ForeignPtr TensorList)
S.load

saveParameters ::
  forall model parameters tensors dtype device.
  ( Parameterized model,
    parameters ~ Parameters model,
    HMap' ToDependent parameters tensors,
    HMapM' IO MakeIndependent tensors parameters,
    HFoldrM IO TensorListFold [D.ATenTensor] tensors [D.ATenTensor],
    Apply TensorListUnfold [D.ATenTensor] (HUnfoldMRes IO [D.ATenTensor] tensors),
    HUnfoldM IO TensorListUnfold (HUnfoldMRes IO [D.ATenTensor] tensors) tensors
  ) =>
  model ->
  FilePath ->
  IO ()
saveParameters :: forall model (parameters :: [*]) (tensors :: [*]) dtype device.
(Parameterized model, parameters ~ Parameters model,
 HMap' ToDependent parameters tensors,
 HMapM' IO MakeIndependent tensors parameters,
 HFoldrM IO TensorListFold [ATenTensor] tensors [ATenTensor],
 Apply
   TensorListUnfold
   [ATenTensor]
   (HUnfoldMRes IO [ATenTensor] tensors),
 HUnfoldM
   IO
   TensorListUnfold
   (HUnfoldMRes IO [ATenTensor] tensors)
   tensors) =>
model -> FilePath -> IO ()
saveParameters model
model FilePath
filePath = HList tensors -> FilePath -> IO ()
forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
HList tensors -> FilePath -> IO ()
save (ToDependent -> HList parameters -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList parameters -> HList tensors)
-> (model -> HList parameters) -> model -> HList tensors
forall b c a. (b -> c) -> (a -> b) -> a -> c
. model -> HList parameters
model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (model -> HList tensors) -> model -> HList tensors
forall a b. (a -> b) -> a -> b
$ model
model) FilePath
filePath

loadParameters ::
  forall model parameters tensors dtype device.
  ( Parameterized model,
    parameters ~ Parameters model,
    HMap' ToDependent parameters tensors,
    HMapM' IO MakeIndependent tensors parameters,
    HFoldrM IO TensorListFold [D.ATenTensor] tensors [D.ATenTensor],
    Apply TensorListUnfold [D.ATenTensor] (HUnfoldMRes IO [D.ATenTensor] tensors),
    HUnfoldM IO TensorListUnfold (HUnfoldMRes IO [D.ATenTensor] tensors) tensors
  ) =>
  model ->
  FilePath ->
  IO model
loadParameters :: forall model (parameters :: [*]) (tensors :: [*]) dtype device.
(Parameterized model, parameters ~ Parameters model,
 HMap' ToDependent parameters tensors,
 HMapM' IO MakeIndependent tensors parameters,
 HFoldrM IO TensorListFold [ATenTensor] tensors [ATenTensor],
 Apply
   TensorListUnfold
   [ATenTensor]
   (HUnfoldMRes IO [ATenTensor] tensors),
 HUnfoldM
   IO
   TensorListUnfold
   (HUnfoldMRes IO [ATenTensor] tensors)
   tensors) =>
model -> FilePath -> IO model
loadParameters model
model FilePath
filePath = do
  tensors <- forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
FilePath -> IO (HList tensors)
load @tensors FilePath
filePath
  params <- hmapM' MakeIndependent tensors
  pure $ replaceParameters model params

loadParametersWithSpec ::
  forall spec model parameters tensors dtype device.
  ( Randomizable spec model,
    Parameterized model,
    parameters ~ Parameters model,
    HMap' ToDependent parameters tensors,
    HMapM' IO MakeIndependent tensors parameters,
    HFoldrM IO TensorListFold [D.ATenTensor] tensors [D.ATenTensor],
    Apply TensorListUnfold [D.ATenTensor] (HUnfoldMRes IO [D.ATenTensor] tensors),
    HUnfoldM IO TensorListUnfold (HUnfoldMRes IO [D.ATenTensor] tensors) tensors
  ) =>
  spec ->
  FilePath ->
  IO model
loadParametersWithSpec :: forall spec model (parameters :: [*]) (tensors :: [*]) dtype
       device.
(Randomizable spec model, Parameterized model,
 parameters ~ Parameters model,
 HMap' ToDependent parameters tensors,
 HMapM' IO MakeIndependent tensors parameters,
 HFoldrM IO TensorListFold [ATenTensor] tensors [ATenTensor],
 Apply
   TensorListUnfold
   [ATenTensor]
   (HUnfoldMRes IO [ATenTensor] tensors),
 HUnfoldM
   IO
   TensorListUnfold
   (HUnfoldMRes IO [ATenTensor] tensors)
   tensors) =>
spec -> FilePath -> IO model
loadParametersWithSpec spec
spec FilePath
filePath = do
  model <- spec -> IO model
forall spec f. Randomizable spec f => spec -> IO f
sample spec
spec
  tensors <- load @tensors filePath
  params <- hmapM' MakeIndependent tensors
  pure $ replaceParameters model params