From 226ca797ce07739fcd21b422370f932eb364ad33 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 28 Mar 2017 13:21:38 +0800 Subject: [PATCH] add plot cost in v2 api --- python/paddle/v2/__init__.py | 3 ++- python/paddle/v2/plot_curve.py | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/plot_curve.py diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 25526bf409c..848719cddda 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -29,11 +29,12 @@ import inference import networks import py_paddle.swig_paddle as api import minibatch +import plot_curve __all__ = [ 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader', - 'topology', 'networks', 'infer' + 'topology', 'networks', 'infer', 'plot_curve' ] diff --git a/python/paddle/v2/plot_curve.py b/python/paddle/v2/plot_curve.py new file mode 100644 index 00000000000..9d24a87442b --- /dev/null +++ b/python/paddle/v2/plot_curve.py @@ -0,0 +1,36 @@ +import matplotlib.pyplot as plt +from IPython import display + + +class PlotCost(object): + """ + append train and test cost in event_handle and then call plot. + """ + def __init__(self): + self.train_costs = ([], []) + self.test_costs = ([], []) + + def plot(self): + plt.plot(*self.train_costs) + plt.plot(*self.test_costs) + title = [] + if len(self.train_costs[0]) > 0: + title.append('Train Cost') + if len(self.test_costs[0]) > 0: + title.append('Test Cost') + plt.legend(title, loc='upper left') + display.clear_output(wait=True) + display.display(plt.gcf()) + plt.gcf().clear() + + def append_train_cost(self, step, cost): + self.train_costs[0].append(step) + self.train_costs[1].append(cost) + + def append_test_cost(self, step, cost): + self.test_costs[0].append(step) + self.test_costs[1].append(cost) + + def reset(self): + self.train_costs = ([], []) + self.test_costs = ([], []) -- GitLab