提交 3d5b8e61 编写于 作者: M minqiyang

Polish code

上级 19320b31
...@@ -219,80 +219,6 @@ PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读 ...@@ -219,80 +219,6 @@ PaddlePaddle提供了读取数据者发生器机制来读取训练数据。读
feed_order=['x', 'y'] feed_order=['x', 'y']
``` ```
以及一个绘画器来进行绘制:
```python
import six
import os
class PlotData(object):
def __init__(self):
self.step = []
self.value = []
def append(self, step, value):
self.step.append(step)
self.value.append(value)
def reset(self):
self.step = []
self.value = []
class Ploter(object):
def __init__(self, *args):
self.__args__ = args
self.__plot_data__ = {}
for title in args:
self.__plot_data__[title] = PlotData()
# demo in notebooks will use Ploter to plot figure, but when we convert
# the ipydb to py file for testing, the import of matplotlib will make the
# script crash. So we can use `export DISABLE_PLOT=True` to disable import
# these libs
self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
if not self.__plot_is_disabled__():
import matplotlib.pyplot as plt
from IPython import display
self.plt = plt
self.display = display
def __plot_is_disabled__(self):
return self.__disable_plot__ == "True"
def append(self, title, step, value):
assert isinstance(title, six.string_types)
assert title in self.__plot_data__
data = self.__plot_data__[title]
assert isinstance(data, PlotData)
data.append(step, value)
def plot(self, path=None):
if self.__plot_is_disabled__():
return
titles = []
for title in self.__args__:
data = self.__plot_data__[title]
assert isinstance(data, PlotData)
if len(data.step) > 0:
titles.append(title)
self.plt.plot(data.step, data.value)
self.plt.legend(titles, loc='upper left')
if path is None:
self.display.clear_output(wait=True)
self.display.display(self.plt.gcf())
else:
self.plt.savefig(path)
self.plt.gcf().clear()
def reset(self):
for key in self.__plot_data__:
data = self.__plot_data__[key]
assert isinstance(data, PlotData)
data.reset()
```
除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件: 除此之外,可以定义一个事件响应器来处理类似`打印训练进程`的事件:
```python ```python
...@@ -300,6 +226,8 @@ class Ploter(object): ...@@ -300,6 +226,8 @@ class Ploter(object):
params_dirname = "fit_a_line.inference.model" params_dirname = "fit_a_line.inference.model"
# Plot data # Plot data
from paddle.utils import Ploter
train_title = "Train cost" train_title = "Train cost"
test_title = "Test cost" test_title = "Test cost"
plot_cost = Ploter(train_title, test_title) plot_cost = Ploter(train_title, test_title)
......
...@@ -238,80 +238,6 @@ for loading the training data. A reader may return multiple columns, and we need ...@@ -238,80 +238,6 @@ for loading the training data. A reader may return multiple columns, and we need
feed_order=['x', 'y'] feed_order=['x', 'y']
``` ```
And a ploter to plot metrics:
```python
import six
import os
class PlotData(object):
def __init__(self):
self.step = []
self.value = []
def append(self, step, value):
self.step.append(step)
self.value.append(value)
def reset(self):
self.step = []
self.value = []
class Ploter(object):
def __init__(self, *args):
self.__args__ = args
self.__plot_data__ = {}
for title in args:
self.__plot_data__[title] = PlotData()
# demo in notebooks will use Ploter to plot figure, but when we convert
# the ipydb to py file for testing, the import of matplotlib will make the
# script crash. So we can use `export DISABLE_PLOT=True` to disable import
# these libs
self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
if not self.__plot_is_disabled__():
import matplotlib.pyplot as plt
from IPython import display
self.plt = plt
self.display = display
def __plot_is_disabled__(self):
return self.__disable_plot__ == "True"
def append(self, title, step, value):
assert isinstance(title, six.string_types)
assert title in self.__plot_data__
data = self.__plot_data__[title]
assert isinstance(data, PlotData)
data.append(step, value)
def plot(self, path=None):
if self.__plot_is_disabled__():
return
titles = []
for title in self.__args__:
data = self.__plot_data__[title]
assert isinstance(data, PlotData)
if len(data.step) > 0:
titles.append(title)
self.plt.plot(data.step, data.value)
self.plt.legend(titles, loc='upper left')
if path is None:
self.display.clear_output(wait=True)
self.display.display(self.plt.gcf())
else:
self.plt.savefig(path)
self.plt.gcf().clear()
def reset(self):
for key in self.__plot_data__:
data = self.__plot_data__[key]
assert isinstance(data, PlotData)
data.reset()
```
Moreover, an event handler is provided to print the training progress: Moreover, an event handler is provided to print the training progress:
```python ```python
...@@ -319,6 +245,8 @@ Moreover, an event handler is provided to print the training progress: ...@@ -319,6 +245,8 @@ Moreover, an event handler is provided to print the training progress:
params_dirname = "fit_a_line.inference.model" params_dirname = "fit_a_line.inference.model"
# Plot data # Plot data
from paddle.utils import Ploter
train_title = "Train cost" train_title = "Train cost"
test_title = "Test cost" test_title = "Test cost"
plot_cost = Ploter(train_title, test_title) plot_cost = Ploter(train_title, test_title)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册