未验证 提交 bd66eed5 编写于 作者: J Jeff Wang 提交者: GitHub

Trainer save load params (#10386)

* Load/save the params from the params_path

* Switch to use load_persistables and save_persistables

* Instaed of setup the executor to run program and scope. Pass the program to the load_persistables
上级 5812076e
......@@ -13,7 +13,9 @@
# limitations under the License.
import core
import framework
import executor
import io
__all__ = ['Inferencer', ]
......@@ -29,6 +31,15 @@ class Inferencer(object):
# 4. load params from param_path into scope
self.scope = core.Scope()
self.place = place
self.startup_program = framework.Program()
# TODO: generate the startup_program with network_func
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def infer(self, inputs):
# run self.program
......
......@@ -18,6 +18,7 @@ import framework
import executor
import data_feeder
import contextlib
import io
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
......@@ -93,8 +94,7 @@ class Trainer(object):
if param_path:
# load params from param_path into scope
# TODO(yuyang): This depends on parameters implementation.
pass
io.load_persistables(exe, dirname=param_path)
def dist_transpile_if_necessary(self, optimize_ops, params_grads):
if "PADDLE_TRAINING_ROLE" not in os.environ:
......@@ -172,7 +172,9 @@ class Trainer(object):
def save_params(self, param_path):
# reference: save_persistables in io.py
pass
exe = executor.Executor(self.place)
io.save_persistables(
exe, dirname=param_path, main_program=self.startup_program)
@staticmethod
def _check_and_get_place(place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册