未验证 提交 f1f206cc 编写于 作者: J Jintao Lin 提交者: GitHub

Modify 'val_step()' to validate data for each val mode epoch (#123)

Co-authored-by: Nlizz <innerlee@users.noreply.github.com>
上级 18e3c779
......@@ -164,7 +164,7 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
imgs = data_batch['imgs']
label = data_batch['label']
losses = self.forward(imgs, label)
losses = self(imgs, label)
loss, log_vars = self._parse_losses(losses)
......@@ -183,9 +183,15 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
not implemented with this method, but an evaluation hook.
"""
imgs = data_batch['imgs']
label = data_batch['label']
losses = self(imgs, label)
results = self.forward(imgs, None, return_loss=False)
loss, log_vars = self._parse_losses(losses)
outputs = dict(results=results)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(next(iter(data_batch.values()))))
return outputs
......@@ -3,6 +3,7 @@ import copy
import os
import os.path as osp
import time
import warnings
import mmcv
import torch
......@@ -124,6 +125,10 @@ def main():
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
if args.validate:
warnings.warn('val workflow is duplicated with `--validate`, '
'it is recommended to use `--validate`. see '
'https://github.com/open-mmlab/mmaction2/pull/123')
val_dataset = copy.deepcopy(cfg.data.val)
datasets.append(build_dataset(val_dataset))
if cfg.checkpoint_config is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册