diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index 58e027695a7100245dd424583e2cedeed3d165e6..b38526bc574a103ece86aecdecf06b0bcfd6cad0 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -13,7 +13,9 @@ # limitations under the License. import core - +import framework +import executor +import io __all__ = ['Inferencer', ] @@ -29,6 +31,15 @@ class Inferencer(object): # 4. load params from param_path into scope self.scope = core.Scope() self.place = place + self.startup_program = framework.Program() + # TODO: generate the startup_program with network_func + + exe = executor.Executor(place) + exe.run(self.startup_program, scope=self.scope) + + if param_path: + # load params from param_path into scope + io.load_persistables(exe, dirname=param_path) def infer(self, inputs): # run self.program diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 5385d798ea848e99edfa5fc7344116e21e31dc4d..8252592c8ce0ea0a9959f882170d42bdc74e996a 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -18,6 +18,7 @@ import framework import executor import data_feeder import contextlib +import io # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module import optimizer as opt_module @@ -93,8 +94,7 @@ class Trainer(object): if param_path: # load params from param_path into scope - # TODO(yuyang): This depends on parameters implementation. - pass + io.load_persistables(exe, dirname=param_path) def dist_transpile_if_necessary(self, optimize_ops, params_grads): if "PADDLE_TRAINING_ROLE" not in os.environ: @@ -172,7 +172,9 @@ class Trainer(object): def save_params(self, param_path): # reference: save_persistables in io.py - pass + exe = executor.Executor(self.place) + io.save_persistables( + exe, dirname=param_path, main_program=self.startup_program) @staticmethod def _check_and_get_place(place):