未验证 提交 bd66eed5 编写于 作者: J Jeff Wang 提交者: GitHub

Trainer save load params (#10386)

* Load/save the params from the params_path

* Switch to use load_persistables and save_persistables

* Instaed of setup the executor to run program and scope. Pass the program to the load_persistables
上级 5812076e
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import core import core
import framework
import executor
import io
__all__ = ['Inferencer', ] __all__ = ['Inferencer', ]
...@@ -29,6 +31,15 @@ class Inferencer(object): ...@@ -29,6 +31,15 @@ class Inferencer(object):
# 4. load params from param_path into scope # 4. load params from param_path into scope
self.scope = core.Scope() self.scope = core.Scope()
self.place = place self.place = place
self.startup_program = framework.Program()
# TODO: generate the startup_program with network_func
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def infer(self, inputs): def infer(self, inputs):
# run self.program # run self.program
......
...@@ -18,6 +18,7 @@ import framework ...@@ -18,6 +18,7 @@ import framework
import executor import executor
import data_feeder import data_feeder
import contextlib import contextlib
import io
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module # optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module import optimizer as opt_module
...@@ -93,8 +94,7 @@ class Trainer(object): ...@@ -93,8 +94,7 @@ class Trainer(object):
if param_path: if param_path:
# load params from param_path into scope # load params from param_path into scope
# TODO(yuyang): This depends on parameters implementation. io.load_persistables(exe, dirname=param_path)
pass
def dist_transpile_if_necessary(self, optimize_ops, params_grads): def dist_transpile_if_necessary(self, optimize_ops, params_grads):
if "PADDLE_TRAINING_ROLE" not in os.environ: if "PADDLE_TRAINING_ROLE" not in os.environ:
...@@ -172,7 +172,9 @@ class Trainer(object): ...@@ -172,7 +172,9 @@ class Trainer(object):
def save_params(self, param_path): def save_params(self, param_path):
# reference: save_persistables in io.py # reference: save_persistables in io.py
pass exe = executor.Executor(self.place)
io.save_persistables(
exe, dirname=param_path, main_program=self.startup_program)
@staticmethod @staticmethod
def _check_and_get_place(place): def _check_and_get_place(place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册