diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 25526bf409cf82f26979a84700ce948ac969df0c..848719cddda81b5e19c2176172ce8bdfa23f6e16 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -29,11 +29,12 @@ import inference import networks import py_paddle.swig_paddle as api import minibatch +import plot_curve __all__ = [ 'optimizer', 'layer', 'activation', 'parameters', 'init', 'trainer', 'event', 'data_type', 'attr', 'pooling', 'data_feeder', 'dataset', 'reader', - 'topology', 'networks', 'infer' + 'topology', 'networks', 'infer', 'plot_curve' ] diff --git a/python/paddle/v2/plot_curve.py b/python/paddle/v2/plot_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..9d24a87442b415f61de74d8f2444e06bf817c244 --- /dev/null +++ b/python/paddle/v2/plot_curve.py @@ -0,0 +1,36 @@ +import matplotlib.pyplot as plt +from IPython import display + + +class PlotCost(object): + """ + append train and test cost in event_handle and then call plot. + """ + def __init__(self): + self.train_costs = ([], []) + self.test_costs = ([], []) + + def plot(self): + plt.plot(*self.train_costs) + plt.plot(*self.test_costs) + title = [] + if len(self.train_costs[0]) > 0: + title.append('Train Cost') + if len(self.test_costs[0]) > 0: + title.append('Test Cost') + plt.legend(title, loc='upper left') + display.clear_output(wait=True) + display.display(plt.gcf()) + plt.gcf().clear() + + def append_train_cost(self, step, cost): + self.train_costs[0].append(step) + self.train_costs[1].append(cost) + + def append_test_cost(self, step, cost): + self.test_costs[0].append(step) + self.test_costs[1].append(cost) + + def reset(self): + self.train_costs = ([], []) + self.test_costs = ([], [])