提交 a8fec662 编写于 作者: L lyuwenyu 提交者: jzhang533

fix doc, last iter, and test for amp

上级 77eae14a
...@@ -1584,7 +1584,9 @@ class Model(object): ...@@ -1584,7 +1584,9 @@ class Model(object):
callbacks (Callback|None): A list of `Callback` instances to apply callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint` during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None. are automatically inserted. Default: None.
accumulate (int): The number of steps to accumulate gradident during training process before optimizer updates. It can mimic large batch size. Default: 1. accumulate (int): The number of steps to accumulate gradident during
training process before optimizer updates. It can mimic large batch
size. Default: 1.
Returns: Returns:
None None
...@@ -2044,7 +2046,8 @@ class Model(object): ...@@ -2044,7 +2046,8 @@ class Model(object):
_inputs = [data[:len(self._inputs)], data[len(self._inputs):]] _inputs = [data[:len(self._inputs)], data[len(self._inputs):]]
if mode == 'train': if mode == 'train':
_inputs.append((step + 1) % self._accumulate == 0) _inputs.append((step + 1) % self._accumulate == 0 or
step + 1 == len(data_loader))
outs = getattr(self, mode + '_batch')(*_inputs) outs = getattr(self, mode + '_batch')(*_inputs)
......
...@@ -727,12 +727,20 @@ class TestModelFunction(unittest.TestCase): ...@@ -727,12 +727,20 @@ class TestModelFunction(unittest.TestCase):
parameter_list=net.parameters()) parameter_list=net.parameters())
inputs = [InputSpec([None, dim], 'float32', 'x')] inputs = [InputSpec([None, dim], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')] labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs, labels) model = Model(net, inputs, labels)
model.prepare(optim, loss=CrossEntropyLoss(reduction="sum")) model.prepare(optim, loss=CrossEntropyLoss(reduction="sum"))
loss1, = model.train_batch([data], [label], update=False) loss1, = model.train_batch([data], [label], update=False)
loss2, = model.train_batch([data], [label], update=True) loss2, = model.train_batch([data], [label], update=True)
np.testing.assert_almost_equal(loss1, loss2, decimal=4) np.testing.assert_almost_equal(loss1, loss2, decimal=4)
model = Model(net, inputs, labels)
model.prepare(
optim, loss=CrossEntropyLoss(reduction="sum"), amp_configs='O1')
loss1, = model.train_batch([data], [label], update=False)
loss2, = model.train_batch([data], [label], update=True)
np.testing.assert_almost_equal(loss1, loss2, decimal=4)
class TestModelWithLRScheduler(unittest.TestCase): class TestModelWithLRScheduler(unittest.TestCase):
def test_fit_by_step(self): def test_fit_by_step(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册