未验证 提交 7d1ea67d 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #15 from guoshengCS/fix-data-train

Reorganize data from data_loader into inputs and labels.
...@@ -29,6 +29,7 @@ from paddle.fluid.executor import global_scope ...@@ -29,6 +29,7 @@ from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader, Dataset from paddle.fluid.io import DataLoader, Dataset
...@@ -414,13 +415,7 @@ class StaticGraphAdapter(object): ...@@ -414,13 +415,7 @@ class StaticGraphAdapter(object):
losses = [] losses = []
metrics = [] metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict): ins = self.model._inputs
ins = [
self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'
]
else:
ins = self.model._inputs
lbls = self.model._labels if self.model._labels else [] lbls = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(ins)] inputs = [k.forward() for k in to_list(ins)]
labels = [k.forward() for k in to_list(lbls)] labels = [k.forward() for k in to_list(lbls)]
...@@ -867,8 +862,10 @@ class Model(fluid.dygraph.Layer): ...@@ -867,8 +862,10 @@ class Model(fluid.dygraph.Layer):
metric.__class__.__name__) metric.__class__.__name__)
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._inputs = inputs self._inputs = to_list(inputs) if not isinstance(inputs, dict) else [
self._labels = labels inputs[n] for n in extract_args(self.forward) if n != 'self'
]
self._labels = to_list(labels)
if not in_dygraph_mode(): if not in_dygraph_mode():
self._adapter.prepare() self._adapter.prepare()
...@@ -1174,17 +1171,30 @@ class Model(fluid.dygraph.Layer): ...@@ -1174,17 +1171,30 @@ class Model(fluid.dygraph.Layer):
callbacks.on_epoch_begin(epoch) callbacks.on_epoch_begin(epoch)
for step, data in enumerate(data_loader): for step, data in enumerate(data_loader):
if not fluid.in_dygraph_mode(): # data might come from different types of data_loader and have
data = data[0] # different format, as following:
batch_size = data[0].shape()[0] # 1. DataLoader in static graph:
else: # [[input1, input2, ..., label1, lable2, ...]]
batch_size = data[0].shape[0] # 2. DataLoader in dygraph
# [input1, input2, ..., label1, lable2, ...]
# 3. custumed iterator yield concated inputs and labels:
# [input1, input2, ..., label1, lable2, ...]
# 4. custumed iterator yield seperated inputs and labels:
# ([input1, input2, ...], [label1, lable2, ...])
# To handle all of these, flatten (nested) list to list.
data = flatten(data)
# LoDTensor.shape is callable, where LoDTensor comes from
# DataLoader in static graph
batch_size = data[0].shape()[0] if callable(data[
0].shape) else data[0].shape[0]
callbacks.on_batch_begin(mode, step, logs) callbacks.on_batch_begin(mode, step, logs)
if mode == 'train': if mode == 'train':
outs = self.train(*data) outs = self.train(data[:len(self._inputs)],
data[len(self._inputs):])
else: else:
outs = self.eval(*data) outs = self.eval(data[:len(self._inputs)],
data[len(self._inputs):])
# losses # losses
loss = outs[0] if self._metrics else outs loss = outs[0] if self._metrics else outs
......
...@@ -107,7 +107,7 @@ class ProgressBar(object): ...@@ -107,7 +107,7 @@ class ProgressBar(object):
eta = time_per_unit * (self._num - current_num) eta = time_per_unit * (self._num - current_num)
if eta > 3600: if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
60, eta % 60) 60, eta % 60)
elif eta > 60: elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60) eta_format = '%d:%02d' % (eta // 60, eta % 60)
else: else:
...@@ -148,7 +148,7 @@ class ProgressBar(object): ...@@ -148,7 +148,7 @@ class ProgressBar(object):
else: else:
info += ' %.4e' % v info += ' %.4e' % v
elif isinstance(v, np.ndarray) and \ elif isinstance(v, np.ndarray) and \
isinstance(v.size, 1) and \ v.size == 1 and \
isinstance(v.dtype, (np.float32, np.float64)): isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3: if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0] info += ' %.4f' % v[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册