Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
of-maskrcnn-benchmark
提交
b0ff535b
O
of-maskrcnn-benchmark
项目概览
Oneflow-Inc
/
of-maskrcnn-benchmark
10 个月 前同步成功
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
of-maskrcnn-benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b0ff535b
编写于
3月 20, 2019
作者:
S
ScXfjiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dump momentum_buffer for each iteration
上级
2fe865bf
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
17 addition
and
12 deletion
+17
-12
maskrcnn_benchmark/engine/trainer.py
maskrcnn_benchmark/engine/trainer.py
+15
-1
tools/train_net.py
tools/train_net.py
+1
-10
train.sh
train.sh
+1
-1
未找到文件。
maskrcnn_benchmark/engine/trainer.py
浏览文件 @
b0ff535b
...
...
@@ -18,6 +18,8 @@ import os
from
functools
import
partial
import
pickle
as
pkl
def
reduce_loss_dict
(
loss_dict
):
"""
Reduce the loss dictionary from all processes so that process with rank
...
...
@@ -44,6 +46,7 @@ def reduce_loss_dict(loss_dict):
def
do_train
(
cfg
,
model
,
data_loader
,
optimizer
,
...
...
@@ -103,6 +106,17 @@ def do_train(
losses
.
backward
()
optimizer
.
step
()
if
not
os
.
path
.
exists
(
"model_name2momentum_buffer/"
):
os
.
makedirs
(
"model_name2momentum_buffer/"
)
state_dict
=
optimizer
.
state_dict
()
model_name2momentum_buffer
=
{}
for
key
,
value
in
model
.
named_parameters
():
if
value
.
requires_grad
:
momentum_buffer
=
state_dict
[
'state'
][
id
(
value
)][
'momentum_buffer'
].
cpu
().
detach
().
numpy
()
model_name2momentum_buffer
[
key
]
=
momentum_buffer
pkl
.
dump
(
model_name2momentum_buffer
,
open
(
"model_name2momentum_buffer/"
+
os
.
path
.
basename
(
cfg
.
MODEL
.
WEIGHT
)
\
+
"-iteration-"
+
str
(
iteration
)
+
'-model_name2momentum_buffer.pkl'
,
'w'
))
batch_time
=
time
.
time
()
-
end
end
=
time
.
time
()
meters
.
update
(
time
=
batch_time
,
data
=
data_time
)
...
...
@@ -110,7 +124,7 @@ def do_train(
eta_seconds
=
meters
.
time
.
global_avg
*
(
max_iter
-
iteration
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
if
iteration
%
20
==
0
or
iteration
==
max_iter
:
if
iteration
%
1
==
0
or
iteration
==
max_iter
:
logger
.
info
(
meters
.
delimiter
.
join
(
[
...
...
tools/train_net.py
浏览文件 @
b0ff535b
...
...
@@ -25,8 +25,6 @@ from maskrcnn_benchmark.utils.imports import import_file
from
maskrcnn_benchmark.utils.logger
import
setup_logger
from
maskrcnn_benchmark.utils.miscellaneous
import
mkdir
import
pickle
as
pkl
def
train
(
cfg
,
local_rank
,
distributed
):
model
=
build_detection_model
(
cfg
)
device
=
torch
.
device
(
cfg
.
MODEL
.
DEVICE
)
...
...
@@ -53,14 +51,6 @@ def train(cfg, local_rank, distributed):
)
extra_checkpoint_data
=
checkpointer
.
load
(
cfg
.
MODEL
.
WEIGHT
)
arguments
.
update
(
extra_checkpoint_data
)
state_dict
=
optimizer
.
state_dict
()
model_name2momentum_buffer
=
{}
for
key
,
value
in
model
.
named_parameters
():
if
value
.
requires_grad
:
momentum_buffer
=
state_dict
[
'state'
][
id
(
value
)][
'momentum_buffer'
].
cpu
().
detach
().
numpy
()
model_name2momentum_buffer
[
key
]
=
momentum_buffer
pkl
.
dump
(
model_name2momentum_buffer
,
open
(
os
.
path
.
basename
(
cfg
.
MODEL
.
WEIGHT
)
+
'.model_name2momentum_buffer.pkl'
,
'w'
))
data_loader
=
make_data_loader
(
cfg
,
...
...
@@ -73,6 +63,7 @@ def train(cfg, local_rank, distributed):
arguments
[
"fake_image"
]
=
cfg
.
DATALOADER
.
FAKE_IMAGE_DATA_PATH
do_train
(
cfg
,
model
,
data_loader
,
optimizer
,
...
...
train.sh
浏览文件 @
b0ff535b
...
...
@@ -6,7 +6,7 @@ rm -f last_checkpoint
rm
-f
model_final.pth
rm
-f
log.txt
rm
-f
model_0090000.pth
rm
-
f
e2e_mask_rcnn_R_50_FPN_1x.pth.model_name2momentum_buffer.pkl
rm
-
rf
model_name2momentum_buffer
CUDA_VISIBLE_DEVICES
=
1
\
python ./tools/train_net.py
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录