Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
openvinotoolkit
mmaction2
提交
4748baa2
M
mmaction2
项目概览
openvinotoolkit
/
mmaction2
大约 1 年 前同步成功
通知
2
Star
5
Fork
3
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mmaction2
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4748baa2
编写于
7月 01, 2020
作者:
L
linjintao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move parse_loss into class
上级
d1eb189f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
44 addition
and
77 deletion
+44
-77
mmaction/apis/__init__.py
mmaction/apis/__init__.py
+2
-2
mmaction/apis/train.py
mmaction/apis/train.py
+0
-36
mmaction/models/localizers/base.py
mmaction/models/localizers/base.py
+2
-1
mmaction/models/recognizers/base.py
mmaction/models/recognizers/base.py
+40
-2
tests/test_core.py
tests/test_core.py
+0
-36
未找到文件。
mmaction/apis/__init__.py
浏览文件 @
4748baa2
from
.inference
import
inference_recognizer
,
init_recognizer
from
.test
import
multi_gpu_test
,
single_gpu_test
from
.train
import
parse_losses
,
set_random_seed
,
train_model
from
.train
import
set_random_seed
,
train_model
__all__
=
[
'set_random_seed'
,
'train_model'
,
'init_recognizer'
,
'inference_recognizer'
,
'multi_gpu_test'
,
'single_gpu_test'
,
'parse_losses'
'inference_recognizer'
,
'multi_gpu_test'
,
'single_gpu_test'
]
mmaction/apis/train.py
浏览文件 @
4748baa2
import
os
import
random
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
mmcv.parallel
import
MMDataParallel
,
MMDistributedDataParallel
from
mmcv.runner
import
(
DistSamplerSeedHook
,
EpochBasedRunner
,
OptimizerHook
,
build_optimizer
)
...
...
@@ -35,40 +33,6 @@ def set_random_seed(seed, deterministic=False):
torch
.
backends
.
cudnn
.
benchmark
=
False
def
parse_losses
(
losses
):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
f
'
{
loss_name
}
is not a tensor or list of tensors'
)
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
loss_name
,
loss_value
in
log_vars
.
items
():
# reduce loss when distributed training
if
dist
.
is_available
()
and
dist
.
is_initialized
():
loss_value
=
loss_value
.
data
.
clone
()
dist
.
all_reduce
(
loss_value
.
div_
(
dist
.
get_world_size
()))
log_vars
[
loss_name
]
=
loss_value
.
item
()
return
loss
,
log_vars
def
train_model
(
model
,
dataset
,
cfg
,
...
...
mmaction/models/localizers/base.py
浏览文件 @
4748baa2
...
...
@@ -37,7 +37,8 @@ class BaseLocalizer(nn.Module, metaclass=ABCMeta):
else
:
return
self
.
forward_test
(
imgs
)
def
_parse_losses
(
self
,
losses
):
@
staticmethod
def
_parse_losses
(
losses
):
"""Parse the raw outputs (losses) of the network.
Args:
...
...
mmaction/models/recognizers/base.py
浏览文件 @
4748baa2
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
...apis
import
parse_losses
from
..
import
builder
...
...
@@ -75,6 +77,42 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
def
forward_test
(
self
,
imgs
):
pass
@
staticmethod
def
_parse_losses
(
losses
):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
if
isinstance
(
loss_value
,
torch
.
Tensor
):
log_vars
[
loss_name
]
=
loss_value
.
mean
()
elif
isinstance
(
loss_value
,
list
):
log_vars
[
loss_name
]
=
sum
(
_loss
.
mean
()
for
_loss
in
loss_value
)
else
:
raise
TypeError
(
f
'
{
loss_name
}
is not a tensor or list of tensors'
)
loss
=
sum
(
_value
for
_key
,
_value
in
log_vars
.
items
()
if
'loss'
in
_key
)
log_vars
[
'loss'
]
=
loss
for
loss_name
,
loss_value
in
log_vars
.
items
():
# reduce loss when distributed training
if
dist
.
is_available
()
and
dist
.
is_initialized
():
loss_value
=
loss_value
.
data
.
clone
()
dist
.
all_reduce
(
loss_value
.
div_
(
dist
.
get_world_size
()))
log_vars
[
loss_name
]
=
loss_value
.
item
()
return
loss
,
log_vars
def
forward
(
self
,
imgs
,
label
=
None
,
return_loss
=
True
):
if
return_loss
:
if
label
is
None
:
...
...
@@ -89,7 +127,7 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
losses
=
self
.
forward
(
imgs
,
label
)
loss
,
log_vars
=
parse_losses
(
losses
)
loss
,
log_vars
=
self
.
_
parse_losses
(
losses
)
outputs
=
dict
(
loss
=
loss
,
...
...
tests/test_core.py
已删除
100644 → 0
浏览文件 @
d1eb189f
import
pytest
import
torch
from
mmaction.apis
import
parse_losses
def
test_parse_loss
():
with
pytest
.
raises
(
TypeError
):
# loss must be a tensor or list of tensors
losses
=
dict
(
loss
=
0.5
)
parse_losses
(
losses
)
# loss values are a tenor and list of tensors
a_loss
=
[
torch
.
randn
(
5
,
5
),
torch
.
randn
(
5
,
5
)]
b_loss
=
torch
.
randn
(
5
,
5
)
losses
=
dict
(
a_loss
=
a_loss
,
b_loss
=
b_loss
)
r_a_loss
=
sum
(
_loss
.
mean
()
for
_loss
in
a_loss
)
r_b_loss
=
b_loss
.
mean
()
r_loss
=
[
r_a_loss
,
r_b_loss
]
r_loss
=
sum
(
r_loss
)
loss
,
log_vars
=
parse_losses
(
losses
)
assert
r_loss
==
loss
assert
set
(
log_vars
.
keys
())
==
set
([
'a_loss'
,
'b_loss'
,
'loss'
])
assert
log_vars
[
'a_loss'
]
==
r_a_loss
assert
log_vars
[
'b_loss'
]
==
r_b_loss
assert
log_vars
[
'loss'
]
==
r_loss
ones_loss
=
torch
.
ones
(
5
,
5
)
losses
=
dict
(
ones_loss
=
ones_loss
)
loss
,
_
=
parse_losses
(
losses
)
loss
.
requires_grad_
(
True
)
assert
float
(
loss
.
item
())
==
1.0
loss
.
backward
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录