提交 a08ac369 编写于 作者: Q qingqing01

Fix multi-loss

上级 59b986c7
...@@ -18,7 +18,6 @@ import inspect ...@@ -18,7 +18,6 @@ import inspect
import os import os
import pickle import pickle
import numpy as np import numpy as np
import itertools
from collections import Iterable from collections import Iterable
from collections import OrderedDict from collections import OrderedDict
...@@ -742,14 +741,17 @@ class Model(fluid.dygraph.Layer): ...@@ -742,14 +741,17 @@ class Model(fluid.dygraph.Layer):
else: else:
outs = self.eval(*data) outs = self.eval(*data)
metrics = list(itertools.chain.from_iterable(outs)) # losses
metrics = [np.mean(metrics[0])] loss = outs[0] if self._metrics else outs
metrics = [[l[0] for l in loss]]
# metrics
for metric in self._metrics: for metric in self._metrics:
res = metric.accumulate() res = metric.accumulate()
metrics.extend(to_list(res)) metrics.extend(to_list(res))
assert len(metrics_name) == len(metrics) assert len(metrics_name) == len(metrics)
for k, v in zip(metrics_name, metrics): for k, v in zip(metrics_name, metrics):
logs[k] = np.mean(v) logs[k] = v
logs['step'] = step logs['step'] = step
logs['batch_size'] = data[0].shape[0] logs['batch_size'] = data[0].shape[0]
...@@ -761,7 +763,7 @@ class Model(fluid.dygraph.Layer): ...@@ -761,7 +763,7 @@ class Model(fluid.dygraph.Layer):
cbks.on_begin('train') cbks.on_begin('train')
for epoch in range(epochs): for epoch in range(epochs):
cbks.on_epoch_begin(epoch) cbks.on_epoch_begin(epoch)
# FIXME: adapte to DataLoader # FIXME: adapt to DataLoader
loader = train_loader loader = train_loader
if not isinstance(train_loader, Iterable): if not isinstance(train_loader, Iterable):
loader = train_loader() loader = train_loader()
...@@ -770,7 +772,7 @@ class Model(fluid.dygraph.Layer): ...@@ -770,7 +772,7 @@ class Model(fluid.dygraph.Layer):
if do_eval and epoch % eval_freq == 0: if do_eval and epoch % eval_freq == 0:
cbks.on_begin('eval', logs) cbks.on_begin('eval', logs)
# FIXME: adapte to DataLoader # FIXME: adapt to DataLoader
loader = eval_loader loader = eval_loader
if not isinstance(eval_loader, Iterable): if not isinstance(eval_loader, Iterable):
loader = eval_loader() loader = eval_loader()
......
...@@ -91,15 +91,17 @@ class ProgressBar(object): ...@@ -91,15 +91,17 @@ class ProgressBar(object):
self._total_width = len(bar_chars) self._total_width = len(bar_chars)
sys.stdout.write(bar_chars) sys.stdout.write(bar_chars)
for k, v in values: for k, val in values:
info += ' - %s:' % k info += ' - %s:' % k
if isinstance(v, (float, np.float32, np.float64)): val = val if isinstance(val, list) else [val]
if abs(v) > 1e-3: for i, v in enumerate(val):
info += ' %.4f' % v if isinstance(v, (float, np.float32, np.float64)):
if abs(v) > 1e-3:
info += ' %.4f' % v
else:
info += ' %.4e' % v
else: else:
info += ' %.4e' % v info += ' %s' % v
else:
info += ' %s' % v
if self._num is not None and current_num < self._num: if self._num is not None and current_num < self._num:
eta = time_per_unit * (self._num - current_num) eta = time_per_unit * (self._num - current_num)
...@@ -136,22 +138,24 @@ class ProgressBar(object): ...@@ -136,22 +138,24 @@ class ProgressBar(object):
count = 'step %3d' % current_num count = 'step %3d' % current_num
info = count + info info = count + info
for k, v in values: for k, val in values:
info += ' - %s:' % k info += ' - %s:' % k
if isinstance(v, (float, np.float32, np.float64)): val = val if isinstance(val, list) else [val]
if abs(v) > 1e-3: for v in val:
info += ' %.4f' % v if isinstance(v, (float, np.float32, np.float64)):
else: if abs(v) > 1e-3:
info += ' %.4e' % v info += ' %.4f' % v
elif isinstance(v, np.ndarray) and \ else:
isinstance(v.size, 1) and \ info += ' %.4e' % v
isinstance(v.dtype, (np.float32, np.float64)): elif isinstance(v, np.ndarray) and \
if abs(v[0]) > 1e-3: isinstance(v.size, 1) and \
info += ' %.4f' % v[0] isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0]
else:
info += ' %.4e' % v[0]
else: else:
info += ' %.4e' % v[0] info += ' %s' % v
else:
info += ' %s' % v
info += fps info += fps
info += '\n' info += '\n'
......
...@@ -24,7 +24,7 @@ import contextlib ...@@ -24,7 +24,7 @@ import contextlib
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input from model import Model, CrossEntropy, Input, Loss
from metrics import Accuracy from metrics import Accuracy
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
...@@ -103,47 +103,65 @@ class MNIST(Model): ...@@ -103,47 +103,65 @@ class MNIST(Model):
return x return x
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(100.0 * correct_k / batch_size)
return res
@contextlib.contextmanager @contextlib.contextmanager
def null_guard(): def null_guard():
yield yield
class MLP(Model):
def __init__(self):
super(MLP, self).__init__()
SIZE = 10
self._fc1 = Linear(784, 200, act="relu")
self._fc2 = Linear(200, 200, act="relu")
self._fc3 = Linear(200, 200, act="relu")
self._fc4 = Linear(200, 10, act="softmax")
self._fc5 = Linear(200, 10, act="softmax")
def forward(self, inputs):
x1 = self._fc1(inputs)
x2 = self._fc2(x1)
x3 = self._fc3(x2)
o1 = self._fc5(x3)
o2 = self._fc4(x2)
return o1, o2
class MyCrossEntropy(Loss):
def __init__(self, average=True):
super(MyCrossEntropy, self).__init__()
def forward(self, outputs, labels):
loss1 = fluid.layers.cross_entropy(outputs[0], labels[0])
loss2 = fluid.layers.cross_entropy(outputs[1], labels[0])
return [loss1, loss2]
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def fit(self, dynamic): def fit(self, dynamic, is_mlp=False):
im_shape = (-1, 784) if is_mlp else (-1, 1, 28, 28)
guard = fluid.dygraph.guard() if dynamic else null_guard() guard = fluid.dygraph.guard() if dynamic else null_guard()
batch_size = 128 batch_size = 128
train_loader = fluid.io.xmap_readers( train_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), lambda b: [np.array([x[0] for x in b]).reshape(im_shape),
np.array([x[1] for x in b]).reshape(-1, 1)], np.array([x[1] for x in b]).reshape(-1, 1)],
paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4), paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4),
batch_size=batch_size, drop_last=True), 1, 1) batch_size=batch_size, drop_last=True), 1, 1)
val_loader = fluid.io.xmap_readers( val_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), lambda b: [np.array([x[0] for x in b]).reshape(im_shape),
np.array([x[1] for x in b]).reshape(-1, 1)], np.array([x[1] for x in b]).reshape(-1, 1)],
paddle.batch(paddle.dataset.mnist.test(), paddle.batch(paddle.dataset.mnist.test(),
batch_size=batch_size, drop_last=False), 1, 1) batch_size=batch_size, drop_last=False), 1, 1)
with guard: with guard:
inputs = [Input([None, 1, 28, 28], 'float32', name='image')] inputs = [Input(im_shape, 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')] labels = [Input([None, 1], 'int64', name='label')]
model = MNIST() model = MNIST() if not is_mlp else MLP()
optim = fluid.optimizer.Momentum( optim = fluid.optimizer.Momentum(
learning_rate=0.01, learning_rate=0.01,
momentum=.9, momentum=.9,
parameter_list=model.parameters()) parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(), inputs, labels) loss = CrossEntropy() if not is_mlp else MyCrossEntropy()
model.prepare(optim, loss, Accuracy(), inputs, labels)
cbk = ProgBarLogger(50) cbk = ProgBarLogger(50)
model.fit(train_loader, val_loader, epochs=2, callbacks=cbk) model.fit(train_loader, val_loader, epochs=2, callbacks=cbk)
...@@ -153,6 +171,12 @@ class TestModel(unittest.TestCase): ...@@ -153,6 +171,12 @@ class TestModel(unittest.TestCase):
def test_fit_dygraph(self): def test_fit_dygraph(self):
self.fit(True) self.fit(True)
def test_fit_static_multi_loss(self):
self.fit(False, MyCrossEntropy())
def test_fit_dygraph_multi_loss(self):
self.fit(True, MyCrossEntropy())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册