From bd66eed50ad2eecf5192bfc15d6a40d5123e9f6d Mon Sep 17 00:00:00 2001 From: Jeff Wang Date: Fri, 4 May 2018 11:30:49 -0700 Subject: [PATCH] Trainer save load params (#10386) * Load/save the params from the params_path * Switch to use load_persistables and save_persistables * Instaed of setup the executor to run program and scope. Pass the program to the load_persistables --- python/paddle/fluid/inferencer.py | 13 ++++++++++++- python/paddle/fluid/trainer.py | 8 +++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index 58e027695a..b38526bc57 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -13,7 +13,9 @@ # limitations under the License. import core - +import framework +import executor +import io __all__ = ['Inferencer', ] @@ -29,6 +31,15 @@ class Inferencer(object): # 4. load params from param_path into scope self.scope = core.Scope() self.place = place + self.startup_program = framework.Program() + # TODO: generate the startup_program with network_func + + exe = executor.Executor(place) + exe.run(self.startup_program, scope=self.scope) + + if param_path: + # load params from param_path into scope + io.load_persistables(exe, dirname=param_path) def infer(self, inputs): # run self.program diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 5385d798ea..8252592c8c 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -18,6 +18,7 @@ import framework import executor import data_feeder import contextlib +import io # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module import optimizer as opt_module @@ -93,8 +94,7 @@ class Trainer(object): if param_path: # load params from param_path into scope - # TODO(yuyang): This depends on parameters implementation. - pass + io.load_persistables(exe, dirname=param_path) def dist_transpile_if_necessary(self, optimize_ops, params_grads): if "PADDLE_TRAINING_ROLE" not in os.environ: @@ -172,7 +172,9 @@ class Trainer(object): def save_params(self, param_path): # reference: save_persistables in io.py - pass + exe = executor.Executor(self.place) + io.save_persistables( + exe, dirname=param_path, main_program=self.startup_program) @staticmethod def _check_and_get_place(place): -- GitLab