提交 4748baa2 编写于 作者: L linjintao

Move parse_loss into class

上级 d1eb189f
from .inference import inference_recognizer, init_recognizer
from .test import multi_gpu_test, single_gpu_test
from .train import parse_losses, set_random_seed, train_model
from .train import set_random_seed, train_model
__all__ = [
'set_random_seed', 'train_model', 'init_recognizer',
'inference_recognizer', 'multi_gpu_test', 'single_gpu_test', 'parse_losses'
'inference_recognizer', 'multi_gpu_test', 'single_gpu_test'
]
import os
import random
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
......@@ -35,40 +33,6 @@ def set_random_seed(seed, deterministic=False):
torch.backends.cudnn.benchmark = False
def parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_model(model,
dataset,
cfg,
......
......@@ -37,7 +37,8 @@ class BaseLocalizer(nn.Module, metaclass=ABCMeta):
else:
return self.forward_test(imgs)
def _parse_losses(self, losses):
@staticmethod
def _parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
......
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from ...apis import parse_losses
from .. import builder
......@@ -75,6 +77,42 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
def forward_test(self, imgs):
pass
@staticmethod
def _parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def forward(self, imgs, label=None, return_loss=True):
if return_loss:
if label is None:
......@@ -89,7 +127,7 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
losses = self.forward(imgs, label)
loss, log_vars = parse_losses(losses)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
......
import pytest
import torch
from mmaction.apis import parse_losses
def test_parse_loss():
with pytest.raises(TypeError):
# loss must be a tensor or list of tensors
losses = dict(loss=0.5)
parse_losses(losses)
# loss values are a tenor and list of tensors
a_loss = [torch.randn(5, 5), torch.randn(5, 5)]
b_loss = torch.randn(5, 5)
losses = dict(a_loss=a_loss, b_loss=b_loss)
r_a_loss = sum(_loss.mean() for _loss in a_loss)
r_b_loss = b_loss.mean()
r_loss = [r_a_loss, r_b_loss]
r_loss = sum(r_loss)
loss, log_vars = parse_losses(losses)
assert r_loss == loss
assert set(log_vars.keys()) == set(['a_loss', 'b_loss', 'loss'])
assert log_vars['a_loss'] == r_a_loss
assert log_vars['b_loss'] == r_b_loss
assert log_vars['loss'] == r_loss
ones_loss = torch.ones(5, 5)
losses = dict(ones_loss=ones_loss)
loss, _ = parse_losses(losses)
loss.requires_grad_(True)
assert float(loss.item()) == 1.0
loss.backward()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册