Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e0d14446
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e0d14446
编写于
8月 05, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 05, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3949 support pretrain for maskrcnn
Merge pull request !3949 from meixiaowei/master
上级
0154bdeb
512900c5
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
11 addition
and
8 deletion
+11
-8
model_zoo/official/cv/maskrcnn/README.md
model_zoo/official/cv/maskrcnn/README.md
+1
-1
model_zoo/official/cv/maskrcnn/src/config.py
model_zoo/official/cv/maskrcnn/src/config.py
+1
-0
model_zoo/official/cv/maskrcnn/src/lr_schedule.py
model_zoo/official/cv/maskrcnn/src/lr_schedule.py
+3
-3
model_zoo/official/cv/maskrcnn/train.py
model_zoo/official/cv/maskrcnn/train.py
+6
-4
未找到文件。
model_zoo/official/cv/maskrcnn/README.md
浏览文件 @
e0d14446
...
...
@@ -35,7 +35,7 @@ MaskRcnn is a two-stage target detection network,This network uses a region prop
└─train2017
```
Notice that the coco2017 dataset will be converted to MindRecord which is a data format in MindSpore. The dataset conversion may take about 4 hours.
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
...
...
model_zoo/official/cv/maskrcnn/src/config.py
浏览文件 @
e0d14446
...
...
@@ -134,6 +134,7 @@ config = ed({
"loss_scale"
:
1
,
"momentum"
:
0.91
,
"weight_decay"
:
1e-4
,
"pretrain_epoch_size"
:
0
,
"epoch_size"
:
12
,
"save_checkpoint"
:
True
,
"save_checkpoint_epochs"
:
1
,
...
...
model_zoo/official/cv/maskrcnn/src/lr_schedule.py
浏览文件 @
e0d14446
...
...
@@ -25,7 +25,7 @@ def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
learning_rate
=
(
1
+
math
.
cos
(
base
*
math
.
pi
))
/
2
*
base_lr
return
learning_rate
def
dynamic_lr
(
config
,
rank_size
=
1
):
def
dynamic_lr
(
config
,
rank_size
=
1
,
start_steps
=
0
):
"""dynamic learning rate generator"""
base_lr
=
config
.
base_lr
...
...
@@ -38,5 +38,5 @@ def dynamic_lr(config, rank_size=1):
lr
.
append
(
linear_warmup_learning_rate
(
i
,
warmup_steps
,
base_lr
,
base_lr
*
config
.
warmup_ratio
))
else
:
lr
.
append
(
a_cosine_learning_rate
(
i
,
base_lr
,
warmup_steps
,
total_steps
))
return
l
r
learning_rate
=
lr
[
start_steps
:]
return
l
earning_rate
model_zoo/official/cv/maskrcnn/train.py
浏览文件 @
e0d14446
...
...
@@ -108,13 +108,15 @@ if __name__ == '__main__':
load_path
=
args_opt
.
pre_trained
if
load_path
!=
""
:
param_dict
=
load_checkpoint
(
load_path
)
for
item
in
list
(
param_dict
.
keys
()):
if
not
(
item
.
startswith
(
'backbone'
)
or
item
.
startswith
(
'rcnn_mask'
)):
param_dict
.
pop
(
item
)
if
config
.
pretrain_epoch_size
==
0
:
for
item
in
list
(
param_dict
.
keys
()):
if
not
(
item
.
startswith
(
'backbone'
)
or
item
.
startswith
(
'rcnn_mask'
)):
param_dict
.
pop
(
item
)
load_param_into_net
(
net
,
param_dict
)
loss
=
LossNet
()
lr
=
Tensor
(
dynamic_lr
(
config
,
rank_size
=
device_num
),
mstype
.
float32
)
lr
=
Tensor
(
dynamic_lr
(
config
,
rank_size
=
device_num
,
start_steps
=
config
.
pretrain_epoch_size
*
dataset_size
),
mstype
.
float32
)
opt
=
SGD
(
params
=
net
.
trainable_params
(),
learning_rate
=
lr
,
momentum
=
config
.
momentum
,
weight_decay
=
config
.
weight_decay
,
loss_scale
=
config
.
loss_scale
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录