未验证 提交 7d1ea67d 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #15 from guoshengCS/fix-data-train

Reorganize data from data_loader into inputs and labels.
......@@ -29,6 +29,7 @@ from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader, Dataset
......@@ -414,13 +415,7 @@ class StaticGraphAdapter(object):
losses = []
metrics = []
with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict):
ins = [
self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'
]
else:
ins = self.model._inputs
ins = self.model._inputs
lbls = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(ins)]
labels = [k.forward() for k in to_list(lbls)]
......@@ -867,8 +862,10 @@ class Model(fluid.dygraph.Layer):
metric.__class__.__name__)
self._metrics = to_list(metrics)
self._inputs = inputs
self._labels = labels
self._inputs = to_list(inputs) if not isinstance(inputs, dict) else [
inputs[n] for n in extract_args(self.forward) if n != 'self'
]
self._labels = to_list(labels)
if not in_dygraph_mode():
self._adapter.prepare()
......@@ -1174,17 +1171,30 @@ class Model(fluid.dygraph.Layer):
callbacks.on_epoch_begin(epoch)
for step, data in enumerate(data_loader):
if not fluid.in_dygraph_mode():
data = data[0]
batch_size = data[0].shape()[0]
else:
batch_size = data[0].shape[0]
# data might come from different types of data_loader and have
# different format, as following:
# 1. DataLoader in static graph:
# [[input1, input2, ..., label1, lable2, ...]]
# 2. DataLoader in dygraph
# [input1, input2, ..., label1, lable2, ...]
# 3. custumed iterator yield concated inputs and labels:
# [input1, input2, ..., label1, lable2, ...]
# 4. custumed iterator yield seperated inputs and labels:
# ([input1, input2, ...], [label1, lable2, ...])
# To handle all of these, flatten (nested) list to list.
data = flatten(data)
# LoDTensor.shape is callable, where LoDTensor comes from
# DataLoader in static graph
batch_size = data[0].shape()[0] if callable(data[
0].shape) else data[0].shape[0]
callbacks.on_batch_begin(mode, step, logs)
if mode == 'train':
outs = self.train(*data)
outs = self.train(data[:len(self._inputs)],
data[len(self._inputs):])
else:
outs = self.eval(*data)
outs = self.eval(data[:len(self._inputs)],
data[len(self._inputs):])
# losses
loss = outs[0] if self._metrics else outs
......
......@@ -107,7 +107,7 @@ class ProgressBar(object):
eta = time_per_unit * (self._num - current_num)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
60, eta % 60)
60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
......@@ -148,7 +148,7 @@ class ProgressBar(object):
else:
info += ' %.4e' % v
elif isinstance(v, np.ndarray) and \
isinstance(v.size, 1) and \
v.size == 1 and \
isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册