From fb2b54f4fed25a2e1d335bee14af77a62b23fac3 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Sat, 1 Apr 2017 00:16:02 +0800 Subject: [PATCH] refine plot_curve --- python/paddle/v2/config_base.py | 3 +- python/paddle/v2/plot/__init__.py | 3 ++ python/paddle/v2/plot/plot.py | 65 +++++++++++++++++++++++++++++ python/paddle/v2/plot/plot_curve.py | 48 --------------------- 4 files changed, 70 insertions(+), 49 deletions(-) create mode 100644 python/paddle/v2/plot/__init__.py create mode 100644 python/paddle/v2/plot/plot.py delete mode 100644 python/paddle/v2/plot/plot_curve.py diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index 1ec1d7bbdf..8e2795d20f 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -14,8 +14,9 @@ import collections import re -from paddle.trainer_config_helpers.default_decorators import wrap_name_default + import paddle.trainer_config_helpers as conf_helps +from paddle.trainer_config_helpers.default_decorators import wrap_name_default class LayerType(type): diff --git a/python/paddle/v2/plot/__init__.py b/python/paddle/v2/plot/__init__.py new file mode 100644 index 0000000000..324f28bb65 --- /dev/null +++ b/python/paddle/v2/plot/__init__.py @@ -0,0 +1,3 @@ +from plot import Plot + +__all__ = ['Plot'] diff --git a/python/paddle/v2/plot/plot.py b/python/paddle/v2/plot/plot.py new file mode 100644 index 0000000000..1c56294368 --- /dev/null +++ b/python/paddle/v2/plot/plot.py @@ -0,0 +1,65 @@ +from IPython import display +import os + + +class PlotData(object): + def __init__(self): + self.step = [] + self.value = [] + + def append(self, step, value): + self.step.append(step) + self.value.append(value) + + def reset(self): + self.step = [] + self.value = [] + + +class Plot(object): + def __init__(self, *args): + self.args = args + self.__plot_data__ = {} + for title in args: + self.__plot_data__[title] = PlotData() + self.__disable_plot__ = os.environ.get("DISABLE_PLOT") + if not self.__plot_is_disabled__(): + import matplotlib.pyplot as plt + self.plt = plt + + def __plot_is_disabled__(self): + return self.__disable_plot__ == "True" + + def append(self, title, step, value): + assert isinstance(title, basestring) + assert self.__plot_data__.has_key(title) + data = self.__plot_data__[title] + assert isinstance(data, PlotData) + data.append(step, value) + + def plot(self): + if self.__plot_is_disabled__(): + return + + titles = [] + for title in self.args: + data = self.__plot_data__[title] + assert isinstance(data, PlotData) + if len(data.step) > 0: + titles.append(title) + self.plt.plot(data.step, data.value) + self.plt.legend(titles, loc='upper left') + display.clear_output(wait=True) + display.display(self.plt.gcf()) + self.plt.gcf().clear() + + def reset(self): + self.__plot_data__ = [] + +if __name__ == '__main__': + title = "cost" + plot_test = Plot(title) + plot_test.append(title, 1, 1) + plot_test.append(title, 2, 2) + for k, v in plot_test.__plot_data__.iteritems(): + print k, v.step, v.value diff --git a/python/paddle/v2/plot/plot_curve.py b/python/paddle/v2/plot/plot_curve.py deleted file mode 100644 index 0f62674cb2..0000000000 --- a/python/paddle/v2/plot/plot_curve.py +++ /dev/null @@ -1,48 +0,0 @@ -from IPython import display -import os - - -class PlotCost(object): - """ - append train and test cost in event_handle and then call plot. - """ - - def __init__(self): - self.train_costs = ([], []) - self.test_costs = ([], []) - - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - self.plt = plt - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def plot(self): - if self.__plot_is_disabled__(): - return - - self.plt.plot(*self.train_costs) - self.plt.plot(*self.test_costs) - title = [] - if len(self.train_costs[0]) > 0: - title.append('Train Cost') - if len(self.test_costs[0]) > 0: - title.append('Test Cost') - self.plt.legend(title, loc='upper left') - display.clear_output(wait=True) - display.display(self.plt.gcf()) - self.plt.gcf().clear() - - def append_train_cost(self, step, cost): - self.train_costs[0].append(step) - self.train_costs[1].append(cost) - - def append_test_cost(self, step, cost): - self.test_costs[0].append(step) - self.test_costs[1].append(cost) - - def reset(self): - self.train_costs = ([], []) - self.test_costs = ([], []) -- GitLab