From 994b52fc2b2a061436729b94cae0e3a35ac17366 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 2 Mar 2018 17:53:35 +0800 Subject: [PATCH] Add layers for save/load op --- python/paddle/fluid/layers/nn.py | 70 +++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a10463b52..1acc74f75 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3206,7 +3206,7 @@ def one_hot(input, depth): operator. Args: - input(Tensor/LodTensor): A Tensor/LodTensor of indices, last dimension must be 1. + input(variable): A Tensor/LodTensor of indices, last dimension must be 1. depth(scalar): an interger defining the depth of the one hot dimension. Returns: @@ -3265,3 +3265,71 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): counter.stop_gradient = True return counter + + +def save(x, file_path, overwrite=True): + """ + Saves a variable as a file. + + Args: + x(variable): The Tensor/LoDTensor to be saved. + file_path(str): The file path where the variable will be saved. + overwrite(bool): Whether or not cover the given file when it has already + existed. If it's set 'False' and the file is existed, a runtime + error will be thrown. + """ + helper = LayerHelper("save", **locals()) + helper.append_op( + type="save", + inputs={"input": x}, + outputs={}, + args={"file_path": file_path, + "overwrite": overwrite}) + + +def save_combine(x, file_path, overwrite=True): + """ + Saves a variable as a file. + + Args: + x(list): A list of Tensor/LoDTensor to be saved together in a single file. + file_path(str): The file path where variables will be saved. + overwrite(bool): Whether or not cover the given file when it has already + existed. If it's set 'False' and the file is existed, a runtime + error will be thrown. + """ + helper = LayerHelper("save_combine", **locals()) + helper.append_op( + type="save_combine", + inputs={"input": x}, + outputs={}, + args={"file_path": file_path, + "overwrite": overwrite}) + + +def load(out, file_path): + """ + Args: + out(variable): The variable to be read from the disk file. + file_path(str): The path of the disk file. + """ + helper = LayerHelper("load", **locals()) + helper.append_op( + type="load", + inputs={}, + output={"Out": out}, + args={"file_path": file_path}) + + +def load_combine(out, file_path): + """ + Args: + out(list): The list of variables to be read from the disk file. + file_path(str): The path of the disk file. + """ + helper = LayerHelper("load_combine", **locals()) + helper.append_op( + type="load_combine", + inputs={}, + output={"Out": out}, + args={"file_path": file_path}) -- GitLab