From 3d5b8e614895eebd2a2153383b6abaed793e2434 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 8 Oct 2018 15:56:46 +0800 Subject: [PATCH] Polish code --- 01.fit_a_line/index.cn.html | 76 +------------------------------------ 01.fit_a_line/index.html | 76 +------------------------------------ 2 files changed, 4 insertions(+), 148 deletions(-) diff --git a/01.fit_a_line/index.cn.html b/01.fit_a_line/index.cn.html index c6ad4e2..4d33d00 100644 --- a/01.fit_a_line/index.cn.html +++ b/01.fit_a_line/index.cn.html @@ -219,80 +219,6 @@ PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读 feed_order=['x', 'y'] ``` -以及一个绘画器来进行绘制: - -```python -import six -import os - - -class PlotData(object): - def __init__(self): - self.step = [] - self.value = [] - - def append(self, step, value): - self.step.append(step) - self.value.append(value) - - def reset(self): - self.step = [] - self.value = [] - - -class Ploter(object): - def __init__(self, *args): - self.__args__ = args - self.__plot_data__ = {} - for title in args: - self.__plot_data__[title] = PlotData() - # demo in notebooks will use Ploter to plot figure, but when we convert - # the ipydb to py file for testing, the import of matplotlib will make the - # script crash. So we can use `export DISABLE_PLOT=True` to disable import - # these libs - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - from IPython import display - self.plt = plt - self.display = display - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def append(self, title, step, value): - assert isinstance(title, six.string_types) - assert title in self.__plot_data__ - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - data.append(step, value) - - def plot(self, path=None): - if self.__plot_is_disabled__(): - return - - titles = [] - for title in self.__args__: - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - if len(data.step) > 0: - titles.append(title) - self.plt.plot(data.step, data.value) - self.plt.legend(titles, loc='upper left') - if path is None: - self.display.clear_output(wait=True) - self.display.display(self.plt.gcf()) - else: - self.plt.savefig(path) - self.plt.gcf().clear() - - def reset(self): - for key in self.__plot_data__: - data = self.__plot_data__[key] - assert isinstance(data, PlotData) - data.reset() -``` - 除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件: ```python @@ -300,6 +226,8 @@ class Ploter(object): params_dirname = "fit_a_line.inference.model" # Plot data +from paddle.utils import Ploter + train_title = "Train cost" test_title = "Test cost" plot_cost = Ploter(train_title, test_title) diff --git a/01.fit_a_line/index.html b/01.fit_a_line/index.html index d17a3c2..587c728 100644 --- a/01.fit_a_line/index.html +++ b/01.fit_a_line/index.html @@ -238,80 +238,6 @@ for loading the training data. A reader may return multiple columns, and we need feed_order=['x', 'y'] ``` -And a ploter to plot metrics: - -```python -import six -import os - - -class PlotData(object): - def __init__(self): - self.step = [] - self.value = [] - - def append(self, step, value): - self.step.append(step) - self.value.append(value) - - def reset(self): - self.step = [] - self.value = [] - - -class Ploter(object): - def __init__(self, *args): - self.__args__ = args - self.__plot_data__ = {} - for title in args: - self.__plot_data__[title] = PlotData() - # demo in notebooks will use Ploter to plot figure, but when we convert - # the ipydb to py file for testing, the import of matplotlib will make the - # script crash. So we can use `export DISABLE_PLOT=True` to disable import - # these libs - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - from IPython import display - self.plt = plt - self.display = display - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def append(self, title, step, value): - assert isinstance(title, six.string_types) - assert title in self.__plot_data__ - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - data.append(step, value) - - def plot(self, path=None): - if self.__plot_is_disabled__(): - return - - titles = [] - for title in self.__args__: - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - if len(data.step) > 0: - titles.append(title) - self.plt.plot(data.step, data.value) - self.plt.legend(titles, loc='upper left') - if path is None: - self.display.clear_output(wait=True) - self.display.display(self.plt.gcf()) - else: - self.plt.savefig(path) - self.plt.gcf().clear() - - def reset(self): - for key in self.__plot_data__: - data = self.__plot_data__[key] - assert isinstance(data, PlotData) - data.reset() -``` - Moreover, an event handler is provided to print the training progress: ```python @@ -319,6 +245,8 @@ Moreover, an event handler is provided to print the training progress: params_dirname = "fit_a_line.inference.model" # Plot data +from paddle.utils import Ploter + train_title = "Train cost" test_title = "Test cost" plot_cost = Ploter(train_title, test_title) -- GitLab