diff --git a/python/paddle/v2/plot/plot.py b/python/paddle/v2/plot/plot.py index 6f7bd039b07db4832295c2374293bffa588eb4ef..c18e63dd5f60481ba804738a6a9238dfea35d9f3 100644 --- a/python/paddle/v2/plot/plot.py +++ b/python/paddle/v2/plot/plot.py @@ -56,7 +56,7 @@ class Ploter(object): assert isinstance(data, PlotData) data.append(step, value) - def plot(self): + def plot(self, path=None): if self.__plot_is_disabled__(): return @@ -68,8 +68,11 @@ class Ploter(object): titles.append(title) self.plt.plot(data.step, data.value) self.plt.legend(titles, loc='upper left') - self.display.clear_output(wait=True) - self.display.display(self.plt.gcf()) + if path is None: + self.display.clear_output(wait=True) + self.display.display(self.plt.gcf()) + else: + self.plt.savefig(path) self.plt.gcf().clear() def reset(self):