Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
c25000a2
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c25000a2
编写于
3月 06, 2019
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine mixup score and label smooth
上级
6c19ddfd
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
36 addition
and
32 deletion
+36
-32
fluid/PaddleCV/yolov3/box_utils.py
fluid/PaddleCV/yolov3/box_utils.py
+3
-4
fluid/PaddleCV/yolov3/data_utils.py
fluid/PaddleCV/yolov3/data_utils.py
+1
-2
fluid/PaddleCV/yolov3/infer.py
fluid/PaddleCV/yolov3/infer.py
+1
-1
fluid/PaddleCV/yolov3/models.py
fluid/PaddleCV/yolov3/models.py
+2
-2
fluid/PaddleCV/yolov3/reader.py
fluid/PaddleCV/yolov3/reader.py
+23
-17
fluid/PaddleCV/yolov3/train.py
fluid/PaddleCV/yolov3/train.py
+3
-4
fluid/PaddleCV/yolov3/utility.py
fluid/PaddleCV/yolov3/utility.py
+3
-2
未找到文件。
fluid/PaddleCV/yolov3/box_utils.py
浏览文件 @
c25000a2
...
@@ -123,16 +123,15 @@ def box_crop(boxes, labels, scores, crop, img_shape):
...
@@ -123,16 +123,15 @@ def box_crop(boxes, labels, scores, crop, img_shape):
boxes
[:,
1
],
boxes
[:,
3
]
=
(
boxes
[:,
1
]
-
boxes
[:,
3
]
/
2
)
*
im_h
,
(
boxes
[:,
1
]
+
boxes
[:,
3
]
/
2
)
*
im_h
boxes
[:,
1
],
boxes
[:,
3
]
=
(
boxes
[:,
1
]
-
boxes
[:,
3
]
/
2
)
*
im_h
,
(
boxes
[:,
1
]
+
boxes
[:,
3
]
/
2
)
*
im_h
crop_box
=
np
.
array
([
x
,
y
,
x
+
w
,
y
+
h
])
crop_box
=
np
.
array
([
x
,
y
,
x
+
w
,
y
+
h
])
#
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
centers
=
(
boxes
[:,
:
2
]
+
boxes
[:,
2
:])
/
2.0
#
mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1)
mask
=
np
.
logical_and
(
crop_box
[:
2
]
<=
centers
,
centers
<=
crop_box
[
2
:]).
all
(
axis
=
1
)
boxes
[:,
:
2
]
=
np
.
maximum
(
boxes
[:,
:
2
],
crop_box
[:
2
])
boxes
[:,
:
2
]
=
np
.
maximum
(
boxes
[:,
:
2
],
crop_box
[:
2
])
boxes
[:,
2
:]
=
np
.
minimum
(
boxes
[:,
2
:],
crop_box
[
2
:])
boxes
[:,
2
:]
=
np
.
minimum
(
boxes
[:,
2
:],
crop_box
[
2
:])
boxes
[:,
:
2
]
-=
crop_box
[:
2
]
boxes
[:,
:
2
]
-=
crop_box
[:
2
]
boxes
[:,
2
:]
-=
crop_box
[:
2
]
boxes
[:,
2
:]
-=
crop_box
[:
2
]
# mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
mask
=
np
.
logical_and
(
mask
,
(
boxes
[:,
:
2
]
<
boxes
[:,
2
:]).
all
(
axis
=
1
))
mask
=
(
boxes
[:,
:
2
]
<
boxes
[:,
2
:]).
all
(
axis
=
1
)
boxes
=
boxes
*
np
.
expand_dims
(
mask
.
astype
(
'float32'
),
axis
=
1
)
boxes
=
boxes
*
np
.
expand_dims
(
mask
.
astype
(
'float32'
),
axis
=
1
)
labels
=
labels
*
mask
.
astype
(
'float32'
)
labels
=
labels
*
mask
.
astype
(
'float32'
)
scores
=
scores
*
mask
.
astype
(
'float32'
)
scores
=
scores
*
mask
.
astype
(
'float32'
)
...
...
fluid/PaddleCV/yolov3/data_utils.py
浏览文件 @
c25000a2
...
@@ -73,12 +73,11 @@ class GeneratorEnqueuer(object):
...
@@ -73,12 +73,11 @@ class GeneratorEnqueuer(object):
size
=
self
.
random_sizes
[
queue_idx
]
size
=
self
.
random_sizes
[
queue_idx
]
for
g
in
generator_output
:
for
g
in
generator_output
:
g
[
0
]
=
g
[
0
].
transpose
((
1
,
2
,
0
))
g
[
0
]
=
g
[
0
].
transpose
((
1
,
2
,
0
))
g
[
0
]
=
image_utils
.
random_interp
(
g
[
0
],
size
,
cv2
.
INTER_LINEAR
)
g
[
0
]
=
image_utils
.
random_interp
(
g
[
0
],
size
)
g
[
0
]
=
g
[
0
].
transpose
((
2
,
0
,
1
))
g
[
0
]
=
g
[
0
].
transpose
((
2
,
0
,
1
))
try
:
try
:
self
.
queues
[
queue_idx
].
put_nowait
(
generator_output
)
self
.
queues
[
queue_idx
].
put_nowait
(
generator_output
)
except
:
except
:
timw
.
sleep
(
self
.
wait_time
)
continue
continue
else
:
else
:
break
break
...
...
fluid/PaddleCV/yolov3/infer.py
浏览文件 @
c25000a2
...
@@ -35,7 +35,6 @@ def infer():
...
@@ -35,7 +35,6 @@ def infer():
# yapf: enable
# yapf: enable
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
model
.
feeds
())
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
model
.
feeds
())
fetch_list
=
[
outputs
]
fetch_list
=
[
outputs
]
# fetch_list = outputs
image_names
=
[]
image_names
=
[]
if
cfg
.
image_name
is
not
None
:
if
cfg
.
image_name
is
not
None
:
image_names
.
append
(
cfg
.
image_name
)
image_names
.
append
(
cfg
.
image_name
)
...
@@ -55,6 +54,7 @@ def infer():
...
@@ -55,6 +54,7 @@ def infer():
bboxes
=
np
.
array
(
outputs
[
0
])
bboxes
=
np
.
array
(
outputs
[
0
])
if
bboxes
.
shape
[
1
]
!=
6
:
if
bboxes
.
shape
[
1
]
!=
6
:
print
(
"No object found in {}"
.
format
(
image_name
))
print
(
"No object found in {}"
.
format
(
image_name
))
continue
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
labels
=
bboxes
[:,
0
].
astype
(
'int32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
scores
=
bboxes
[:,
1
].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
boxes
=
bboxes
[:,
2
:].
astype
(
'float32'
)
...
...
fluid/PaddleCV/yolov3/models.py
浏览文件 @
c25000a2
...
@@ -206,13 +206,13 @@ class YOLOv3(object):
...
@@ -206,13 +206,13 @@ class YOLOv3(object):
x
=
out
,
x
=
out
,
gtbox
=
self
.
gtbox
,
gtbox
=
self
.
gtbox
,
gtlabel
=
self
.
gtlabel
,
gtlabel
=
self
.
gtlabel
,
#
gtscore=self.gtscore,
gtscore
=
self
.
gtscore
,
anchors
=
anchors
,
anchors
=
anchors
,
anchor_mask
=
anchor_mask
,
anchor_mask
=
anchor_mask
,
class_num
=
class_num
,
class_num
=
class_num
,
ignore_thresh
=
ignore_thresh
,
ignore_thresh
=
ignore_thresh
,
downsample_ratio
=
self
.
downsample
,
downsample_ratio
=
self
.
downsample
,
# use_label_smooth=False
,
use_label_smooth
=
cfg
.
label_smooth
,
name
=
"yolo_loss"
+
str
(
i
))
name
=
"yolo_loss"
+
str
(
i
))
self
.
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
self
.
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
else
:
else
:
...
...
fluid/PaddleCV/yolov3/reader.py
浏览文件 @
c25000a2
...
@@ -38,10 +38,6 @@ class DataSetReader(object):
...
@@ -38,10 +38,6 @@ class DataSetReader(object):
self
.
has_parsed_categpry
=
False
self
.
has_parsed_categpry
=
False
def
_parse_dataset_dir
(
self
,
mode
):
def
_parse_dataset_dir
(
self
,
mode
):
# cfg.data_dir = "dataset/coco"
# cfg.train_file_list = 'annotations/instances_val2017.json'
# cfg.train_data_dir = 'val2017'
# cfg.dataset = "coco2017"
if
'coco2014'
in
cfg
.
dataset
:
if
'coco2014'
in
cfg
.
dataset
:
cfg
.
train_file_list
=
'annotations/instances_train2014.json'
cfg
.
train_file_list
=
'annotations/instances_train2014.json'
cfg
.
train_data_dir
=
'train2014'
cfg
.
train_data_dir
=
'train2014'
...
@@ -115,7 +111,6 @@ class DataSetReader(object):
...
@@ -115,7 +111,6 @@ class DataSetReader(object):
image_ids
=
self
.
COCO
.
getImgIds
()
image_ids
=
self
.
COCO
.
getImgIds
()
image_ids
.
sort
()
image_ids
.
sort
()
imgs
=
copy
.
deepcopy
(
self
.
COCO
.
loadImgs
(
image_ids
))
imgs
=
copy
.
deepcopy
(
self
.
COCO
.
loadImgs
(
image_ids
))
# imgs = imgs[-8:]
for
img
in
imgs
:
for
img
in
imgs
:
img
[
'image'
]
=
os
.
path
.
join
(
self
.
img_dir
,
img
[
'file_name'
])
img
[
'image'
]
=
os
.
path
.
join
(
self
.
img_dir
,
img
[
'file_name'
])
assert
os
.
path
.
exists
(
img
[
'image'
]),
\
assert
os
.
path
.
exists
(
img
[
'image'
]),
\
...
@@ -141,7 +136,7 @@ class DataSetReader(object):
...
@@ -141,7 +136,7 @@ class DataSetReader(object):
else
:
else
:
return
self
.
_parse_images
(
is_train
=
(
mode
==
'train'
))
return
self
.
_parse_images
(
is_train
=
(
mode
==
'train'
))
def
get_reader
(
self
,
mode
,
size
=
416
,
batch_size
=
None
,
shuffle
=
False
,
random_shape
_iter
=
0
,
random_sizes
=
[],
image
=
None
):
def
get_reader
(
self
,
mode
,
size
=
416
,
batch_size
=
None
,
shuffle
=
False
,
mixup
_iter
=
0
,
random_sizes
=
[],
image
=
None
):
assert
mode
in
[
'train'
,
'test'
,
'infer'
],
"Unknow mode type!"
assert
mode
in
[
'train'
,
'test'
,
'infer'
],
"Unknow mode type!"
if
mode
!=
'infer'
:
if
mode
!=
'infer'
:
assert
batch_size
is
not
None
,
"batch size connot be None in mode {}"
.
format
(
mode
)
assert
batch_size
is
not
None
,
"batch size connot be None in mode {}"
.
format
(
mode
)
...
@@ -172,6 +167,16 @@ class DataSetReader(object):
...
@@ -172,6 +167,16 @@ class DataSetReader(object):
gt_labels
=
img
[
'gt_labels'
].
copy
()
gt_labels
=
img
[
'gt_labels'
].
copy
()
gt_scores
=
np
.
ones_like
(
gt_labels
)
gt_scores
=
np
.
ones_like
(
gt_labels
)
if
mixup_img
:
mixup_im
=
cv2
.
imread
(
mixup_img
[
'image'
])
mixup_im
=
cv2
.
cvtColor
(
mixup_im
,
cv2
.
COLOR_BGR2RGB
)
mixup_gt_boxes
=
np
.
array
(
mixup_img
[
'gt_boxes'
]).
copy
()
mixup_gt_labels
=
np
.
array
(
mixup_img
[
'gt_labels'
]).
copy
()
mixup_gt_scores
=
np
.
ones_like
(
mixup_gt_labels
)
im
,
gt_boxes
,
gt_labels
,
gt_scores
=
image_utils
.
image_mixup
(
im
,
gt_boxes
,
\
gt_labels
,
gt_scores
,
mixup_im
,
mixup_gt_boxes
,
mixup_gt_labels
,
\
mixup_gt_scores
)
im
,
gt_boxes
,
gt_labels
,
gt_scores
=
image_utils
.
image_augment
(
im
,
gt_boxes
,
gt_labels
,
gt_scores
,
size
,
mean
)
im
,
gt_boxes
,
gt_labels
,
gt_scores
=
image_utils
.
image_augment
(
im
,
gt_boxes
,
gt_labels
,
gt_scores
,
size
,
mean
)
mean
=
np
.
array
(
mean
).
reshape
((
1
,
1
,
-
1
))
mean
=
np
.
array
(
mean
).
reshape
((
1
,
1
,
-
1
))
...
@@ -186,6 +191,14 @@ class DataSetReader(object):
...
@@ -186,6 +191,14 @@ class DataSetReader(object):
return
np
.
random
.
choice
(
random_sizes
)
return
np
.
random
.
choice
(
random_sizes
)
return
size
return
size
def
get_mixup_img
(
imgs
,
mixup_iter
,
total_iter
,
read_cnt
):
if
total_iter
>=
mixup_iter
:
return
None
mixup_idx
=
np
.
random
.
randint
(
1
,
len
(
imgs
))
mixup_img
=
imgs
[(
read_cnt
+
mixup_idx
)
%
len
(
imgs
)]
return
mixup_img
def
reader
():
def
reader
():
if
mode
==
'train'
:
if
mode
==
'train'
:
imgs
=
self
.
_parse_images_by_mode
(
mode
)
imgs
=
self
.
_parse_images_by_mode
(
mode
)
...
@@ -195,25 +208,21 @@ class DataSetReader(object):
...
@@ -195,25 +208,21 @@ class DataSetReader(object):
total_iter
=
0
total_iter
=
0
batch_out
=
[]
batch_out
=
[]
img_size
=
get_img_size
(
size
,
random_sizes
)
img_size
=
get_img_size
(
size
,
random_sizes
)
# img_ids = []
while
True
:
while
True
:
img
=
imgs
[
read_cnt
%
len
(
imgs
)]
img
=
imgs
[
read_cnt
%
len
(
imgs
)]
mixup_img
=
None
mixup_img
=
get_mixup_img
(
imgs
,
mixup_iter
,
total_iter
,
read_cnt
)
read_cnt
+=
1
read_cnt
+=
1
if
read_cnt
%
len
(
imgs
)
==
0
and
shuffle
:
if
read_cnt
%
len
(
imgs
)
==
0
and
shuffle
:
np
.
random
.
shuffle
(
imgs
)
np
.
random
.
shuffle
(
imgs
)
im
,
gt_boxes
,
gt_labels
,
gt_scores
=
img_reader_with_augment
(
img
,
img_size
,
cfg
.
pixel_means
,
cfg
.
pixel_stds
,
mixup_img
)
im
,
gt_boxes
,
gt_labels
,
gt_scores
=
img_reader_with_augment
(
img
,
img_size
,
cfg
.
pixel_means
,
cfg
.
pixel_stds
,
mixup_img
)
batch_out
.
append
([
im
,
gt_boxes
,
gt_labels
,
gt_scores
])
batch_out
.
append
([
im
,
gt_boxes
,
gt_labels
,
gt_scores
])
# img_ids.append((img['id'], mixup_img['id'] if mixup_img else -1))
if
len
(
batch_out
)
==
batch_size
:
if
len
(
batch_out
)
==
batch_size
:
# print("img_ids: ", img_ids)
yield
batch_out
yield
batch_out
batch_out
=
[]
batch_out
=
[]
total_iter
+=
1
total_iter
+=
1
if
total_iter
%
10
==
0
:
if
total_iter
%
10
==
0
:
img_size
=
get_img_size
(
size
,
random_sizes
)
img_size
=
get_img_size
(
size
,
random_sizes
)
# img_ids = []
elif
mode
==
'test'
:
elif
mode
==
'test'
:
imgs
=
self
.
_parse_images_by_mode
(
mode
)
imgs
=
self
.
_parse_images_by_mode
(
mode
)
...
@@ -242,14 +251,14 @@ dsr = DataSetReader()
...
@@ -242,14 +251,14 @@ dsr = DataSetReader()
def
train
(
size
=
416
,
def
train
(
size
=
416
,
batch_size
=
64
,
batch_size
=
64
,
shuffle
=
True
,
shuffle
=
True
,
random_shape
_iter
=
0
,
mixup
_iter
=
0
,
random_sizes
=
[],
random_sizes
=
[],
interval
=
10
,
interval
=
10
,
pyreader_num
=
1
,
pyreader_num
=
1
,
num_workers
=
16
,
num_workers
=
16
,
max_queue
=
32
,
max_queue
=
32
,
use_multiprocessing
=
True
):
use_multiprocessing
=
True
):
generator
=
dsr
.
get_reader
(
'train'
,
size
,
batch_size
,
shuffle
,
random_shape_iter
,
random_sizes
)
generator
=
dsr
.
get_reader
(
'train'
,
size
,
batch_size
,
shuffle
,
int
(
mixup_iter
/
pyreader_num
)
,
random_sizes
)
if
not
use_multiprocessing
:
if
not
use_multiprocessing
:
return
generator
return
generator
...
@@ -267,7 +276,6 @@ def train(size=416,
...
@@ -267,7 +276,6 @@ def train(size=416,
generator_out
=
None
generator_out
=
None
np
.
random
.
seed
(
1000
)
np
.
random
.
seed
(
1000
)
intervals
=
pyreader_num
*
interval
intervals
=
pyreader_num
*
interval
total_random_iter
=
pyreader_num
*
random_shape_iter
cnt
=
0
cnt
=
0
idx
=
len
(
random_sizes
)
-
1
idx
=
len
(
random_sizes
)
-
1
while
True
:
while
True
:
...
@@ -282,9 +290,7 @@ def train(size=416,
...
@@ -282,9 +290,7 @@ def train(size=416,
cnt
+=
1
cnt
+=
1
if
cnt
%
intervals
==
0
:
if
cnt
%
intervals
==
0
:
idx
=
np
.
random
.
randint
(
len
(
random_sizes
))
idx
=
np
.
random
.
randint
(
len
(
random_sizes
))
if
cnt
>=
total_random_iter
:
print
(
"Resizing: "
,
random_sizes
[
idx
])
idx
=
-
1
print
(
"Resizing: "
,
random_sizes
[
idx
])
finally
:
finally
:
if
enqueuer
is
not
None
:
if
enqueuer
is
not
None
:
enqueuer
.
stop
()
enqueuer
.
stop
()
...
...
fluid/PaddleCV/yolov3/train.py
浏览文件 @
c25000a2
...
@@ -94,13 +94,13 @@ def train():
...
@@ -94,13 +94,13 @@ def train():
if
cfg
.
random_shape
:
if
cfg
.
random_shape
:
random_sizes
=
[
32
*
i
for
i
in
range
(
10
,
20
)]
random_sizes
=
[
32
*
i
for
i
in
range
(
10
,
20
)]
random_shape_iter
=
cfg
.
max_iter
-
cfg
.
start_iter
-
cfg
.
tune
_iter
mixup_iter
=
cfg
.
max_iter
-
cfg
.
start_iter
-
cfg
.
no_mixup
_iter
if
cfg
.
use_pyreader
:
if
cfg
.
use_pyreader
:
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
hyperparams
[
'batch'
])
/
devices_num
,
shuffle
=
True
,
random_shape_iter
=
random_shape_iter
,
random_sizes
=
random_sizes
,
interval
=
10
,
pyreader_num
=
devices_num
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
hyperparams
[
'batch'
])
/
devices_num
,
shuffle
=
True
,
mixup_iter
=
mixup_iter
*
devices_num
,
random_sizes
=
random_sizes
,
interval
=
10
,
pyreader_num
=
devices_num
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
py_reader
=
model
.
py_reader
py_reader
=
model
.
py_reader
py_reader
.
decorate_paddle_reader
(
train_reader
)
py_reader
.
decorate_paddle_reader
(
train_reader
)
else
:
else
:
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
hyperparams
[
'batch'
]),
shuffle
=
True
,
random_shape_iter
=
random_shape_iter
,
random_sizes
=
random_sizes
,
interval
=
10
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
hyperparams
[
'batch'
]),
shuffle
=
True
,
mixup_iter
=
mixup_iter
,
random_sizes
=
random_sizes
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
model
.
feeds
())
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
model
.
feeds
())
def
save_model
(
postfix
):
def
save_model
(
postfix
):
...
@@ -150,7 +150,6 @@ def train():
...
@@ -150,7 +150,6 @@ def train():
snapshot_loss
=
0
snapshot_loss
=
0
snapshot_time
=
0
snapshot_time
=
0
for
iter_id
,
data
in
enumerate
(
train_reader
()):
for
iter_id
,
data
in
enumerate
(
train_reader
()):
print
(
len
(
data
),
data
[
0
][
0
].
shape
)
iter_id
+=
cfg
.
start_iter
iter_id
+=
cfg
.
start_iter
prev_start_time
=
start_time
prev_start_time
=
start_time
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
...
fluid/PaddleCV/yolov3/utility.py
浏览文件 @
c25000a2
...
@@ -114,8 +114,9 @@ def parse_args():
...
@@ -114,8 +114,9 @@ def parse_args():
# TRAIN TEST INFER
# TRAIN TEST INFER
add_arg
(
'input_size'
,
int
,
608
,
"Image input size of YOLOv3."
)
add_arg
(
'input_size'
,
int
,
608
,
"Image input size of YOLOv3."
)
add_arg
(
'random_shape'
,
bool
,
True
,
"Resize to random shape for train reader."
)
add_arg
(
'random_shape'
,
bool
,
True
,
"Resize to random shape for train reader."
)
add_arg
(
'tune_iter'
,
int
,
200
,
"Disable random shape in last N iter."
)
add_arg
(
'label_smooth'
,
bool
,
True
,
"Use label smooth in class label."
)
add_arg
(
'valid_thresh'
,
float
,
0.005
,
"Valid confidence score for NMS."
)
add_arg
(
'no_mixup_iter'
,
int
,
40000
,
"Disable mixup in last N iter."
)
add_arg
(
'valid_thresh'
,
float
,
0.01
,
"Valid confidence score for NMS."
)
add_arg
(
'nms_thresh'
,
float
,
0.45
,
"NMS threshold."
)
add_arg
(
'nms_thresh'
,
float
,
0.45
,
"NMS threshold."
)
add_arg
(
'nms_topk'
,
int
,
400
,
"The number of boxes to perform NMS."
)
add_arg
(
'nms_topk'
,
int
,
400
,
"The number of boxes to perform NMS."
)
add_arg
(
'nms_posk'
,
int
,
100
,
"The number of boxes of NMS output."
)
add_arg
(
'nms_posk'
,
int
,
100
,
"The number of boxes of NMS output."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录