提交 87eb929f 编写于 作者: L lyuwenyu 提交者: jzhang533

add gradient accumulate for dygraph

上级 921b0418
......@@ -701,7 +701,7 @@ class DynamicGraphAdapter(object):
self.model.mode = value
# TODO multi device in dygraph mode not implemented at present time
def train_batch(self, inputs, labels=None):
def train_batch(self, inputs, labels=None, update=True):
assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first"
self.model.network.train()
......@@ -729,13 +729,15 @@ class DynamicGraphAdapter(object):
if self._amp_level != "O0":
scaled = scaler.scale(final_loss)
scaled.backward()
scaler.minimize(self.model._optimizer, scaled)
self.model.network.clear_gradients()
if update:
scaler.minimize(self.model._optimizer, scaled)
self.model.network.clear_gradients()
else:
final_loss.backward()
self.model._optimizer.minimize(final_loss)
self.model.network.clear_gradients()
if update:
self.model._optimizer.minimize(final_loss)
self.model.network.clear_gradients()
metrics = []
for metric in self.model._metrics:
metric_outs = metric.compute(*(to_list(outputs) + labels))
......@@ -1017,7 +1019,7 @@ class Model(object):
else:
self._adapter = StaticGraphAdapter(self)
def train_batch(self, inputs, labels=None):
def train_batch(self, inputs, labels=None, update=True):
"""
Run one training step on a batch of data.
......@@ -1062,7 +1064,7 @@ class Model(object):
loss = model.train_batch([data], [label])
print(loss)
"""
loss = self._adapter.train_batch(inputs, labels)
loss = self._adapter.train_batch(inputs, labels, update)
if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs()
return loss
......@@ -1536,7 +1538,8 @@ class Model(object):
drop_last=False,
shuffle=True,
num_workers=0,
callbacks=None, ):
callbacks=None,
accumulate=1, ):
"""
Trains the model for a fixed number of epochs. If `eval_data` is set,
evaluation will be done at the end of each epoch.
......@@ -1579,7 +1582,8 @@ class Model(object):
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None.
accumulate (int): The number of steps to accumulate gradident in training process before optimizer update. Using this to mimic large batch size. Default: 1.
Returns:
None
......@@ -1699,7 +1703,8 @@ class Model(object):
do_eval = eval_loader is not None
self._test_dataloader = eval_loader
self._accumulate = accumulate
steps = self._len_data_loader(train_loader)
cbks = config_callbacks(
callbacks,
......@@ -1737,7 +1742,7 @@ class Model(object):
cbks.on_end('train', logs)
self._test_dataloader = None
def evaluate(
self,
eval_data,
......@@ -2004,7 +2009,7 @@ class Model(object):
model_filename=model_filename,
params_filename=params_filename)
def _run_one_epoch(self, data_loader, callbacks, mode, logs={}):
def _run_one_epoch(self, data_loader, callbacks, mode, logs={},):
outputs = []
for step, data in enumerate(data_loader):
# data might come from different types of data_loader and have
......@@ -2028,8 +2033,16 @@ class Model(object):
callbacks.on_batch_begin(mode, step, logs)
if mode != 'predict':
outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
data[len(self._inputs):])
_inputs = [data[:len(self._inputs)], data[len(self._inputs):]]
if mode == 'train':
_inputs.append((step + 1) % self._accumulate == 0)
outs = getattr(self, mode + '_batch')(*_inputs)
# outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
# data[len(self._inputs):])
if self._metrics and self._loss:
metrics = [[l[0] for l in outs[0]]]
elif self._loss:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册