{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Torch.Typed.Serialize where
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Serialize as S
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import Torch.Typed.Tensor
import Torch.Typed.Parameter
import Torch.Typed.NN
import Torch.Typed.Autograd
save ::
forall tensors.
ATen.Castable (HList tensors) [D.ATenTensor] =>
HList tensors ->
FilePath ->
IO ()
save :: forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
HList tensors -> FilePath -> IO ()
save = (ForeignPtr TensorList -> FilePath -> IO ())
-> HList tensors -> FilePath -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
ATen.cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save
load ::
forall tensors.
ATen.Castable (HList tensors) [D.ATenTensor] =>
FilePath ->
IO (HList tensors)
load :: forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
FilePath -> IO (HList tensors)
load = (FilePath -> IO (ForeignPtr TensorList))
-> FilePath -> IO (HList tensors)
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
ATen.cast1 FilePath -> IO (ForeignPtr TensorList)
S.load
saveParameters ::
forall model parameters tensors dtype device.
( Parameterized model,
parameters ~ Parameters model,
HMap' ToDependent parameters tensors,
HMapM' IO MakeIndependent tensors parameters,
HFoldrM IO TensorListFold [D.ATenTensor] tensors [D.ATenTensor],
Apply TensorListUnfold [D.ATenTensor] (HUnfoldMRes IO [D.ATenTensor] tensors),
HUnfoldM IO TensorListUnfold (HUnfoldMRes IO [D.ATenTensor] tensors) tensors
) =>
model ->
FilePath ->
IO ()
saveParameters :: forall model (parameters :: [*]) (tensors :: [*]) dtype device.
(Parameterized model, parameters ~ Parameters model,
HMap' ToDependent parameters tensors,
HMapM' IO MakeIndependent tensors parameters,
HFoldrM IO TensorListFold [ATenTensor] tensors [ATenTensor],
Apply
TensorListUnfold
[ATenTensor]
(HUnfoldMRes IO [ATenTensor] tensors),
HUnfoldM
IO
TensorListUnfold
(HUnfoldMRes IO [ATenTensor] tensors)
tensors) =>
model -> FilePath -> IO ()
saveParameters model
model FilePath
filePath = HList tensors -> FilePath -> IO ()
forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
HList tensors -> FilePath -> IO ()
save (ToDependent -> HList parameters -> HList tensors
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList parameters -> HList tensors)
-> (model -> HList parameters) -> model -> HList tensors
forall b c a. (b -> c) -> (a -> b) -> a -> c
. model -> HList parameters
model -> HList (Parameters model)
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (model -> HList tensors) -> model -> HList tensors
forall a b. (a -> b) -> a -> b
$ model
model) FilePath
filePath
loadParameters ::
forall model parameters tensors dtype device.
( Parameterized model,
parameters ~ Parameters model,
HMap' ToDependent parameters tensors,
HMapM' IO MakeIndependent tensors parameters,
HFoldrM IO TensorListFold [D.ATenTensor] tensors [D.ATenTensor],
Apply TensorListUnfold [D.ATenTensor] (HUnfoldMRes IO [D.ATenTensor] tensors),
HUnfoldM IO TensorListUnfold (HUnfoldMRes IO [D.ATenTensor] tensors) tensors
) =>
model ->
FilePath ->
IO model
loadParameters :: forall model (parameters :: [*]) (tensors :: [*]) dtype device.
(Parameterized model, parameters ~ Parameters model,
HMap' ToDependent parameters tensors,
HMapM' IO MakeIndependent tensors parameters,
HFoldrM IO TensorListFold [ATenTensor] tensors [ATenTensor],
Apply
TensorListUnfold
[ATenTensor]
(HUnfoldMRes IO [ATenTensor] tensors),
HUnfoldM
IO
TensorListUnfold
(HUnfoldMRes IO [ATenTensor] tensors)
tensors) =>
model -> FilePath -> IO model
loadParameters model
model FilePath
filePath = do
tensors <- forall (tensors :: [*]).
Castable (HList tensors) [ATenTensor] =>
FilePath -> IO (HList tensors)
load @tensors FilePath
filePath
params <- hmapM' MakeIndependent tensors
pure $ replaceParameters model params
loadParametersWithSpec ::
forall spec model parameters tensors dtype device.
( Randomizable spec model,
Parameterized model,
parameters ~ Parameters model,
HMap' ToDependent parameters tensors,
HMapM' IO MakeIndependent tensors parameters,
HFoldrM IO TensorListFold [D.ATenTensor] tensors [D.ATenTensor],
Apply TensorListUnfold [D.ATenTensor] (HUnfoldMRes IO [D.ATenTensor] tensors),
HUnfoldM IO TensorListUnfold (HUnfoldMRes IO [D.ATenTensor] tensors) tensors
) =>
spec ->
FilePath ->
IO model
loadParametersWithSpec :: forall spec model (parameters :: [*]) (tensors :: [*]) dtype
device.
(Randomizable spec model, Parameterized model,
parameters ~ Parameters model,
HMap' ToDependent parameters tensors,
HMapM' IO MakeIndependent tensors parameters,
HFoldrM IO TensorListFold [ATenTensor] tensors [ATenTensor],
Apply
TensorListUnfold
[ATenTensor]
(HUnfoldMRes IO [ATenTensor] tensors),
HUnfoldM
IO
TensorListUnfold
(HUnfoldMRes IO [ATenTensor] tensors)
tensors) =>
spec -> FilePath -> IO model
loadParametersWithSpec spec
spec FilePath
filePath = do
model <- spec -> IO model
forall spec f. Randomizable spec f => spec -> IO f
sample spec
spec
tensors <- load @tensors filePath
params <- hmapM' MakeIndependent tensors
pure $ replaceParameters model params