提交 77eae14a 编写于 作者: L lyuwenyu 提交者: jzhang533

update

上级 f9f21a5c
......@@ -302,7 +302,7 @@ class StaticGraphAdapter(object):
assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first"
self.mode = 'train'
assert update is True, "Model does not support `update == False` in static mode by now."
assert update is True, "Does not support `update == False` in static mode by now."
return self._run(inputs, labels)
def eval_batch(self, inputs, labels=None):
......@@ -1032,7 +1032,7 @@ class Model(object):
a numpy array or paddle.Tensor, or a list of arrays or tensors
(in case the model has multiple labels). If has no labels,
set None. Default is None.
update (bool): Whether update parameters after loss.backward() computes. Using this to accumulate gradients. Default is True.
update (bool): Whether update parameters after loss.backward() computing. Using it to accumulate gradients. Default is True.
Returns:
A list of scalar training loss if the model has no metrics,
......@@ -1584,7 +1584,7 @@ class Model(object):
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
accumulate (int): The number of steps to accumulate gradident in training process before optimizer update. Using this to mimic large batch size. Default: 1.
accumulate (int): The number of steps to accumulate gradident during training process before optimizer updates. It can mimic large batch size. Default: 1.
Returns:
None
......
......@@ -729,8 +729,9 @@ class TestModelFunction(unittest.TestCase):
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs, labels)
model.prepare(optim, loss=CrossEntropyLoss(reduction="sum"))
loss1, = model.train_batch([data], [label], update=True)
loss2, = model.train_batch([data], [label], update=False)
loss1, = model.train_batch([data], [label], update=False)
loss2, = model.train_batch([data], [label], update=True)
np.testing.assert_almost_equal(loss1, loss2, decimal=4)
class TestModelWithLRScheduler(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册