diff --git a/python/paddle/v2/plot_curve.py b/python/paddle/v2/plot/plot_curve.py similarity index 60% rename from python/paddle/v2/plot_curve.py rename to python/paddle/v2/plot/plot_curve.py index 178506bbebdf3edaf78647e58178a2a6640cf9a1..0f62674cb2baad9e4ecd9f6655f7e2dc00173dc6 100644 --- a/python/paddle/v2/plot_curve.py +++ b/python/paddle/v2/plot/plot_curve.py @@ -1,5 +1,5 @@ -import matplotlib.pyplot as plt from IPython import display +import os class PlotCost(object): @@ -11,18 +11,29 @@ class PlotCost(object): self.train_costs = ([], []) self.test_costs = ([], []) + self.__disable_plot__ = os.environ.get("DISABLE_PLOT") + if not self.__plot_is_disabled__(): + import matplotlib.pyplot as plt + self.plt = plt + + def __plot_is_disabled__(self): + return self.__disable_plot__ == "True" + def plot(self): - plt.plot(*self.train_costs) - plt.plot(*self.test_costs) + if self.__plot_is_disabled__(): + return + + self.plt.plot(*self.train_costs) + self.plt.plot(*self.test_costs) title = [] if len(self.train_costs[0]) > 0: title.append('Train Cost') if len(self.test_costs[0]) > 0: title.append('Test Cost') - plt.legend(title, loc='upper left') + self.plt.legend(title, loc='upper left') display.clear_output(wait=True) - display.display(plt.gcf()) - plt.gcf().clear() + display.display(self.plt.gcf()) + self.plt.gcf().clear() def append_train_cost(self, step, cost): self.train_costs[0].append(step)