Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
e4ccc4d7
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4ccc4d7
编写于
8月 19, 2020
作者:
S
sunxl1988
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=dygraph update eval
上级
37f7e6ee
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
90 addition
and
172 deletion
+90
-172
configs/architechture/cascade_mask_rcnn.yml
configs/architechture/cascade_mask_rcnn.yml
+0
-91
configs/base/mask_rcnn_r50_1x.yml
configs/base/mask_rcnn_r50_1x.yml
+2
-2
configs/base/mask_rcnn_r50_fpn_1x.yml
configs/base/mask_rcnn_r50_fpn_1x.yml
+2
-2
configs/example/mask_rcnn_r50_1x.yml
configs/example/mask_rcnn_r50_1x.yml
+2
-0
configs/example/mask_rcnn_r50_fpn_1x.yml
configs/example/mask_rcnn_r50_fpn_1x.yml
+1
-0
configs/reader/faster_rcnn.yml
configs/reader/faster_rcnn.yml
+5
-3
configs/reader/mask_rcnn.yml
configs/reader/mask_rcnn.yml
+7
-5
ppdet/core/workspace.py
ppdet/core/workspace.py
+37
-51
ppdet/data/loader.py
ppdet/data/loader.py
+1
-2
ppdet/data/source/coco.py
ppdet/data/source/coco.py
+0
-1
ppdet/modeling/head/bbox_head.py
ppdet/modeling/head/bbox_head.py
+28
-5
ppdet/utils/checkpoint.py
ppdet/utils/checkpoint.py
+1
-1
ppdet/utils/eval_utils.py
ppdet/utils/eval_utils.py
+3
-3
tools/eval.py
tools/eval.py
+0
-5
tools/train.py
tools/train.py
+1
-1
未找到文件。
configs/architechture/cascade_mask_rcnn.yml
已删除
100644 → 0
浏览文件 @
37f7e6ee
# Model Achitecture
CascadeRCNN
:
# model anchor info flow
anchor
:
AnchorRPN
proposal
:
Proposal
mask
:
Mask
# model feat info flow
backbone
:
ResNet
rpn_head
:
RPNHead
bbox_head
:
BBoxHead
mask_head
:
MaskHead
ResNet
:
norm_type
:
'
affine'
depth
:
50
freeze_at
:
'
res2'
RPNHead
:
rpn_feat
:
name
:
RPNFeat
feat_in
:
1024
feat_out
:
1024
anchor_per_position
:
15
BBoxHead
:
bbox_feat
:
name
:
BBoxFeat
feat_in
:
1024
feat_out
:
512
roi_extractor
:
resolution
:
14
sampling_ratio
:
0
spatial_scale
:
0.0625
extractor_type
:
'
RoIAlign'
MaskHead
:
mask_feat
:
name
:
MaskFeat
feat_in
:
2048
feat_out
:
256
feat_in
:
256
resolution
:
14
AnchorRPN
:
anchor_generator
:
name
:
AnchorGeneratorRPN
anchor_sizes
:
[
32
,
64
,
128
,
256
,
512
]
aspect_ratios
:
[
0.5
,
1.0
,
2.0
]
stride
:
[
16.0
,
16.0
]
variance
:
[
1.0
,
1.0
,
1.0
,
1.0
]
anchor_target_generator
:
name
:
AnchorTargetGeneratorRPN
batch_size_per_im
:
256
fg_fraction
:
0.5
negative_overlap
:
0.3
positive_overlap
:
0.7
straddle_thresh
:
0.0
Proposal
:
proposal_generator
:
name
:
ProposalGenerator
min_size
:
0.0
nms_thresh
:
0.7
train_pre_nms_top_n
:
2000
train_post_nms_top_n
:
2000
infer_pre_nms_top_n
:
2000
infer_post_nms_top_n
:
2000
return_rois_num
:
True
proposal_target_generator
:
name
:
ProposalTargetGenerator
batch_size_per_im
:
512
bbox_reg_weights
:
[[
0.1
,
0.1
,
0.2
,
0.2
],[
0.05
,
0.05
,
0.1
,
0.1
],[
0.333333
,
0.333333
,
0.666666
,
0.666666
]]
bg_thresh_hi
:
[
0.5
,
0.6
,
0.7
]
bg_thresh_lo
:
[
0.0
,
0.0
,
0.0
]
fg_thresh
:
[
0.5
,
0.6
,
0.7
]
fg_fraction
:
0.25
bbox_post_process
:
# used in infer
name
:
BBoxPostProcess
# decode -> clip -> nms
decode_clip_nms
:
name
:
DecodeClipNms
keep_top_k
:
100
score_threshold
:
0.05
nms_threshold
:
0.5
Mask
:
mask_target_generator
:
name
:
MaskTargetGenerator
resolution
:
14
mask_post_process
:
name
:
MaskPostProcess
configs/base/mask_rcnn_r50_1x.yml
浏览文件 @
e4ccc4d7
architecture
:
MaskRCNN
pretrain_weights
:
https://paddle
models.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams
pretrain_weights
:
https://paddle
-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights
:
output/mask_rcnn_r50_1x/model_final
use_gpu
:
true
epoch
:
24
...
...
@@ -9,7 +9,7 @@ log_smooth_window: 20
save_dir
:
output
metric
:
COCO
num_classes
:
81
load_static_weights
:
tru
e
load_static_weights
:
fals
e
_READER_
:
'
../reader/mask_rcnn.yml'
_ARCHITECHTURE_
:
'
../architechture/mask_rcnn.yml'
...
...
configs/base/mask_rcnn_r50_fpn_1x.yml
浏览文件 @
e4ccc4d7
architecture
:
MaskRCNN
pretrain_weights
:
https://paddle
models.bj.bcebos.com/object_detection/dygraph/resnet50.pdparams
pretrain_weights
:
https://paddle
-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
weights
:
output/mask_rcnn_r50_1x/model_final
use_gpu
:
true
epoch
:
24
...
...
@@ -9,7 +9,7 @@ log_smooth_window: 20
save_dir
:
output
metric
:
COCO
num_classes
:
81
load_static_weights
:
tru
e
load_static_weights
:
fals
e
_READER_
:
'
../reader/mask_rcnn.yml'
_ARCHITECHTURE_
:
'
../architechture/mask_rcnn_fpn.yml'
...
...
configs/example/mask_rcnn_r50_1x.yml
浏览文件 @
e4ccc4d7
...
...
@@ -7,6 +7,8 @@ log_smooth_window: 20
save_dir
:
output
metric
:
COCO
num_classes
:
81
load_static_weights
:
true
log_iter
:
20
TrainReader
:
inputs_def
:
...
...
configs/example/mask_rcnn_r50_fpn_1x.yml
浏览文件 @
e4ccc4d7
...
...
@@ -8,6 +8,7 @@ save_dir: output
metric
:
COCO
num_classes
:
81
weights
:
output/mask_r50_fpn_1x/model_final.pdparams
EvalReader
:
dataset
:
dataset_dir
:
/home/ai/dataset/COCO17
configs/reader/faster_rcnn.yml
浏览文件 @
e4ccc4d7
...
...
@@ -3,9 +3,10 @@ TrainReader:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
gt_bbox'
,
'
gt_class'
,
'
is_crowd'
]
dataset
:
name
:
COCODataset
dataset_dir
:
/home/ai/dataset/COCO17/
dataset_dir
:
dataset/coco
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
RandomFlipImage
:
{
prob
:
0.5
}
...
...
@@ -24,14 +25,15 @@ EvalReader:
name
:
COCODataset
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
/home/ai/dataset/COCO17
dataset_dir
:
dataset/coco
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
NormalizeImage
:
{
is_channel_first
:
false
,
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
ResizeImage
:
{
interp
:
1
,
max_size
:
1333
,
target_size
:
800
,
use_cv2
:
true
}
-
Permute
:
{
channel_first
:
true
,
to_bgr
:
false
}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
0
,
use_padded_im_info
:
false
,
pad_gt
:
Tru
e
}
-
PadBatch
:
{
pad_to_stride
:
0
,
use_padded_im_info
:
false
,
pad_gt
:
Fals
e
}
batch_size
:
2
shuffle
:
false
drop_empty
:
false
...
...
configs/reader/mask_rcnn.yml
浏览文件 @
e4ccc4d7
worker_num
:
0
use_prefetch
:
F
alse
use_prefetch
:
f
alse
TrainReader
:
inputs_def
:
...
...
@@ -9,6 +9,7 @@ TrainReader:
dataset_dir
:
dataset/coco
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
RandomFlipImage
:
{
prob
:
0.5
,
is_mask_flip
:
true
}
...
...
@@ -16,7 +17,7 @@ TrainReader:
-
ResizeImage
:
{
target_size
:
800
,
max_size
:
1333
,
interp
:
1
,
use_cv2
:
true
}
-
Permute
:
{
to_bgr
:
false
,
channel_first
:
true
}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
T
rue
}
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
t
rue
}
batch_size
:
1
shuffle
:
true
drop_last
:
false
...
...
@@ -24,19 +25,20 @@ TrainReader:
EvalReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
im_shape'
]
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
name
:
COCODataset
dataset_dir
:
dataset/coco
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
with_background
:
true
sample_transforms
:
-
DecodeImage
:
{
to_rgb
:
true
}
-
NormalizeImage
:
{
is_channel_first
:
false
,
is_scale
:
true
,
mean
:
[
0.485
,
0.456
,
0.406
],
std
:
[
0.229
,
0.224
,
0.225
]}
-
ResizeImage
:
{
interp
:
1
,
max_size
:
1333
,
target_size
:
800
,
use_cv2
:
true
}
-
Permute
:
{
channel_first
:
true
,
to_bgr
:
false
}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
Tru
e
}
-
PadBatch
:
{
pad_to_stride
:
32
,
use_padded_im_info
:
false
,
pad_gt
:
fals
e
}
batch_size
:
1
shuffle
:
false
drop_last
:
false
...
...
@@ -45,7 +47,7 @@ EvalReader:
TestReader
:
inputs_def
:
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
,
'
im_shape'
]
fields
:
[
'
image'
,
'
im_info'
,
'
im_id'
]
dataset
:
name
:
ImageFolder
anno_path
:
annotations/instances_val2017.json
...
...
ppdet/core/workspace.py
浏览文件 @
e4ccc4d7
...
...
@@ -66,73 +66,62 @@ class AttrDict(dict):
global_config
=
AttrDict
()
READER_KEY
=
'_READER_'
def
load_config
(
file_path
):
"""
Load config from file.
cfg
=
load_cfg
(
file_path
)
if
'_BASE_'
in
cfg
.
keys
():
base_cfg
=
load_cfg
(
cfg
[
'_BASE_'
])
del
cfg
[
'_BASE_'
]
# merge cfg into base_cfg
cfg
=
merge_config
(
cfg
,
base_cfg
)
# merge cfg int global_config
merge_config
(
cfg
)
return
global_config
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
def
load_cfg
(
file_path
):
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
cfg
=
AttrDict
()
with
open
(
file_path
)
as
f
:
cfg
=
merge_config
(
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
),
cfg
)
cfg
=
load_part_cfg
(
cfg
,
file_path
)
return
cfg
if
READER_KEY
in
cfg
:
reader_cfg
=
cfg
[
READER_KEY
]
if
reader_cfg
.
startswith
(
"~"
):
reader_cfg
=
os
.
path
.
expanduser
(
reader_cfg
)
if
not
reader_cfg
.
startswith
(
'/'
):
reader_cfg
=
os
.
path
.
join
(
os
.
path
.
dirname
(
file_path
),
reader_cfg
)
with
open
(
reader_cfg
)
as
f
:
merge_config
(
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
))
del
cfg
[
READER_KEY
]
PART_KEY
=
[
'_READER_'
,
'_ARCHITECHTURE_'
,
'_OPTIMIZE_'
]
merge_config
(
cfg
)
return
global_config
def
load_part_cfg
(
cfg
,
file_path
):
for
part_k
in
PART_KEY
:
if
part_k
in
cfg
:
part_cfg
=
cfg
[
part_k
]
if
part_cfg
.
startswith
(
"~"
):
part_cfg
=
os
.
path
.
expanduser
(
part_cfg
)
if
not
part_cfg
.
startswith
(
'/'
):
part_cfg
=
os
.
path
.
join
(
os
.
path
.
dirname
(
file_path
),
part_cfg
)
def
dict_merge
(
dct
,
merge_dct
):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
with
open
(
part_cfg
)
as
f
:
merge_config
(
yaml
.
load
(
f
,
Loader
=
yaml
.
Loader
),
cfg
)
del
cfg
[
part_k
]
Args:
dct: dict onto which the merge is executed
merge_dct: dct merged into dct
Returns: dct
"""
for
k
,
v
in
merge_dct
.
items
():
if
(
k
in
dct
and
isinstance
(
dct
[
k
],
dict
)
and
isinstance
(
merge_dct
[
k
],
collections
.
Mapping
)):
dict_merge
(
dct
[
k
],
merge_dct
[
k
])
else
:
dct
[
k
]
=
merge_dct
[
k
]
return
dct
return
cfg
def
merge_config
(
config
,
another_cfg
=
None
):
"""
Merge config into global config or another_cfg.
def
merge_config
(
config
,
other_cfg
=
None
):
global
global_config
dct
=
other_cfg
if
other_cfg
is
not
None
else
global_config
return
merge_dict
(
dct
,
config
)
Args:
config (dict): Config to be merged.
Returns: global config
"""
global
global_config
dct
=
another_cfg
if
another_cfg
is
not
None
else
global_config
return
dict_merge
(
dct
,
config
)
def
merge_dict
(
dct
,
other_dct
):
for
k
,
v
in
other_dct
.
items
():
if
(
k
in
dct
and
isinstance
(
dct
[
k
],
dict
)
and
isinstance
(
other_dct
[
k
],
collections
.
Mapping
)):
merge_dict
(
dct
[
k
],
other_dct
[
k
])
else
:
dct
[
k
]
=
other_dct
[
k
]
return
dct
def
get_registered_modules
():
...
...
@@ -251,7 +240,4 @@ def create(cls_or_name, **kwargs):
kwargs
[
k
]
=
new_dict
else
:
raise
ValueError
(
"Unsupported injection type:"
,
target_key
)
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
#kwargs = copy.deepcopy(kwargs)
return
cls
(
**
kwargs
)
ppdet/data/loader.py
浏览文件 @
e4ccc4d7
...
...
@@ -24,7 +24,6 @@ class Compose(object):
self
.
transforms_cls
=
[]
for
t
in
self
.
transforms
:
for
k
,
v
in
t
.
items
():
print
(
k
,
v
)
op_cls
=
getattr
(
from_
,
k
)
self
.
transforms_cls
.
append
(
op_cls
(
**
v
))
if
hasattr
(
op_cls
,
'num_classes'
):
...
...
@@ -131,7 +130,7 @@ class BaseDataLoader(object):
drop_empty
=
True
,
num_classes
=
81
):
# dataset
self
.
_dataset
=
dataset
#create(dataset['name'])
self
.
_dataset
=
dataset
self
.
_dataset
.
parse_dataset
()
# out fields
self
.
_fields
=
copy
.
deepcopy
(
inputs_def
[
...
...
ppdet/data/source/coco.py
浏览文件 @
e4ccc4d7
...
...
@@ -20,7 +20,6 @@ class COCODataset(DetDataset):
with_background
,
sample_num
)
self
.
load_image_only
=
False
self
.
load_semantic
=
False
#self.parse_dataset()
def
parse_dataset
(
self
):
anno_path
=
os
.
path
.
join
(
self
.
dataset_dir
,
self
.
anno_path
)
...
...
ppdet/modeling/head/bbox_head.py
浏览文件 @
e4ccc4d7
...
...
@@ -5,15 +5,34 @@ from paddle.fluid.initializer import Normal, Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
ppdet.core.workspace
import
register
from
..backbone.name_adapter
import
NameAdapter
from
..backbone.resnet
import
Blocks
@
register
class
TwoFCHead
(
Layer
):
class
Res5Feat
(
Layer
):
def
__init__
(
self
,
feat_in
=
1024
,
feat_out
=
512
):
super
(
Res5Feat
,
self
).
__init__
()
na
=
NameAdapter
(
self
)
self
.
res5_conv
=
[]
self
.
res5
=
self
.
add_sublayer
(
'res5_roi_feat'
,
Blocks
(
feat_in
,
feat_out
,
count
=
3
,
name_adapter
=
na
,
stage_num
=
5
))
self
.
feat_out
=
feat_out
*
4
def
forward
(
self
,
roi_feat
,
stage
=
0
):
y
=
self
.
res5
(
roi_feat
)
return
y
@
register
class
TwoFCFeat
(
Layer
):
__shared__
=
[
'num_stages'
]
def
__init__
(
self
,
in_dim
=
256
,
mlp_dim
=
1024
,
resolution
=
7
,
num_stages
=
1
):
super
(
TwoFC
Head
,
self
).
__init__
()
super
(
TwoFC
Feat
,
self
).
__init__
()
self
.
in_dim
=
in_dim
self
.
mlp_dim
=
mlp_dim
self
.
num_stages
=
num_stages
...
...
@@ -94,6 +113,7 @@ class BBoxHead(Layer):
self
.
bbox_score_list
=
[]
self
.
bbox_delta_list
=
[]
self
.
with_pool
=
with_pool
for
stage
in
range
(
num_stages
):
score_name
=
'bbox_score_{}'
.
format
(
stage
)
delta_name
=
'bbox_delta_{}'
.
format
(
stage
)
...
...
@@ -132,11 +152,14 @@ class BBoxHead(Layer):
def
forward
(
self
,
body_feats
,
rois
,
spatial_scale
,
stage
=
0
):
bbox_feat
=
self
.
bbox_feat
(
body_feats
,
rois
,
spatial_scale
,
stage
)
if
self
.
with_pool
:
bbox_feat
=
fluid
.
layers
.
pool2d
(
bbox_feat
_
=
fluid
.
layers
.
pool2d
(
bbox_feat
,
pool_type
=
'avg'
,
global_pooling
=
True
)
bbox_feat_
=
fluid
.
layers
.
squeeze
(
bbox_feat_
,
axes
=
[
2
,
3
])
else
:
bbox_feat_
=
bbox_feat
bbox_head_out
=
[]
scores
=
self
.
bbox_score_list
[
stage
](
bbox_feat
)
deltas
=
self
.
bbox_delta_list
[
stage
](
bbox_feat
)
scores
=
self
.
bbox_score_list
[
stage
](
bbox_feat
_
)
deltas
=
self
.
bbox_delta_list
[
stage
](
bbox_feat
_
)
bbox_head_out
.
append
((
scores
,
deltas
))
return
bbox_feat
,
bbox_head_out
...
...
ppdet/utils/checkpoint.py
浏览文件 @
e4ccc4d7
...
...
@@ -59,8 +59,8 @@ def load_dygraph_ckpt(model,
assert
os
.
path
.
exists
(
ckpt
),
"Path {} does not exist."
.
format
(
ckpt
)
if
load_static_weights
:
pre_state_dict
=
fluid
.
load_program_state
(
ckpt
)
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
param_state_dict
=
{}
for
key
in
model_dict
.
keys
():
weight_name
=
model_dict
[
key
].
name
if
weight_name
in
pre_state_dict
.
keys
():
...
...
ppdet/utils/eval_utils.py
浏览文件 @
e4ccc4d7
...
...
@@ -11,7 +11,7 @@ def json_eval_results(metric, json_directory=None, dataset=None):
"""
assert
metric
==
'COCO'
from
ppdet.utils.coco_eval
import
cocoapi_eval
anno_file
=
dataset
.
get_anno
(
)
anno_file
=
os
.
path
.
join
(
dataset
[
'dataset_dir'
],
dataset
[
'anno_path'
]
)
json_file_list
=
[
'proposal.json'
,
'bbox.json'
,
'mask.json'
]
if
json_directory
:
assert
os
.
path
.
exists
(
...
...
@@ -36,10 +36,10 @@ def coco_eval_results(outs_res=None, include_mask=False, dataset=None):
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
from
ppdet.py_op.post_process
import
get_det_res
,
get_seg_res
anno_file
=
os
.
path
.
join
(
dataset
.
dataset_dir
,
dataset
.
anno_path
)
anno_file
=
os
.
path
.
join
(
dataset
[
'dataset_dir'
],
dataset
[
'anno_path'
]
)
cocoGt
=
COCO
(
anno_file
)
catid
=
{
i
+
dataset
.
with_background
:
v
i
+
1
if
dataset
[
'with_background'
]
else
0
:
v
for
i
,
v
in
enumerate
(
cocoGt
.
getCatIds
())
}
...
...
tools/eval.py
浏览文件 @
e4ccc4d7
...
...
@@ -41,11 +41,6 @@ def parse_args():
def
run
(
FLAGS
,
cfg
,
place
):
if
FLAGS
.
use_gpu
:
devices_num
=
1
else
:
devices_num
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
1
))
# Data
eval_loader
,
_
=
create
(
'EvalReader'
)(
cfg
[
'worker_num'
],
place
)
...
...
tools/train.py
浏览文件 @
e4ccc4d7
...
...
@@ -166,7 +166,7 @@ def run(FLAGS, cfg, place):
logger
.
info
(
strs
)
# Save Stage
if
fluid
.
dygraph
.
parallel
.
Env
().
local_rank
==
0
:
if
Parallel
Env
().
local_rank
==
0
:
cfg_name
=
os
.
path
.
basename
(
FLAGS
.
config
).
split
(
'.'
)[
0
]
save_name
=
str
(
e_id
+
1
)
if
e_id
+
1
!=
int
(
cfg
.
epoch
)
else
"model_final"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录