提交 ae2b2185 编写于 作者: L lyuwenyu 提交者: jzhang533

update

上级 87eb929f
......@@ -737,7 +737,7 @@ class DynamicGraphAdapter(object):
if update:
self.model._optimizer.minimize(final_loss)
self.model.network.clear_gradients()
metrics = []
for metric in self.model._metrics:
metric_outs = metric.compute(*(to_list(outputs) + labels))
......@@ -1538,7 +1538,7 @@ class Model(object):
drop_last=False,
shuffle=True,
num_workers=0,
callbacks=None,
callbacks=None,
accumulate=1, ):
"""
Trains the model for a fixed number of epochs. If `eval_data` is set,
......@@ -1703,8 +1703,9 @@ class Model(object):
do_eval = eval_loader is not None
self._test_dataloader = eval_loader
self._accumulate = accumulate
steps = self._len_data_loader(train_loader)
cbks = config_callbacks(
callbacks,
......@@ -1742,7 +1743,7 @@ class Model(object):
cbks.on_end('train', logs)
self._test_dataloader = None
def evaluate(
self,
eval_data,
......@@ -2009,7 +2010,12 @@ class Model(object):
model_filename=model_filename,
params_filename=params_filename)
def _run_one_epoch(self, data_loader, callbacks, mode, logs={},):
def _run_one_epoch(
self,
data_loader,
callbacks,
mode,
logs={}, ):
outputs = []
for step, data in enumerate(data_loader):
# data might come from different types of data_loader and have
......@@ -2033,13 +2039,13 @@ class Model(object):
callbacks.on_batch_begin(mode, step, logs)
if mode != 'predict':
_inputs = [data[:len(self._inputs)], data[len(self._inputs):]]
if mode == 'train':
_inputs.append((step + 1) % self._accumulate == 0)
outs = getattr(self, mode + '_batch')(*_inputs)
# outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
# data[len(self._inputs):])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册