diff --git a/01.fit_a_line/README.cn.md b/01.fit_a_line/README.cn.md index d4b1b9412fe95a50c756bbdfab5259df3a195004..9f29b5a2d398920529be00c077fe8ef3b1980cd0 100644 --- a/01.fit_a_line/README.cn.md +++ b/01.fit_a_line/README.cn.md @@ -177,80 +177,6 @@ PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读 feed_order=['x', 'y'] ``` -以及一个绘画器来进行绘制: - -```python -import six -import os - - -class PlotData(object): - def __init__(self): - self.step = [] - self.value = [] - - def append(self, step, value): - self.step.append(step) - self.value.append(value) - - def reset(self): - self.step = [] - self.value = [] - - -class Ploter(object): - def __init__(self, *args): - self.__args__ = args - self.__plot_data__ = {} - for title in args: - self.__plot_data__[title] = PlotData() - # demo in notebooks will use Ploter to plot figure, but when we convert - # the ipydb to py file for testing, the import of matplotlib will make the - # script crash. So we can use `export DISABLE_PLOT=True` to disable import - # these libs - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - from IPython import display - self.plt = plt - self.display = display - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def append(self, title, step, value): - assert isinstance(title, six.string_types) - assert title in self.__plot_data__ - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - data.append(step, value) - - def plot(self, path=None): - if self.__plot_is_disabled__(): - return - - titles = [] - for title in self.__args__: - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - if len(data.step) > 0: - titles.append(title) - self.plt.plot(data.step, data.value) - self.plt.legend(titles, loc='upper left') - if path is None: - self.display.clear_output(wait=True) - self.display.display(self.plt.gcf()) - else: - self.plt.savefig(path) - self.plt.gcf().clear() - - def reset(self): - for key in self.__plot_data__: - data = self.__plot_data__[key] - assert isinstance(data, PlotData) - data.reset() -``` - 除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件: ```python @@ -258,6 +184,8 @@ class Ploter(object): params_dirname = "fit_a_line.inference.model" # Plot data +from paddle.utils import Ploter + train_title = "Train cost" test_title = "Test cost" plot_cost = Ploter(train_title, test_title) diff --git a/01.fit_a_line/README.md b/01.fit_a_line/README.md index dd2fc04728beb11edfa23cac224781abb969696a..e1f05860fc64f581d19dcee3869b7fd2bf623b64 100644 --- a/01.fit_a_line/README.md +++ b/01.fit_a_line/README.md @@ -196,80 +196,6 @@ for loading the training data. A reader may return multiple columns, and we need feed_order=['x', 'y'] ``` -And a ploter to plot metrics: - -```python -import six -import os - - -class PlotData(object): - def __init__(self): - self.step = [] - self.value = [] - - def append(self, step, value): - self.step.append(step) - self.value.append(value) - - def reset(self): - self.step = [] - self.value = [] - - -class Ploter(object): - def __init__(self, *args): - self.__args__ = args - self.__plot_data__ = {} - for title in args: - self.__plot_data__[title] = PlotData() - # demo in notebooks will use Ploter to plot figure, but when we convert - # the ipydb to py file for testing, the import of matplotlib will make the - # script crash. So we can use `export DISABLE_PLOT=True` to disable import - # these libs - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - from IPython import display - self.plt = plt - self.display = display - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def append(self, title, step, value): - assert isinstance(title, six.string_types) - assert title in self.__plot_data__ - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - data.append(step, value) - - def plot(self, path=None): - if self.__plot_is_disabled__(): - return - - titles = [] - for title in self.__args__: - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - if len(data.step) > 0: - titles.append(title) - self.plt.plot(data.step, data.value) - self.plt.legend(titles, loc='upper left') - if path is None: - self.display.clear_output(wait=True) - self.display.display(self.plt.gcf()) - else: - self.plt.savefig(path) - self.plt.gcf().clear() - - def reset(self): - for key in self.__plot_data__: - data = self.__plot_data__[key] - assert isinstance(data, PlotData) - data.reset() -``` - Moreover, an event handler is provided to print the training progress: ```python @@ -277,6 +203,8 @@ Moreover, an event handler is provided to print the training progress: params_dirname = "fit_a_line.inference.model" # Plot data +from paddle.utils import Ploter + train_title = "Train cost" test_title = "Test cost" plot_cost = Ploter(train_title, test_title) diff --git a/01.fit_a_line/plot.py b/01.fit_a_line/plot.py deleted file mode 100644 index 847f7fa697a6811a8ecc6a0105864dc0e4d0e435..0000000000000000000000000000000000000000 --- a/01.fit_a_line/plot.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import six -import os - - -class PlotData(object): - def __init__(self): - self.step = [] - self.value = [] - - def append(self, step, value): - self.step.append(step) - self.value.append(value) - - def reset(self): - self.step = [] - self.value = [] - - -class Ploter(object): - def __init__(self, *args): - self.__args__ = args - self.__plot_data__ = {} - for title in args: - self.__plot_data__[title] = PlotData() - # demo in notebooks will use Ploter to plot figure, but when we convert - # the ipydb to py file for testing, the import of matplotlib will make the - # script crash. So we can use `export DISABLE_PLOT=True` to disable import - # these libs - self.__disable_plot__ = os.environ.get("DISABLE_PLOT") - if not self.__plot_is_disabled__(): - import matplotlib.pyplot as plt - from IPython import display - self.plt = plt - self.display = display - - def __plot_is_disabled__(self): - return self.__disable_plot__ == "True" - - def append(self, title, step, value): - assert isinstance(title, six.string_types) - assert title in self.__plot_data__ - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - data.append(step, value) - - def plot(self, path=None): - if self.__plot_is_disabled__(): - return - - titles = [] - for title in self.__args__: - data = self.__plot_data__[title] - assert isinstance(data, PlotData) - if len(data.step) > 0: - titles.append(title) - self.plt.plot(data.step, data.value) - self.plt.legend(titles, loc='upper left') - if path is None: - self.display.clear_output(wait=True) - self.display.display(self.plt.gcf()) - else: - self.plt.savefig(path) - self.plt.gcf().clear() - - def reset(self): - for key in self.__plot_data__: - data = self.__plot_data__[key] - assert isinstance(data, PlotData) - data.reset() diff --git a/01.fit_a_line/train.py b/01.fit_a_line/train.py index 7953eeeedbedab1e4d6488bb6044ffe485fb4f19..2503921089f35bdccc1200d8c8460c4ebdd35cb2 100644 --- a/01.fit_a_line/train.py +++ b/01.fit_a_line/train.py @@ -70,7 +70,7 @@ feed_order = ['x', 'y'] params_dirname = "fit_a_line.inference.model" # Plot data -from plot import Ploter +from paddle.utils import Ploter train_title = "Train cost" test_title = "Test cost"