提交 ae2b2185 编写于 作者: L lyuwenyu 提交者: jzhang533

update

上级 87eb929f
...@@ -737,7 +737,7 @@ class DynamicGraphAdapter(object): ...@@ -737,7 +737,7 @@ class DynamicGraphAdapter(object):
if update: if update:
self.model._optimizer.minimize(final_loss) self.model._optimizer.minimize(final_loss)
self.model.network.clear_gradients() self.model.network.clear_gradients()
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.compute(*(to_list(outputs) + labels)) metric_outs = metric.compute(*(to_list(outputs) + labels))
...@@ -1538,7 +1538,7 @@ class Model(object): ...@@ -1538,7 +1538,7 @@ class Model(object):
drop_last=False, drop_last=False,
shuffle=True, shuffle=True,
num_workers=0, num_workers=0,
callbacks=None, callbacks=None,
accumulate=1, ): accumulate=1, ):
""" """
Trains the model for a fixed number of epochs. If `eval_data` is set, Trains the model for a fixed number of epochs. If `eval_data` is set,
...@@ -1703,8 +1703,9 @@ class Model(object): ...@@ -1703,8 +1703,9 @@ class Model(object):
do_eval = eval_loader is not None do_eval = eval_loader is not None
self._test_dataloader = eval_loader self._test_dataloader = eval_loader
self._accumulate = accumulate self._accumulate = accumulate
steps = self._len_data_loader(train_loader) steps = self._len_data_loader(train_loader)
cbks = config_callbacks( cbks = config_callbacks(
callbacks, callbacks,
...@@ -1742,7 +1743,7 @@ class Model(object): ...@@ -1742,7 +1743,7 @@ class Model(object):
cbks.on_end('train', logs) cbks.on_end('train', logs)
self._test_dataloader = None self._test_dataloader = None
def evaluate( def evaluate(
self, self,
eval_data, eval_data,
...@@ -2009,7 +2010,12 @@ class Model(object): ...@@ -2009,7 +2010,12 @@ class Model(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
def _run_one_epoch(self, data_loader, callbacks, mode, logs={},): def _run_one_epoch(
self,
data_loader,
callbacks,
mode,
logs={}, ):
outputs = [] outputs = []
for step, data in enumerate(data_loader): for step, data in enumerate(data_loader):
# data might come from different types of data_loader and have # data might come from different types of data_loader and have
...@@ -2033,13 +2039,13 @@ class Model(object): ...@@ -2033,13 +2039,13 @@ class Model(object):
callbacks.on_batch_begin(mode, step, logs) callbacks.on_batch_begin(mode, step, logs)
if mode != 'predict': if mode != 'predict':
_inputs = [data[:len(self._inputs)], data[len(self._inputs):]] _inputs = [data[:len(self._inputs)], data[len(self._inputs):]]
if mode == 'train': if mode == 'train':
_inputs.append((step + 1) % self._accumulate == 0) _inputs.append((step + 1) % self._accumulate == 0)
outs = getattr(self, mode + '_batch')(*_inputs) outs = getattr(self, mode + '_batch')(*_inputs)
# outs = getattr(self, mode + '_batch')(data[:len(self._inputs)], # outs = getattr(self, mode + '_batch')(data[:len(self._inputs)],
# data[len(self._inputs):]) # data[len(self._inputs):])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册