提交 87eb929f 编写于 作者: L lyuwenyu 提交者: jzhang533

add gradient accumulate for dygraph

上级 921b0418
...@@ -701,7 +701,7 @@ class DynamicGraphAdapter(object): ...@@ -701,7 +701,7 @@ class DynamicGraphAdapter(object):
self.model.mode = value self.model.mode = value
# TODO multi device in dygraph mode not implemented at present time # TODO multi device in dygraph mode not implemented at present time
def train_batch(self, inputs, labels=None): def train_batch(self, inputs, labels=None, update=True):
assert self.model._optimizer, \ assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first" "model not ready, please call `model.prepare()` first"
self.model.network.train() self.model.network.train()
...@@ -729,13 +729,15 @@ class DynamicGraphAdapter(object): ...@@ -729,13 +729,15 @@ class DynamicGraphAdapter(object):
if self._amp_level != "O0": if self._amp_level != "O0":
scaled = scaler.scale(final_loss) scaled = scaler.scale(final_loss)
scaled.backward() scaled.backward()
scaler.minimize(self.model._optimizer, scaled) if update:
self.model.network.clear_gradients() scaler.minimize(self.model._optimizer, scaled)
self.model.network.clear_gradients()
else: else:
final_loss.backward() final_loss.backward()
self.model._optimizer.minimize(final_loss) if update:
self.model.network.clear_gradients() self.model._optimizer.minimize(final_loss)
self.model.network.clear_gradients()
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.compute(*(to_list(outputs) + labels)) metric_outs = metric.compute(*(to_list(outputs) + labels))
...@@ -1017,7 +1019,7 @@ class Model(object): ...@@ -1017,7 +1019,7 @@ class Model(object):
else: else:
self._adapter = StaticGraphAdapter(self) self._adapter = StaticGraphAdapter(self)
def train_batch(self, inputs, labels=None): def train_batch(self, inputs, labels=None, update=True):
""" """
Run one training step on a batch of data. Run one training step on a batch of data.
...@@ -1062,7 +1064,7 @@ class Model(object): ...@@ -1062,7 +1064,7 @@ class Model(object):
loss = model.train_batch([data], [label]) loss = model.train_batch([data], [label])
print(loss) print(loss)
""" """
loss = self._adapter.train_batch(inputs, labels) loss = self._adapter.train_batch(inputs, labels, update)
if fluid.in_dygraph_mode() and self._input_info is None: if fluid.in_dygraph_mode() and self._input_info is None:
self._update_inputs() self._update_inputs()
return loss return loss
...@@ -1536,7 +1538,8 @@ class Model(object): ...@@ -1536,7 +1538,8 @@ class Model(object):
drop_last=False, drop_last=False,
shuffle=True, shuffle=True,
num_workers=0, num_workers=0,
callbacks=None, ): callbacks=None,
accumulate=1, ):
""" """
Trains the model for a fixed number of epochs. If `eval_data` is set, Trains the model for a fixed number of epochs. If `eval_data` is set,
evaluation will be done at the end of each epoch. evaluation will be done at the end of each epoch.
...@@ -1579,7 +1582,8 @@ class Model(object): ...@@ -1579,7 +1582,8 @@ class Model(object):
callbacks (Callback|None): A list of `Callback` instances to apply callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint` during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. Default: None. are automatically inserted. Default: None.
accumulate (int): The number of steps to accumulate gradident in training process before optimizer update. Using this to mimic large batch size. Default: 1.
Returns: Returns:
None None
...@@ -1699,7 +1703,8 @@ class Model(object): ...@@ -1699,7 +1703,8 @@ class Model(object):
do_eval = eval_loader is not None do_eval = eval_loader is not None
self._test_dataloader = eval_loader self._test_dataloader = eval_loader
self._accumulate = accumulate
steps = self._len_data_loader(train_loader) steps = self._len_data_loader(train_loader)
cbks = config_callbacks( cbks = config_callbacks(
callbacks, callbacks,
...@@ -1737,7 +1742,7 @@ class Model(object): ...@@ -1737,7 +1742,7 @@ class Model(object):
cbks.on_end('train', logs) cbks.on_end('train', logs)
self._test_dataloader = None self._test_dataloader = None
def evaluate( def evaluate(
self, self,
eval_data, eval_data,
...@@ -2004,7 +2009,7 @@ class Model(object): ...@@ -2004,7 +2009,7 @@ class Model(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
def _run_one_epoch(self, data_loader, callbacks, mode, logs={}): def _run_one_epoch(self, data_loader, callbacks, mode, logs={},):
outputs = [] outputs = []
for step, data in enumerate(data_loader): for step, data in enumerate(data_loader):
# data might come from different types of data_loader and have # data might come from different types of data_loader and have
...@@ -2028,8 +2033,16 @@ class Model(object): ...@@ -2028,8 +2033,16 @@ class Model(object):
callbacks.on_batch_begin(mode, step, logs) callbacks.on_batch_begin(mode, step, logs)
if mode != 'predict': if mode != 'predict':
outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
data[len(self._inputs):]) _inputs = [data[:len(self._inputs)], data[len(self._inputs):]]
if mode == 'train':
_inputs.append((step + 1) % self._accumulate == 0)
outs = getattr(self, mode + '_batch')(*_inputs)
# outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
# data[len(self._inputs):])
if self._metrics and self._loss: if self._metrics and self._loss:
metrics = [[l[0] for l in outs[0]]] metrics = [[l[0] for l in outs[0]]]
elif self._loss: elif self._loss:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册