Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
e4ccc4d7
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4ccc4d7
编写于
8月 19, 2020
作者:
S
sunxl1988
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=dygraph update eval
上级
37f7e6ee
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
90 addition
and
172 deletion
+90
-172
configs/architechture/cascade_mask_rcnn.yml
configs/architechture/cascade_mask_rcnn.yml
+0
-91
configs/base/mask_rcnn_r50_1x.yml
configs/base/mask_rcnn_r50_1x.yml
+2
-2
configs/base/mask_rcnn_r50_fpn_1x.yml
configs/base/mask_rcnn_r50_fpn_1x.yml
+2
-2
configs/example/mask_rcnn_r50_1x.yml
configs/example/mask_rcnn_r50_1x.yml
+2
-0
configs/example/mask_rcnn_r50_fpn_1x.yml
configs/example/mask_rcnn_r50_fpn_1x.yml
+1
-0
configs/reader/faster_rcnn.yml
configs/reader/faster_rcnn.yml
+5
-3
configs/reader/mask_rcnn.yml
configs/reader/mask_rcnn.yml
+7
-5
ppdet/core/workspace.py
ppdet/core/workspace.py
+37
-51
ppdet/data/loader.py
ppdet/data/loader.py
+1
-2
ppdet/data/source/coco.py
ppdet/data/source/coco.py
+0
-1
ppdet/modeling/head/bbox_head.py
ppdet/modeling/head/bbox_head.py
+28
-5
ppdet/utils/checkpoint.py
ppdet/utils/checkpoint.py
+1
-1
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+3
-3
tools/eval.py
tools/eval.py
+0
-5
tools/train.py
tools/train.py
+1
-1
未找到文件。
configs/architechture/cascade_mask_rcnn.yml
已删除
100644 → 0
浏览文件 @
37f7e6ee
# Model Achitecture
CascadeRCNN
:
# model anchor info flow
anchor
:
AnchorRPN
proposal
:
Proposal
mask
:
Mask
# model feat info flow
backbone
:
ResNet
rpn_head
:
RPNHead
bbox_head
:
BBoxHead
mask_head
:
MaskHead
ResNet
:
norm_type
:
'
affine'
depth
:
50
freeze_at
:
'
res2'
RPNHead
:
rpn_feat
:
name
:
RPNFeat
feat_in
:
1024
feat_out
:
1024
anchor_per_position
:
15
BBoxHead
:
bbox_feat
:
name
:
BBoxFeat
feat_in
:
1024
feat_out
:
512
roi_extractor
:
resolution
:
14
sampling_ratio
:
0
spatial_scale
:
0.0625
extractor_type
:
'
RoIAlign'
MaskHead
:
mask_feat
:
name
:
MaskFeat
feat_in
:
2048
feat_out
:
256
feat_in
:
256
resolution
:
14
AnchorRPN
:
anchor_generator
:
name
:
AnchorGeneratorRPN
anchor_sizes
:
[
32
,
64
,
128
,
256
,
512
]
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
stride
:
[
16.0
,
16.0
]
variance
:
[
1.0
,
1.0
,
1.0
,
1.0
]
anchor_target_generator
:
name
:
AnchorTargetGeneratorRPN
batch_size_per_im
:
256
fg_fraction
:
0.5
negative_overlap
:
0.3
positive_overlap
:
0.7
straddle_thresh
:
0.0
Proposal
:
proposal_generator
:
name
:
ProposalGenerator
min_size
:
0.0
nms_thresh
:
0.7
train_pre_nms_top_n
:
2000
train_post_nms_top_n
:
2000
infer_pre_nms_top_n
:
2000
infer_post_nms_top_n
:
2000
return_rois_num
:
True
proposal_target_generator
:
name
:
ProposalTargetGenerator
batch_size_per_im
:
512
bbox_reg_weights
:
[[
0.1
,
0.1
,
0.2
,
0.2
],[
0.05
,
0.05
,
0.1
,
0.1
],[
0.333333
,
0.333333
,
0.666666
,
0.666666
]]
bg_thresh_hi
:
[
0.5
,
0.6
,
0.7
]
bg_thresh_lo
:
[
0.0
,
0.0
,
0.0
]
fg_thresh
:
[
0.5
,
0.6
,
0.7
]
fg_fraction
:
0.25
bbox_post_process
:
# used in infer
name
:
BBoxPostProcess
# decode -> clip -> nms
decode_clip_nms
:
name
:
DecodeClipNms
keep_top_k
:
100
score_threshold
:
0.05
nms_threshold
:
0.5
Mask
:
mask_target_generator
:
name
:
MaskTargetGenerator
resolution
:
14
mask_post_process
:
name
:
MaskPostProcess
configs/base/mask_rcnn_r50_1x.yml
浏览文件 @
e4ccc4d7
architecture
:
MaskRCNN
pretrain_weights
:
https://paddle
models.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams
pretrain_weights
:
https://paddle
-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights
:
output/mask_rcnn_r50_1x/model_final
use_gpu
:
true
epoch
:
24
...
...
@@ -9,7 +9,7 @@ log_smooth_window: 20
save_dir
:
output
metric
:
COCO
num_classes
:
81
load_static_weights
:
tru
e
load_static_weights
:
fals
e
_READER_
:
'
../reader/mask_rcnn.yml'
_ARCHITECHTURE_
:
'
../architechture/mask_rcnn.yml'
...
...
configs/base/mask_rcnn_r50_fpn_1x.yml
浏览文件 @
e4ccc4d7
architecture
:
MaskRCNN
pretrain_weights
:
https://paddle
models.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams
pretrain_weights
:
https://paddle
-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights
:
output/mask_rcnn_r50_1x/model_final
use_gpu
:
true
epoch
:
24
...
...
@@ -9,7 +9,7 @@ log_smooth_window: 20
save_dir
:
output
metric
:
COCO
num_classes
:
81
load_static_weights
:
tru
e
load_static_weights
:
fals
e
_READER_
:
'
../reader/mask_rcnn.yml'
_ARCHITECHTURE_
:
'
../architechture/mask_rcnn_fpn.yml'
...
...
configs/example/mask_rcnn_r50_1x.yml
浏览文件 @
e4ccc4d7
...
...
@@ -7,6 +7,8 @@ log_smooth_window: 20
save_dir
:
output
metric
:
COCO
num_classes
:
81
load_static_weights
:
true
log_iter
:
20
TrainReader
:
inputs_def
:
...
...
configs/example/mask_rcnn_r50_fpn_1x.yml
浏览文件 @
e4ccc4d7
...
...
@@ -8,6 +8,7 @@ save_dir: output
metric
:
COCO
num_classes
:
81
weights
:
output/mask_r50_fpn_1x/model_final.pdparams
EvalReader
:
dataset
:
dataset_dir
:
/home/ai/dataset/COCO17
configs/reader/faster_rcnn.yml
浏览文件 @
e4ccc4d7
...
...
@@ -3,9 +3,10 @@ TrainReader:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
gt_bbox'
,
'
gt_class'
,
'
is_crowd'
]
dataset
:
name
:
COCODataset
dataset_dir
:
/home/ai/dataset/COCO17/
dataset_dir
:
dataset/coco
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
RandomFlipImage
:
{
prob
:
0.5
}
...
...
@@ -24,14 +25,15 @@ EvalReader:
name
:
COCODataset
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
/home/ai/dataset/COCO17
dataset_dir
:
dataset/coco
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
NormalizeImage
:
{
is_channel_first
:
false
,
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
ResizeImage
:
{
interp
:
1
,
max_size
:
1333
,
target_size
:
800
,
use_cv2
:
true
}
-
Permute
:
{
channel_first
:
true
,
to_bgr
:
false
}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
0
,
use_padded_im_info
:
false
,
pad_gt
:
Tru
e
}
-
PadBatch
:
{
pad_to_stride
:
0
,
use_padded_im_info
:
false
,
pad_gt
:
Fals
e
}
batch_size
:
2
shuffle
:
false
drop_empty
:
false
...
...
configs/reader/mask_rcnn.yml
浏览文件 @
e4ccc4d7
worker_num
:
0
use_prefetch
:
F
alse
use_prefetch
:
f
alse
TrainReader
:
inputs_def
:
...
...
@@ -9,6 +9,7 @@ TrainReader:
dataset_dir
:
dataset/coco
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
RandomFlipImage
:
{
prob
:
0.5
,
is_mask_flip
:
true
}
...
...
@@ -16,7 +17,7 @@ TrainReader:
-
ResizeImage
:
{
target_size
:
800
,
max_size
:
1333
,
interp
:
1
,
use_cv2
:
true
}
-
Permute
:
{
to_bgr
:
false
,
channel_first
:
true
}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
T
rue
}
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
t
rue
}
batch_size
:
1
shuffle
:
true
drop_last
:
false
...
...
@@ -24,19 +25,20 @@ TrainReader:
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
im_shape'
]
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
name
:
COCODataset
dataset_dir
:
dataset/coco
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
NormalizeImage
:
{
is_channel_first
:
false
,
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
ResizeImage
:
{
interp
:
1
,
max_size
:
1333
,
target_size
:
800
,
use_cv2
:
true
}
-
Permute
:
{
channel_first
:
true
,
to_bgr
:
false
}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
Tru
e
}
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
fals
e
}
batch_size
:
1
shuffle
:
false
drop_last
:
false
...
...
@@ -45,7 +47,7 @@ EvalReader:
TestReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
im_shape'
]
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
name
:
ImageFolder
anno_path
:
annotations/instances_val2017.json
...
...
ppdet/core/workspace.py
浏览文件 @
e4ccc4d7
...
...
@@ -66,73 +66,62 @@ class AttrDict(dict):
global_config
=
AttrDict
()
READER_KEY
=
'_READER_'
def
load_config
(
file_path
):
"""
Load config from file.
cfg
=
load_cfg
(
file_path
)
if
'_BASE_'
in
cfg
.
keys
():
base_cfg
=
load_cfg
(
cfg
[
'_BASE_'
])
del
cfg
[
'_BASE_'
]
# merge cfg into base_cfg
cfg
=
merge_config
(
cfg
,
base_cfg
)
# merge cfg int global_config
merge_config
(
cfg
)
return
global_config
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
def
load_cfg
(
file_path
):
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
cfg
=
AttrDict
()
with
open
(
file_path
)
as
f
:
cfg
=
merge_config
(
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
),
cfg
)
cfg
=
load_part_cfg
(
cfg
,
file_path
)
return
cfg
if
READER_KEY
in
cfg
:
reader_cfg
=
cfg
[
READER_KEY
]
if
reader_cfg
.
startswith
(
"~"
):
reader_cfg
=
os
.
path
.
expanduser
(
reader_cfg
)
if
not
reader_cfg
.
startswith
(
'/'
):
reader_cfg
=
os
.
path
.
join
(
os
.
path
.
dirname
(
file_path
),
reader_cfg
)
with
open
(
reader_cfg
)
as
f
:
merge_config
(
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
))
del
cfg
[
READER_KEY
]
PART_KEY
=
[
'_READER_'
,
'_ARCHITECHTURE_'
,
'_OPTIMIZE_'
]
merge_config
(
cfg
)
return
global_config
def
load_part_cfg
(
cfg
,
file_path
):
for
part_k
in
PART_KEY
:
if
part_k
in
cfg
:
part_cfg
=
cfg
[
part_k
]
if
part_cfg
.
startswith
(
"~"
):
part_cfg
=
os
.
path
.
expanduser
(
part_cfg
)
if
not
part_cfg
.
startswith
(
'/'
):
part_cfg
=
os
.
path
.
join
(
os
.
path
.
dirname
(
file_path
),
part_cfg
)
def
dict_merge
(
dct
,
merge_dct
):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
with
open
(
part_cfg
)
as
f
:
merge_config
(
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
),
cfg
)
del
cfg
[
part_k
]
Args:
dct: dict onto which the merge is executed
merge_dct: dct merged into dct
Returns: dct
"""
for
k
,
v
in
merge_dct
.
items
():
if
(
k
in
dct
and
isinstance
(
dct
[
k
],
dict
)
and
isinstance
(
merge_dct
[
k
],
collections
.
Mapping
)):
dict_merge
(
dct
[
k
],
merge_dct
[
k
])
else
:
dct
[
k
]
=
merge_dct
[
k
]
return
dct
return
cfg
def
merge_config
(
config
,
another_cfg
=
None
):
"""
Merge config into global config or another_cfg.
def
merge_config
(
config
,
other_cfg
=
None
):
global
global_config
dct
=
other_cfg
if
other_cfg
is
not
None
else
global_config
return
merge_dict
(
dct
,
config
)
Args:
config (dict): Config to be merged.
Returns: global config
"""
global
global_config
dct
=
another_cfg
if
another_cfg
is
not
None
else
global_config
return
dict_merge
(
dct
,
config
)
def
merge_dict
(
dct
,
other_dct
):
for
k
,
v
in
other_dct
.
items
():
if
(
k
in
dct
and
isinstance
(
dct
[
k
],
dict
)
and
isinstance
(
other_dct
[
k
],
collections
.
Mapping
)):
merge_dict
(
dct
[
k
],
other_dct
[
k
])
else
:
dct
[
k
]
=
other_dct
[
k
]
return
dct
def
get_registered_modules
():
...
...
@@ -251,7 +240,4 @@ def create(cls_or_name, **kwargs):
kwargs
[
k
]
=
new_dict
else
:
raise
ValueError
(
"Unsupported injection type:"
,
target_key
)
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs)
return
cls
(
**
kwargs
)
ppdet/data/loader.py
浏览文件 @
e4ccc4d7
...
...
@@ -24,7 +24,6 @@ class Compose(object):
self
.
transforms_cls
=
[]
for
t
in
self
.
transforms
:
for
k
,
v
in
t
.
items
():
print
(
k
,
v
)
op_cls
=
getattr
(
from_
,
k
)
self
.
transforms_cls
.
append
(
op_cls
(
**
v
))
if
hasattr
(
op_cls
,
'num_classes'
):
...
...
@@ -131,7 +130,7 @@ class BaseDataLoader(object):
drop_empty
=
True
,
num_classes
=
81
):
# dataset
self
.
_dataset
=
dataset
#create(dataset['name'])
self
.
_dataset
=
dataset
self
.
_dataset
.
parse_dataset
()
# out fields
self
.
_fields
=
copy
.
deepcopy
(
inputs_def
[
...
...
ppdet/data/source/coco.py
浏览文件 @
e4ccc4d7
...
...
@@ -20,7 +20,6 @@ class COCODataset(DetDataset):
with_background
,
sample_num
)
self
.
load_image_only
=
False
self
.
load_semantic
=
False
#self.parse_dataset()
def
parse_dataset
(
self
):
anno_path
=
os
.
path
.
join
(
self
.
dataset_dir
,
self
.
anno_path
)
...
...
ppdet/modeling/head/bbox_head.py
浏览文件 @
e4ccc4d7
...
...
@@ -5,15 +5,34 @@ from paddle.fluid.initializer import Normal, Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
ppdet.core.workspace
import
register
from
..backbone.name_adapter
import
NameAdapter
from
..backbone.resnet
import
Blocks
@
register
class
TwoFCHead
(
Layer
):
class
Res5Feat
(
Layer
):
def
__init__
(
self
,
feat_in
=
1024
,
feat_out
=
512
):
super
(
Res5Feat
,
self
).
__init__
()
na
=
NameAdapter
(
self
)
self
.
res5_conv
=
[]
self
.
res5
=
self
.
add_sublayer
(
'res5_roi_feat'
,
Blocks
(
feat_in
,
feat_out
,
count
=
3
,
name_adapter
=
na
,
stage_num
=
5
))
self
.
feat_out
=
feat_out
*
4
def
forward
(
self
,
roi_feat
,
stage
=
0
):
y
=
self
.
res5
(
roi_feat
)
return
y
@
register
class
TwoFCFeat
(
Layer
):
__shared__
=
[
'num_stages'
]
def
__init__
(
self
,
in_dim
=
256
,
mlp_dim
=
1024
,
resolution
=
7
,
num_stages
=
1
):
super
(
TwoFC
Head
,
self
).
__init__
()
super
(
TwoFC
Feat
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
mlp_dim
=
mlp_dim
self
.
num_stages
=
num_stages
...
...
@@ -94,6 +113,7 @@ class BBoxHead(Layer):
self
.
bbox_score_list
=
[]
self
.
bbox_delta_list
=
[]
self
.
with_pool
=
with_pool
for
stage
in
range
(
num_stages
):
score_name
=
'bbox_score_{}'
.
format
(
stage
)
delta_name
=
'bbox_delta_{}'
.
format
(
stage
)
...
...
@@ -132,11 +152,14 @@ class BBoxHead(Layer):
def
forward
(
self
,
body_feats
,
rois
,
spatial_scale
,
stage
=
0
):
bbox_feat
=
self
.
bbox_feat
(
body_feats
,
rois
,
spatial_scale
,
stage
)
if
self
.
with_pool
:
bbox_feat
=
fluid
.
layers
.
pool2d
(
bbox_feat
_
=
fluid
.
layers
.
pool2d
(
bbox_feat
,
pool_type
=
'avg'
,
global_pooling
=
True
)
bbox_feat_
=
fluid
.
layers
.
squeeze
(
bbox_feat_
,
axes
=
[
2
,
3
])
else
:
bbox_feat_
=
bbox_feat
bbox_head_out
=
[]
scores
=
self
.
bbox_score_list
[
stage
](
bbox_feat
)
deltas
=
self
.
bbox_delta_list
[
stage
](
bbox_feat
)
scores
=
self
.
bbox_score_list
[
stage
](
bbox_feat
_
)
deltas
=
self
.
bbox_delta_list
[
stage
](
bbox_feat
_
)
bbox_head_out
.
append
((
scores
,
deltas
))
return
bbox_feat
,
bbox_head_out
...
...
ppdet/utils/checkpoint.py
浏览文件 @
e4ccc4d7
...
...
@@ -59,8 +59,8 @@ def load_dygraph_ckpt(model,
assert
os
.
path
.
exists
(
ckpt
),
"Path {} does not exist."
.
format
(
ckpt
)
if
load_static_weights
:
pre_state_dict
=
fluid
.
load_program_state
(
ckpt
)
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
param_state_dict
=
{}
for
key
in
model_dict
.
keys
():
weight_name
=
model_dict
[
key
].
name
if
weight_name
in
pre_state_dict
.
keys
():
...
...
ppdet/utils/eval_utils.py
浏览文件 @
e4ccc4d7
...
...
@@ -11,7 +11,7 @@ def json_eval_results(metric, json_directory=None, dataset=None):
"""
assert
metric
==
'COCO'
from
ppdet.utils.coco_eval
import
cocoapi_eval
anno_file
=
dataset
.
get_anno
(
)
anno_file
=
os
.
path
.
join
(
dataset
[
'dataset_dir'
],
dataset
[
'anno_path'
]
)
json_file_list
=
[
'proposal.json'
,
'bbox.json'
,
'mask.json'
]
if
json_directory
:
assert
os
.
path
.
exists
(
...
...
@@ -36,10 +36,10 @@ def coco_eval_results(outs_res=None, include_mask=False, dataset=None):
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
from
ppdet.py_op.post_process
import
get_det_res
,
get_seg_res
anno_file
=
os
.
path
.
join
(
dataset
.
dataset_dir
,
dataset
.
anno_path
)
anno_file
=
os
.
path
.
join
(
dataset
[
'dataset_dir'
],
dataset
[
'anno_path'
]
)
cocoGt
=
COCO
(
anno_file
)
catid
=
{
i
+
dataset
.
with_background
:
v
i
+
1
if
dataset
[
'with_background'
]
else
0
:
v
for
i
,
v
in
enumerate
(
cocoGt
.
getCatIds
())
}
...
...
tools/eval.py
浏览文件 @
e4ccc4d7
...
...
@@ -41,11 +41,6 @@ def parse_args():
def
run
(
FLAGS
,
cfg
,
place
):
if
FLAGS
.
use_gpu
:
devices_num
=
1
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
# Data
eval_loader
,
_
=
create
(
'EvalReader'
)(
cfg
[
'worker_num'
],
place
)
...
...
tools/train.py
浏览文件 @
e4ccc4d7
...
...
@@ -166,7 +166,7 @@ def run(FLAGS, cfg, place):
logger
.
info
(
strs
)
# Save Stage
if
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
:
if
Parallel
Env
().
local_rank
==
0
:
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_name
=
str
(
e_id
+
1
)
if
e_id
+
1
!=
int
(
cfg
.
epoch
)
else
"model_final"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录