未验证 提交 eec77713 编写于 作者: A acosta123 提交者: GitHub

Update DyGraph_en.md

上级 df76b93f
...@@ -311,7 +311,7 @@ Please refer to contents in [PaddleBook](https://github.com/PaddlePaddle/book/tr ...@@ -311,7 +311,7 @@ Please refer to contents in [PaddleBook](https://github.com/PaddlePaddle/book/tr
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size= BATCH_SIZE, drop_last=True) paddle.dataset.mnist.train(), batch_size= BATCH_SIZE, drop_last=True)
dy_param_init_value = {}
np.set_printoptions(precision=3, suppress=True) np.set_printoptions(precision=3, suppress=True)
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
...@@ -331,9 +331,6 @@ Please refer to contents in [PaddleBook](https://github.com/PaddlePaddle/book/tr ...@@ -331,9 +331,6 @@ Please refer to contents in [PaddleBook](https://github.com/PaddlePaddle/book/tr
dy_out = avg_loss.numpy() dy_out = avg_loss.numpy()
if epoch == 0 and batch_id == 0:
for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy()
avg_loss.backward() avg_loss.backward()
sgd.minimize(avg_loss) sgd.minimize(avg_loss)
...@@ -388,11 +385,11 @@ Please refer to contents in [PaddleBook](https://github.com/PaddlePaddle/book/tr ...@@ -388,11 +385,11 @@ Please refer to contents in [PaddleBook](https://github.com/PaddlePaddle/book/tr

In model traning, you can use ` fluid.dygraph.save_persistables(your_model_object.state_dict(), "save_dir")` to save all model parameters in `your_model_object`. And you can define Python Dictionary introduction of "parameter name" - "parameter object" that needs to be saved yourself. 
In model traning, you can use ` fluid.dygraph.save_persistables(your_model_object.state_dict(), "save_dir")` to save all model parameters in `your_model_object`. And you can define Python Dictionary introduction of "parameter name" - "parameter object" that needs to be saved yourself.
Or use `your_modle_object.load_dict( Or use `your_modle_object.load_dict(fluid.dygraph.load_persistables("save_dir"))` interface to recover saved model parameters to continue training.
fluid.dygraph.load_persistables(your_model_object.state_dict(), "save_dir"))` interface to recover saved model parameters to continue training.
The following codes show how to save parameters and read saved parameters to continue training in the "Handwriting Digit Recognition" task. The following codes show how to save parameters and read saved parameters to continue training in the "Handwriting Digit Recognition" task.
dy_param_init_value={}
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array( dy_x_data = np.array(
...@@ -419,7 +416,7 @@ The following codes show how to save parameters and read saved parameters to con ...@@ -419,7 +416,7 @@ The following codes show how to save parameters and read saved parameters to con
for param in mnist.parameters(): for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy() dy_param_init_value[param.name] = param.numpy()
mnist.load_dict(fluid.dygraph.load_persistables(mnist.state_dict(), "save_dir")) mnist.load_dict(fluid.dygraph.load_persistables("save_dir"))
restore = mnist.parameters() restore = mnist.parameters()
# check save and load # check save and load
success = True success = True
...@@ -478,7 +475,7 @@ In the second `fluid.dygraph.guard()` context we can use previously saved `check ...@@ -478,7 +475,7 @@ In the second `fluid.dygraph.guard()` context we can use previously saved `check
mnist.train() mnist.train()
print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(epoch, test_cost, test_acc)) print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(epoch, test_cost, test_acc))
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir") fluid.dygraph.save_persistables("save_dir")
print("checkpoint saved") print("checkpoint saved")
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -551,17 +548,7 @@ Take the "Handwriting Digit Recognition" in the last step for example, the same ...@@ -551,17 +548,7 @@ Take the "Handwriting Digit Recognition" in the last step for example, the same
avg_loss = fluid.layers.mean(loss) avg_loss = fluid.layers.mean(loss)
sgd.minimize(avg_loss) sgd.minimize(avg_loss)
# initialize params and fetch them out = exe.run(fluid.default_startup_program())
static_param_init_value = {}
static_param_name_list = []
for param in mnist.parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i]
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
...@@ -572,18 +559,13 @@ Take the "Handwriting Digit Recognition" in the last step for example, the same ...@@ -572,18 +559,13 @@ Take the "Handwriting Digit Recognition" in the last step for example, the same
[x[1] for x in data]).astype('int64').reshape([BATCH_SIZE, 1]) [x[1] for x in data]).astype('int64').reshape([BATCH_SIZE, 1])
fetch_list = [avg_loss.name] fetch_list = [avg_loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run( out = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={"pixel": static_x_data, feed={"pixel": static_x_data,
"label": y_data}, "label": y_data},
fetch_list=fetch_list)
static_param_value = {}
static_out = out[0] static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[
i]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册