提交 11903b95 编写于 作者: N Nicky Chan 提交者: daminglu

Add plot and add more iterations to draw plot in jupyter (#546)

上级 54ace206
...@@ -405,6 +405,11 @@ For example, we can check the cost by `trainer.test` when `EndStepEvent` occurs ...@@ -405,6 +405,11 @@ For example, we can check the cost by `trainer.test` when `EndStepEvent` occurs
# Specify the directory path to save the parameters # Specify the directory path to save the parameters
params_dirname = "recommender_system.inference.model" params_dirname = "recommender_system.inference.model"
from paddle.v2.plot import Ploter
test_title = "Test cost"
plot_cost = Ploter(test_title)
def event_handler(event): def event_handler(event):
if isinstance(event, fluid.EndStepEvent): if isinstance(event, fluid.EndStepEvent):
avg_cost_set = trainer.test( avg_cost_set = trainer.test(
...@@ -412,12 +417,15 @@ def event_handler(event): ...@@ -412,12 +417,15 @@ def event_handler(event):
# get avg cost # get avg cost
avg_cost = np.array(avg_cost_set).mean() avg_cost = np.array(avg_cost_set).mean()
plot_cost.append(test_title, event.step, avg_cost_set[0])
plot_cost.plot()
print("avg_cost: %s" % avg_cost) print("avg_cost: %s" % avg_cost)
print('BatchID {0}, Test Loss {1:0.2}'.format(event.epoch + 1, print('BatchID {0}, Test Loss {1:0.2}'.format(event.epoch + 1,
float(avg_cost))) float(avg_cost)))
if float(avg_cost) < 4: if event.step == 20: # Adjust this number for accuracy
trainer.save_params(params_dirname) trainer.save_params(params_dirname)
trainer.stop() trainer.stop()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册