提交 3e532a78 编写于 作者: L LielinJiang

fix reviews

上级 237549f0
......@@ -905,7 +905,7 @@ class Model(fluid.dygraph.Layer):
eval_freq (int): The frequency, in number of epochs, an evalutation
is performed.
log_freq (int): The frequency, in number of steps, the training logs
is printed.
are printed.
save_dir(str|None): The directory to save checkpoint during training.
If None, will not save checkpoint.
save_freq (int): The frequency, in number of epochs, to save checkpoint.
......@@ -987,14 +987,22 @@ class Model(fluid.dygraph.Layer):
loader = train_loader()
logs = self._run_one_epoch(
loader, cbks, 'train', metrics_name, epoch=epoch)
cbks.on_epoch_end(epoch, logs)
if do_eval and epoch % eval_freq == 0:
# FIXME: adapt to DataLoader
loader = eval_loader
if not isinstance(eval_loader, Iterable):
loader = eval_loader()
eval_steps = len(loader) if hasattr(loader,
'__len__') else None
cbks.on_begin('eval', {
'steps': eval_steps,
'metrics_name': metrics_name
})
logs = self._run_one_epoch(loader, cbks, 'eval', metrics_name)
cbks.on_end('eval', logs)
cbks.on_end('train', logs)
......@@ -1017,11 +1025,11 @@ class Model(fluid.dygraph.Layer):
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
log_freq (int): The frequency, in number of steps, the training logs
is printed.
log_freq (int): The frequency, in number of steps, the eval logs
are printed.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
num_workers (int): The number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
callbacks (Callback|None): A list of `Callback` instances to apply
......@@ -1050,12 +1058,9 @@ class Model(fluid.dygraph.Layer):
self._test_dataloader = eval_loader
metrics_name = self._metrics_name()
steps = len(eval_loader)
cbks = config_callbacks(
callbacks,
model=self,
steps=steps,
log_freq=log_freq,
verbose=verbose,
metrics=self._metrics_name(), )
......@@ -1063,7 +1068,14 @@ class Model(fluid.dygraph.Layer):
loader = eval_loader
if not isinstance(eval_loader, Iterable):
loader = eval_loader()
eval_steps = len(loader) if hasattr(loader, '__len__') else None
cbks.on_begin('eval',
{'steps': eval_steps,
'metrics_name': metrics_name})
logs = self._run_one_epoch(loader, cbks, 'eval', metrics_name)
cbks.on_end('eval', logs)
self._test_dataloader = None
......@@ -1157,9 +1169,8 @@ class Model(fluid.dygraph.Layer):
'metrics_name': metrics_name,
}
callbacks.on_begin(mode, logs)
if mode == 'train':
assert epoch is not None, 'when mode is train, '
assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_begin(epoch)
for step, data in enumerate(data_loader):
......@@ -1198,6 +1209,11 @@ class Model(fluid.dygraph.Layer):
callbacks.on_batch_end(mode, step, logs)
self._reset_metrics()
if mode == 'train':
assert epoch is not None, 'when mode is train, epoch must be given'
callbacks.on_epoch_end(epoch)
return logs
def _reset_metrics(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册