提交 b8dd2986 编写于 作者: W wangmeng28

use save_persistables

上级 f08bbbdf
import os
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import reader
......@@ -134,7 +133,7 @@ def train(learning_rate,
exe.run(fluid.default_startup_program())
if init_model is not None:
fluid.io.load_params(exe, init_model)
fluid.io.load_persistables(exe, init_model)
train_reader = paddle.batch(reader.train(), batch_size=batch_size)
test_reader = paddle.batch(reader.test(), batch_size=batch_size)
......@@ -160,7 +159,9 @@ def train(learning_rate,
pass_id, pass_acc, test_pass_acc))
model_path = os.path.join(model_save_dir, str(pass_id))
fluid.io.save_inference_model(model_path, ['image'], [out], exe)
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册