Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
97a365e5
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97a365e5
编写于
3月 31, 2020
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add yolov3
上级
4d22fee0
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
1472 addition
and
539 deletion
+1472
-539
model.py
model.py
+15
-10
yolov3.py
yolov3.py
+134
-494
yolov3/coco.py
yolov3/coco.py
+275
-0
yolov3/coco_metric.py
yolov3/coco_metric.py
+31
-35
yolov3/darknet.py
yolov3/darknet.py
+177
-0
yolov3/modeling.py
yolov3/modeling.py
+220
-0
yolov3/transforms.py
yolov3/transforms.py
+620
-0
未找到文件。
model.py
浏览文件 @
97a365e5
...
...
@@ -1125,19 +1125,19 @@ class Model(fluid.dygraph.Layer):
if
not
isinstance
(
test_loader
,
Iterable
):
loader
=
test_loader
()
outputs
=
None
outputs
=
[]
for
data
in
tqdm
.
tqdm
(
loader
):
if
not
fluid
.
in_dygraph_mode
():
data
=
data
[
0
]
outs
=
self
.
test
(
*
data
)
assert
len
(
data
)
==
len
(
self
.
_inputs
)
+
len
(
self
.
_labels
),
\
"data fileds number mismatch"
inputs_data
=
data
[:
len
(
self
.
_inputs
)]
if
outputs
is
None
:
outputs
=
outs
else
:
outputs
=
[
np
.
vstack
([
x
,
outs
[
i
]])
for
i
,
x
in
enumerate
(
outputs
)
]
outputs
.
append
(
self
.
test
(
inputs_data
))
# sample list to batched data
outputs
=
list
(
zip
(
*
outputs
))
self
.
_test_dataloader
=
None
if
test_loader
is
not
None
and
self
.
_adapter
.
_nranks
>
1
\
...
...
@@ -1180,11 +1180,16 @@ class Model(fluid.dygraph.Layer):
else
:
batch_size
=
data
[
0
].
shape
[
0
]
assert
len
(
data
)
==
len
(
self
.
_inputs
)
+
len
(
self
.
_labels
),
\
"data fileds number mismatch"
inputs_data
=
data
[:
len
(
self
.
_inputs
)]
labels_data
=
data
[
len
(
self
.
_inputs
):]
callbacks
.
on_batch_begin
(
mode
,
step
,
logs
)
if
mode
==
'train'
:
outs
=
self
.
train
(
*
data
)
outs
=
self
.
train
(
inputs_data
,
labels_
data
)
else
:
outs
=
self
.
eval
(
*
data
)
outs
=
self
.
eval
(
inputs_data
,
labels_
data
)
# losses
loss
=
outs
[
0
]
if
self
.
_metrics
else
outs
...
...
yolov3.py
浏览文件 @
97a365e5
...
...
@@ -18,233 +18,29 @@ from __future__ import print_function
import
argparse
import
contextlib
import
os
import
random
import
time
from
functools
import
partial
import
cv2
import
numpy
as
np
from
pycocotools.coco
import
COCO
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
model
import
Model
,
Loss
,
Input
from
resnet
import
ResNet
,
ConvBNLayer
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
logger
=
logging
.
getLogger
(
__name__
)
# XXX transfer learning
class
ResNetBackBone
(
ResNet
):
def
__init__
(
self
,
depth
=
50
):
super
(
ResNetBackBone
,
self
).
__init__
(
depth
=
depth
)
delattr
(
self
,
'fc'
)
def
forward
(
self
,
inputs
):
x
=
self
.
conv
(
inputs
)
x
=
self
.
pool
(
x
)
outputs
=
[]
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
outputs
.
append
(
x
)
return
outputs
class
YoloDetectionBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
):
super
(
YoloDetectionBlock
,
self
).
__init__
()
assert
num_filters
%
2
==
0
,
\
"num_filters {} cannot be divided by 2"
.
format
(
num_filters
)
self
.
conv0
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'leaky_relu'
)
self
.
conv1
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
*
2
,
filter_size
=
3
,
act
=
'leaky_relu'
)
self
.
conv2
=
ConvBNLayer
(
num_channels
=
num_filters
*
2
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'leaky_relu'
)
self
.
conv3
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
*
2
,
filter_size
=
3
,
act
=
'leaky_relu'
)
self
.
route
=
ConvBNLayer
(
num_channels
=
num_filters
*
2
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'leaky_relu'
)
self
.
tip
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
*
2
,
filter_size
=
3
,
act
=
'leaky_relu'
)
def
forward
(
self
,
inputs
):
out
=
self
.
conv0
(
inputs
)
out
=
self
.
conv1
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
conv3
(
out
)
route
=
self
.
route
(
out
)
tip
=
self
.
tip
(
route
)
return
route
,
tip
class
YOLOv3
(
Model
):
def
__init__
(
self
,
num_classes
=
80
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
self
.
valid_thresh
=
0.005
self
.
nms_thresh
=
0.45
self
.
nms_topk
=
400
self
.
nms_posk
=
100
self
.
draw_thresh
=
0.5
self
.
backbone
=
ResNetBackBone
()
self
.
block_outputs
=
[]
self
.
yolo_blocks
=
[]
self
.
route_blocks
=
[]
for
idx
,
num_chan
in
enumerate
([
2048
,
1280
,
640
]):
yolo_block
=
self
.
add_sublayer
(
"detecton_block_{}"
.
format
(
idx
),
YoloDetectionBlock
(
num_chan
,
num_filters
=
512
//
(
2
**
idx
)))
self
.
yolo_blocks
.
append
(
yolo_block
)
num_filters
=
len
(
self
.
anchor_masks
[
idx
])
*
(
self
.
num_classes
+
5
)
block_out
=
self
.
add_sublayer
(
"block_out_{}"
.
format
(
idx
),
Conv2D
(
num_channels
=
1024
//
(
2
**
idx
),
num_filters
=
num_filters
,
filter_size
=
1
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
)),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
))))
self
.
block_outputs
.
append
(
block_out
)
if
idx
<
2
:
route
=
self
.
add_sublayer
(
"route_{}"
.
format
(
idx
),
ConvBNLayer
(
num_channels
=
512
//
(
2
**
idx
),
num_filters
=
256
//
(
2
**
idx
),
filter_size
=
1
,
act
=
'leaky_relu'
))
self
.
route_blocks
.
append
(
route
)
def
forward
(
self
,
inputs
,
img_info
):
outputs
=
[]
boxes
=
[]
scores
=
[]
downsample
=
32
feats
=
self
.
backbone
(
inputs
)
feats
=
feats
[::
-
1
][:
len
(
self
.
anchor_masks
)]
route
=
None
for
idx
,
feat
in
enumerate
(
feats
):
if
idx
>
0
:
feat
=
fluid
.
layers
.
concat
(
input
=
[
route
,
feat
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
idx
](
feat
)
block_out
=
self
.
block_outputs
[
idx
](
tip
)
outputs
.
append
(
block_out
)
if
idx
<
2
:
route
=
self
.
route_blocks
[
idx
](
route
)
route
=
fluid
.
layers
.
resize_nearest
(
route
,
scale
=
2
)
from
paddle
import
fluid
from
paddle.fluid.optimizer
import
Momentum
from
paddle.fluid.io
import
DataLoader
if
self
.
mode
==
'test'
:
anchor_mask
=
self
.
anchor_masks
[
idx
]
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
])
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
img_shape
=
fluid
.
layers
.
slice
(
img_info
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
3
])
img_id
=
fluid
.
layers
.
slice
(
img_info
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
])
b
,
s
=
fluid
.
layers
.
yolo_box
(
x
=
block_out
,
img_size
=
img_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
num_classes
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
downsample
)
from
model
import
Model
,
Input
,
set_device
from
distributed
import
DistributedBatchSampler
from
yolov3.coco
import
*
from
yolov3.transforms
import
*
from
yolov3.modeling
import
*
from
yolov3.coco_metric
import
*
boxes
.
append
(
b
)
scores
.
append
(
fluid
.
layers
.
transpose
(
s
,
perm
=
[
0
,
2
,
1
]))
NUM_MAX_BOXES
=
50
downsample
//=
2
if
self
.
mode
!=
'test'
:
return
outputs
return
[
img_id
,
fluid
.
layers
.
multiclass_nms
(
bboxes
=
fluid
.
layers
.
concat
(
boxes
,
axis
=
1
),
scores
=
fluid
.
layers
.
concat
(
scores
,
axis
=
2
),
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)]
class
YoloLoss
(
Loss
):
def
__init__
(
self
,
num_classes
=
80
):
super
(
YoloLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
ignore_thresh
=
0.7
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
def
forward
(
self
,
outputs
,
labels
):
downsample
=
32
gt_box
,
gt_label
,
gt_score
=
labels
losses
=
[]
for
idx
,
out
in
enumerate
(
outputs
):
anchor_mask
=
self
.
anchor_masks
[
idx
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
out
,
gt_box
=
gt_box
,
gt_label
=
gt_label
,
gt_score
=
gt_score
,
anchor_mask
=
anchor_mask
,
downsample_ratio
=
downsample
,
anchors
=
self
.
anchors
,
class_num
=
self
.
num_classes
,
ignore_thresh
=
self
.
ignore_thresh
,
use_label_smooth
=
True
)
loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
losses
.
append
(
loss
)
downsample
//=
2
return
losses
def
make_optimizer
(
parameter_list
=
None
):
def
make_optimizer
(
step_per_epoch
,
parameter_list
=
None
):
base_lr
=
FLAGS
.
lr
warm_up_iter
=
4
000
warm_up_iter
=
1
000
momentum
=
0.9
weight_decay
=
5e-4
boundaries
=
[
400000
,
450000
]
boundaries
=
[
step_per_epoch
*
e
for
e
in
[
200
,
250
]
]
values
=
[
base_lr
*
(
0.1
**
i
)
for
i
in
range
(
len
(
boundaries
)
+
1
)]
learning_rate
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
boundaries
,
...
...
@@ -262,307 +58,151 @@ def make_optimizer(parameter_list=None):
return
optimizer
def
_iou_matrix
(
a
,
b
):
tl_i
=
np
.
maximum
(
a
[:,
np
.
newaxis
,
:
2
],
b
[:,
:
2
])
br_i
=
np
.
minimum
(
a
[:,
np
.
newaxis
,
2
:],
b
[:,
2
:])
area_i
=
np
.
prod
(
br_i
-
tl_i
,
axis
=
2
)
*
(
tl_i
<
br_i
).
all
(
axis
=
2
)
area_a
=
np
.
prod
(
a
[:,
2
:]
-
a
[:,
:
2
],
axis
=
1
)
area_b
=
np
.
prod
(
b
[:,
2
:]
-
b
[:,
:
2
],
axis
=
1
)
area_o
=
(
area_a
[:,
np
.
newaxis
]
+
area_b
-
area_i
)
return
area_i
/
(
area_o
+
1e-10
)
def
_crop_box_with_center_constraint
(
box
,
crop
):
cropped_box
=
box
.
copy
()
cropped_box
[:,
:
2
]
=
np
.
maximum
(
box
[:,
:
2
],
crop
[:
2
])
cropped_box
[:,
2
:]
=
np
.
minimum
(
box
[:,
2
:],
crop
[
2
:])
cropped_box
[:,
:
2
]
-=
crop
[:
2
]
cropped_box
[:,
2
:]
-=
crop
[:
2
]
centers
=
(
box
[:,
:
2
]
+
box
[:,
2
:])
/
2
valid
=
np
.
logical_and
(
crop
[:
2
]
<=
centers
,
centers
<
crop
[
2
:]).
all
(
axis
=
1
)
valid
=
np
.
logical_and
(
valid
,
(
cropped_box
[:,
:
2
]
<
cropped_box
[:,
2
:]).
all
(
axis
=
1
))
return
cropped_box
,
np
.
where
(
valid
)[
0
]
def
random_crop
(
inputs
):
aspect_ratios
=
[.
5
,
2.
]
thresholds
=
[.
0
,
.
1
,
.
3
,
.
5
,
.
7
,
.
9
]
scaling
=
[.
3
,
1.
]
img
,
img_ids
,
gt_box
,
gt_label
=
inputs
h
,
w
=
img
.
shape
[:
2
]
if
len
(
gt_box
)
==
0
:
return
inputs
np
.
random
.
shuffle
(
thresholds
)
for
thresh
in
thresholds
:
found
=
False
for
i
in
range
(
50
):
scale
=
np
.
random
.
uniform
(
*
scaling
)
min_ar
,
max_ar
=
aspect_ratios
ar
=
np
.
random
.
uniform
(
max
(
min_ar
,
scale
**
2
),
min
(
max_ar
,
scale
**-
2
))
crop_h
=
int
(
h
*
scale
/
np
.
sqrt
(
ar
))
crop_w
=
int
(
w
*
scale
*
np
.
sqrt
(
ar
))
crop_y
=
np
.
random
.
randint
(
0
,
h
-
crop_h
)
crop_x
=
np
.
random
.
randint
(
0
,
w
-
crop_w
)
crop_box
=
[
crop_x
,
crop_y
,
crop_x
+
crop_w
,
crop_y
+
crop_h
]
iou
=
_iou_matrix
(
gt_box
,
np
.
array
([
crop_box
],
dtype
=
np
.
float32
))
if
iou
.
max
()
<
thresh
:
continue
cropped_box
,
valid_ids
=
_crop_box_with_center_constraint
(
gt_box
,
np
.
array
(
crop_box
,
dtype
=
np
.
float32
))
if
valid_ids
.
size
>
0
:
found
=
True
break
if
found
:
x1
,
y1
,
x2
,
y2
=
crop_box
img
=
img
[
y1
:
y2
,
x1
:
x2
,
:]
gt_box
=
np
.
take
(
cropped_box
,
valid_ids
,
axis
=
0
)
gt_label
=
np
.
take
(
gt_label
,
valid_ids
,
axis
=
0
)
return
img
,
img_ids
,
gt_box
,
gt_label
return
inputs
# XXX mix up, color distort and random expand are skipped for simplicity
def
sample_transform
(
inputs
,
mode
=
'train'
,
num_max_boxes
=
50
):
if
mode
==
'train'
:
img
,
img_id
,
gt_box
,
gt_label
=
random_crop
(
inputs
)
else
:
img
,
img_id
,
gt_box
,
gt_label
=
inputs
h
,
w
=
img
.
shape
[:
2
]
# random flip
if
mode
==
'train'
and
np
.
random
.
uniform
(
0.
,
1.
)
>
.
5
:
img
=
img
[:,
::
-
1
,
:]
if
len
(
gt_box
)
>
0
:
swap
=
gt_box
.
copy
()
gt_box
[:,
0
]
=
w
-
swap
[:,
2
]
-
1
gt_box
[:,
2
]
=
w
-
swap
[:,
0
]
-
1
if
len
(
gt_label
)
==
0
:
gt_box
=
np
.
zeros
([
num_max_boxes
,
4
],
dtype
=
np
.
float32
)
gt_label
=
np
.
zeros
([
num_max_boxes
],
dtype
=
np
.
int32
)
return
img
,
gt_box
,
gt_label
gt_box
=
gt_box
[:
num_max_boxes
,
:]
gt_label
=
gt_label
[:
num_max_boxes
,
0
]
# normalize boxes
gt_box
/=
np
.
array
([
w
,
h
]
*
2
,
dtype
=
np
.
float32
)
gt_box
[:,
2
:]
=
gt_box
[:,
2
:]
-
gt_box
[:,
:
2
]
gt_box
[:,
:
2
]
=
gt_box
[:,
:
2
]
+
gt_box
[:,
2
:]
/
2.
pad
=
num_max_boxes
-
gt_label
.
size
gt_box
=
np
.
pad
(
gt_box
,
((
0
,
pad
),
(
0
,
0
)),
mode
=
'constant'
)
gt_label
=
np
.
pad
(
gt_label
,
((
0
,
pad
)),
mode
=
'constant'
)
return
img
,
img_id
,
gt_box
,
gt_label
def
batch_transform
(
batch
,
mode
=
'train'
):
if
mode
==
'train'
:
d
=
np
.
random
.
choice
(
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
])
interp
=
np
.
random
.
choice
(
range
(
5
))
else
:
d
=
608
interp
=
cv2
.
INTER_CUBIC
# transpose batch
imgs
,
img_ids
,
gt_boxes
,
gt_labels
=
list
(
zip
(
*
batch
))
img_shapes
=
np
.
array
([[
im
.
shape
[
0
],
im
.
shape
[
1
]]
for
im
in
imgs
]).
astype
(
'int32'
)
imgs
=
np
.
array
([
cv2
.
resize
(
img
,
(
d
,
d
),
interpolation
=
interp
)
for
img
in
imgs
])
# transpose, permute and normalize
imgs
=
imgs
.
astype
(
np
.
float32
)[...,
::
-
1
]
mean
=
np
.
array
([
123.675
,
116.28
,
103.53
],
dtype
=
np
.
float32
)
std
=
np
.
array
([
58.395
,
57.120
,
57.375
],
dtype
=
np
.
float32
)
invstd
=
1.
/
std
imgs
-=
mean
imgs
*=
invstd
imgs
=
imgs
.
transpose
((
0
,
3
,
1
,
2
))
img_ids
=
np
.
array
(
img_ids
)
img_info
=
np
.
concatenate
([
img_ids
,
img_shapes
],
axis
=
1
)
gt_boxes
=
np
.
array
(
gt_boxes
)
gt_labels
=
np
.
array
(
gt_labels
)
# XXX since mix up is not used, scores are all ones
gt_scores
=
np
.
ones_like
(
gt_labels
,
dtype
=
np
.
float32
)
return
[
imgs
,
img_info
],
[
gt_boxes
,
gt_labels
,
gt_scores
]
def
coco2017
(
root_dir
,
mode
=
'train'
):
json_path
=
os
.
path
.
join
(
root_dir
,
'annotations/instances_{}2017.json'
.
format
(
mode
))
coco
=
COCO
(
json_path
)
img_ids
=
coco
.
getImgIds
()
imgs
=
coco
.
loadImgs
(
img_ids
)
class_map
=
{
v
:
i
+
1
for
i
,
v
in
enumerate
(
coco
.
getCatIds
())}
samples
=
[]
for
img
in
imgs
:
img_path
=
os
.
path
.
join
(
root_dir
,
'{}2017'
.
format
(
mode
),
img
[
'file_name'
])
file_path
=
img_path
width
=
img
[
'width'
]
height
=
img
[
'height'
]
ann_ids
=
coco
.
getAnnIds
(
imgIds
=
img
[
'id'
],
iscrowd
=
False
)
anns
=
coco
.
loadAnns
(
ann_ids
)
gt_box
=
[]
gt_label
=
[]
for
ann
in
anns
:
x1
,
y1
,
w
,
h
=
ann
[
'bbox'
]
x2
=
x1
+
w
-
1
y2
=
y1
+
h
-
1
x1
=
np
.
clip
(
x1
,
0
,
width
-
1
)
x2
=
np
.
clip
(
x2
,
0
,
width
-
1
)
y1
=
np
.
clip
(
y1
,
0
,
height
-
1
)
y2
=
np
.
clip
(
y2
,
0
,
height
-
1
)
if
ann
[
'area'
]
<=
0
or
x2
<
x1
or
y2
<
y1
:
continue
gt_label
.
append
(
ann
[
'category_id'
])
gt_box
.
append
([
x1
,
y1
,
x2
,
y2
])
gt_box
=
np
.
array
(
gt_box
,
dtype
=
np
.
float32
)
gt_label
=
np
.
array
([
class_map
[
cls
]
for
cls
in
gt_label
],
dtype
=
np
.
int32
)[:,
np
.
newaxis
]
im_id
=
np
.
array
([
img
[
'id'
]],
dtype
=
np
.
int32
)
if
gt_label
.
size
==
0
and
not
mode
==
'train'
:
continue
samples
.
append
((
file_path
,
im_id
.
copy
(),
gt_box
.
copy
(),
gt_label
.
copy
()))
def
iterator
():
if
mode
==
'train'
:
np
.
random
.
shuffle
(
samples
)
for
file_path
,
im_id
,
gt_box
,
gt_label
in
samples
:
img
=
cv2
.
imread
(
file_path
)
yield
img
,
im_id
,
gt_box
,
gt_label
return
iterator
# XXX coco metrics not included for simplicity
def
run
(
model
,
loader
,
mode
=
'train'
):
total_loss
=
0.
total_time
=
0.
device_ids
=
list
(
range
(
FLAGS
.
num_devices
))
start
=
time
.
time
()
for
idx
,
batch
in
enumerate
(
loader
()):
losses
=
getattr
(
model
,
mode
)(
batch
[
0
],
batch
[
1
])
total_loss
+=
np
.
sum
(
losses
)
if
idx
>
1
:
# skip first two steps
total_time
+=
time
.
time
()
-
start
if
idx
%
10
==
0
:
logger
.
info
(
"{:04d}: loss {:0.3f} time: {:0.3f}"
.
format
(
idx
,
total_loss
/
(
idx
+
1
),
total_time
/
max
(
1
,
(
idx
-
1
))))
start
=
time
.
time
()
def
main
():
@
contextlib
.
contextmanager
def
null_guard
():
yield
epoch
=
FLAGS
.
epoch
batch_size
=
FLAGS
.
batch_size
guard
=
fluid
.
dygraph
.
guard
()
if
FLAGS
.
dynamic
else
null_guard
()
device
=
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
device
)
if
FLAGS
.
dynamic
else
None
inputs
=
[
Input
([
None
,
3
],
'int32'
,
name
=
'img_info'
),
Input
([
None
,
3
,
None
,
None
],
'float32'
,
name
=
'image'
)]
labels
=
[
Input
([
None
,
NUM_MAX_BOXES
,
4
],
'float32'
,
name
=
'gt_bbox'
),
Input
([
None
,
NUM_MAX_BOXES
],
'int32'
,
name
=
'gt_label'
),
Input
([
None
,
NUM_MAX_BOXES
],
'float32'
,
name
=
'gt_score'
)]
if
not
FLAGS
.
eval_only
:
# training mode
train_transform
=
Compose
([
ColorDistort
(),
RandomExpand
(),
RandomCrop
(),
RandomFlip
(),
NormalizeBox
(),
PadBox
(),
BboxXYXY2XYWH
()])
train_collate_fn
=
BatchCompose
([
RandomShape
(),
NormalizeImage
()])
dataset
=
COCODataset
(
dataset_dir
=
FLAGS
.
data
,
anno_path
=
'annotations/instances_train2017.json'
,
image_dir
=
'train2017'
,
with_background
=
False
,
mixup
=
True
,
transform
=
train_transform
)
batch_sampler
=
DistributedBatchSampler
(
dataset
,
batch_size
=
FLAGS
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
)
loader
=
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
feed_list
=
[
i
.
forward
()
for
i
in
inputs
+
labels
]
\
if
not
FLAGS
.
dynamic
else
None
,
num_workers
=
FLAGS
.
num_workers
,
return_list
=
True
,
collate_fn
=
train_collate_fn
)
else
:
# evaluation mode
eval_transform
=
Compose
([
ResizeImage
(
target_size
=
608
),
NormalizeBox
(),
PadBox
(),
BboxXYXY2XYWH
()])
eval_collate_fn
=
BatchCompose
([
NormalizeImage
()])
dataset
=
COCODataset
(
dataset_dir
=
FLAGS
.
data
,
anno_path
=
'annotations/instances_val2017.json'
,
image_dir
=
'val2017'
,
with_background
=
False
,
transform
=
eval_transform
)
# batch_size can only be 1 in evaluation for YOLOv3
# prediction bbox is LoDTensor
batch_sampler
=
DistributedBatchSampler
(
dataset
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
)
loader
=
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
feed_list
=
[
i
.
forward
()
for
i
in
inputs
+
labels
]
\
if
not
FLAGS
.
dynamic
else
None
,
num_workers
=
FLAGS
.
num_workers
,
return_list
=
True
,
collate_fn
=
eval_collate_fn
)
model
=
YOLOv3
(
num_classes
=
dataset
.
num_classes
,
model_mode
=
'eval'
if
FLAGS
.
eval_only
else
'train'
)
if
FLAGS
.
pretrain_weights
is
not
None
:
model
.
load
(
FLAGS
.
pretrain_weights
,
skip_mismatch
=
True
,
reset_optimizer
=
True
)
optim
=
make_optimizer
(
len
(
batch_sampler
),
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim
,
YoloLoss
(
num_classes
=
dataset
.
num_classes
),
inputs
=
inputs
,
labels
=
labels
,
device
=
FLAGS
.
device
)
# NOTE: we implement COCO metric of YOLOv3 model here, separately
# from 'prepare' and 'fit' framework for follwing reason:
# 1. YOLOv3 network structure is different between 'train' and
# 'eval' mode, in 'eval' mode, output prediction bbox is not the
# feature map used for YoloLoss calculating
# 2. COCO metric behavior is also different from defined Metric
# for COCO metric should not perform accumulate in each iteration
# but only accumulate at the end of an epoch
if
FLAGS
.
eval_only
:
if
FLAGS
.
weights
is
not
None
:
model
.
load
(
FLAGS
.
weights
)
preds
=
model
.
predict
(
loader
)
_
,
_
,
_
,
img_ids
,
bboxes
=
preds
train_loader
=
fluid
.
io
.
xmap_readers
(
batch_transform
,
paddle
.
batch
(
fluid
.
io
.
xmap_readers
(
sample_transform
,
coco2017
(
FLAGS
.
data
,
'train'
),
process_num
=
8
,
buffer_size
=
4
*
batch_size
),
batch_size
=
batch_size
,
drop_last
=
True
),
process_num
=
2
,
buffer_size
=
4
)
anno_path
=
os
.
path
.
join
(
FLAGS
.
data
,
'annotations/instances_val2017.json'
)
coco_metric
=
COCOMetric
(
anno_path
=
anno_path
,
with_background
=
False
)
for
img_id
,
bbox
in
zip
(
img_ids
,
bboxes
):
coco_metric
.
update
(
img_id
,
bbox
)
coco_metric
.
accumulate
()
coco_metric
.
reset
()
return
val_sample_transform
=
partial
(
sample_transform
,
mode
=
'val'
)
val_batch_transform
=
partial
(
batch_transform
,
mode
=
'val'
)
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
val_loader
=
fluid
.
io
.
xmap_readers
(
val_batch_transform
,
paddle
.
batch
(
fluid
.
io
.
xmap_readers
(
val_sample_transform
,
coco2017
(
FLAGS
.
data
,
'val'
),
process_num
=
8
,
buffer_size
=
4
*
batch_size
),
batch_size
=
1
),
process_num
=
2
,
buffer_size
=
4
)
model
.
fit
(
train_data
=
loader
,
epochs
=
FLAGS
.
epoch
-
FLAGS
.
no_mixup_epoch
,
save_dir
=
"yolo_checkpoint/mixup"
,
save_freq
=
10
)
if
not
os
.
path
.
exists
(
'yolo_checkpoints'
):
os
.
mkdir
(
'yolo_checkpoints'
)
with
guard
:
NUM_CLASSES
=
7
NUM_MAX_BOXES
=
50
model
=
YOLOv3
(
num_classes
=
NUM_CLASSES
)
# XXX transfer learning
if
FLAGS
.
pretrain_weights
is
not
None
:
model
.
backbone
.
load
(
FLAGS
.
pretrain_weights
)
if
FLAGS
.
weights
is
not
None
:
model
.
load
(
FLAGS
.
weights
)
optim
=
make_optimizer
(
parameter_list
=
model
.
parameters
())
anno_path
=
os
.
path
.
join
(
FLAGS
.
data
,
'annotations'
,
'instances_val2017.json'
)
inputs
=
[
Input
([
None
,
3
,
None
,
None
],
'float32'
,
name
=
'image'
),
Input
([
None
,
3
],
'int32'
,
name
=
'img_info'
)]
labels
=
[
Input
([
None
,
NUM_MAX_BOXES
,
4
],
'float32'
,
name
=
'gt_bbox'
),
Input
([
None
,
NUM_MAX_BOXES
],
'int32'
,
name
=
'gt_label'
),
Input
([
None
,
NUM_MAX_BOXES
],
'float32'
,
name
=
'gt_score'
)]
model
.
prepare
(
optim
,
YoloLoss
(
num_classes
=
NUM_CLASSES
),
# For YOLOv3, output variable in train/eval is different,
# which is not supported by metric, add by callback later?
# metrics=COCOMetric(anno_path, with_background=False)
inputs
=
inputs
,
labels
=
labels
)
for
e
in
range
(
epoch
):
logger
.
info
(
"======== train epoch {} ========"
.
format
(
e
))
run
(
model
,
train_loader
)
model
.
save
(
'yolo_checkpoints/{:02d}'
.
format
(
e
))
logger
.
info
(
"======== eval epoch {} ========"
.
format
(
e
))
run
(
model
,
val_loader
,
mode
=
'eval'
)
# should be called in fit()
for
metric
in
model
.
_metrics
:
metric
.
accumulate
()
metric
.
reset
()
# do not use image mixup transfrom in laste FLAGS.no_mixup_epoch epoches
dataset
.
mixup
=
False
model
.
fit
(
train_data
=
loader
,
epochs
=
FLAGS
.
no_mixup_epoch
,
save_dir
=
"yolo_checkpoint/no_mixup"
,
save_freq
=
5
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
"Yolov3 Training on COCO"
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to COCO dataset'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
help
=
"device to use, gpu or cpu"
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
"--eval_only"
,
action
=
'store_true'
,
help
=
"run evaluation only"
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
300
,
type
=
int
,
help
=
"number of epoch"
)
parser
.
add_argument
(
"--no_mixup_epoch"
,
default
=
30
,
type
=
int
,
help
=
"number of the last N epoch without image mixup"
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
0.001
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate'
)
parser
.
add_argument
(
"-b"
,
"--batch_size"
,
default
=
64
,
type
=
int
,
help
=
"batch size"
)
"-b"
,
"--batch_size"
,
default
=
8
,
type
=
int
,
help
=
"batch size"
)
parser
.
add_argument
(
"-n"
,
"--num_devices"
,
default
=
8
,
type
=
int
,
help
=
"number of devices"
)
"-n"
,
"--num_devices"
,
default
=
1
,
type
=
int
,
help
=
"number of devices"
)
parser
.
add_argument
(
"-j"
,
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"reader worker number"
)
parser
.
add_argument
(
"-p"
,
"--pretrain_weights"
,
default
=
None
,
type
=
str
,
help
=
"path to pretrained weights"
)
parser
.
add_argument
(
"-
w"
,
"--weights
"
,
default
=
None
,
type
=
str
,
"-
r"
,
"--resume
"
,
default
=
None
,
type
=
str
,
help
=
"path to model weights"
)
parser
.
add_argument
(
"-w"
,
"--weights"
,
default
=
None
,
type
=
str
,
help
=
"path to weights for evaluation"
)
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
data
,
"error: must provide data path"
main
()
yolov3/coco.py
0 → 100644
浏览文件 @
97a365e5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
os
import
cv2
import
numpy
as
np
from
pycocotools.coco
import
COCO
from
paddle.fluid.io
import
Dataset
import
logging
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'COCODataset'
]
class
COCODataset
(
Dataset
):
"""
Load dataset with MS-COCO format.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): voc annotation file path.
sample_num (int): number of samples to load, -1 means all.
use_default_label (bool): whether use the default mapping of
label to integer index. Default True.
with_background (bool): whether load background as a class,
default True.
transform (callable): callable transform to perform on samples,
default None.
mixup (bool): whether return image mixup samples, default False.
alpha (float): alpha factor of beta distribution to generate
mixup score, used only when mixup is True, default 1.5
beta (float): beta factor of beta distribution to generate
mixup score, used only when mixup is True, default 1.5
"""
def
__init__
(
self
,
dataset_dir
=
''
,
image_dir
=
''
,
anno_path
=
''
,
sample_num
=-
1
,
with_background
=
True
,
transform
=
None
,
mixup
=
False
,
alpha
=
1.5
,
beta
=
1.5
):
# roidbs is list of dict whose structure is:
# {
# 'im_file': im_fname, # image file name
# 'im_id': im_id, # image id
# 'h': im_h, # height of image
# 'w': im_w, # width
# 'is_crowd': is_crowd,
# 'gt_class': gt_class,
# 'gt_bbox': gt_bbox,
# 'gt_score': gt_score,
# 'difficult': difficult
# }
self
.
_anno_path
=
os
.
path
.
join
(
dataset_dir
,
anno_path
)
self
.
_image_dir
=
os
.
path
.
join
(
dataset_dir
,
image_dir
)
assert
os
.
path
.
exists
(
self
.
_anno_path
),
\
"anno_path {} not exists"
.
format
(
anno_path
)
assert
os
.
path
.
exists
(
self
.
_image_dir
),
\
"image_dir {} not exists"
.
format
(
image_dir
)
self
.
_sample_num
=
sample_num
self
.
_with_background
=
with_background
self
.
_transform
=
transform
self
.
_mixup
=
mixup
self
.
_alpha
=
alpha
self
.
_beta
=
beta
# load in dataset roidbs
self
.
_load_roidb_and_cname2cid
()
def
_load_roidb_and_cname2cid
(
self
):
assert
self
.
_anno_path
.
endswith
(
'.json'
),
\
'invalid coco annotation file: '
+
anno_path
coco
=
COCO
(
self
.
_anno_path
)
img_ids
=
coco
.
getImgIds
()
cat_ids
=
coco
.
getCatIds
()
records
=
[]
ct
=
0
# when with_background = True, mapping category to classid, like:
# background:0, first_class:1, second_class:2, ...
catid2clsid
=
dict
({
catid
:
i
+
int
(
self
.
_with_background
)
for
i
,
catid
in
enumerate
(
cat_ids
)
})
cname2cid
=
dict
({
coco
.
loadCats
(
catid
)[
0
][
'name'
]:
clsid
for
catid
,
clsid
in
catid2clsid
.
items
()
})
for
img_id
in
img_ids
:
img_anno
=
coco
.
loadImgs
(
img_id
)[
0
]
im_fname
=
img_anno
[
'file_name'
]
im_w
=
float
(
img_anno
[
'width'
])
im_h
=
float
(
img_anno
[
'height'
])
ins_anno_ids
=
coco
.
getAnnIds
(
imgIds
=
img_id
,
iscrowd
=
False
)
instances
=
coco
.
loadAnns
(
ins_anno_ids
)
bboxes
=
[]
for
inst
in
instances
:
x
,
y
,
box_w
,
box_h
=
inst
[
'bbox'
]
x1
=
max
(
0
,
x
)
y1
=
max
(
0
,
y
)
x2
=
min
(
im_w
-
1
,
x1
+
max
(
0
,
box_w
-
1
))
y2
=
min
(
im_h
-
1
,
y1
+
max
(
0
,
box_h
-
1
))
if
inst
[
'area'
]
>
0
and
x2
>=
x1
and
y2
>=
y1
:
inst
[
'clean_bbox'
]
=
[
x1
,
y1
,
x2
,
y2
]
bboxes
.
append
(
inst
)
else
:
logger
.
warn
(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'
.
format
(
img_id
,
float
(
inst
[
'area'
]),
x1
,
y1
,
x2
,
y2
))
num_bbox
=
len
(
bboxes
)
gt_bbox
=
np
.
zeros
((
num_bbox
,
4
),
dtype
=
np
.
float32
)
gt_class
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_score
=
np
.
ones
((
num_bbox
,
1
),
dtype
=
np
.
float32
)
is_crowd
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
difficult
=
np
.
zeros
((
num_bbox
,
1
),
dtype
=
np
.
int32
)
gt_poly
=
[
None
]
*
num_bbox
for
i
,
box
in
enumerate
(
bboxes
):
catid
=
box
[
'category_id'
]
gt_class
[
i
][
0
]
=
catid2clsid
[
catid
]
gt_bbox
[
i
,
:]
=
box
[
'clean_bbox'
]
is_crowd
[
i
][
0
]
=
box
[
'iscrowd'
]
if
'segmentation'
in
box
:
gt_poly
[
i
]
=
box
[
'segmentation'
]
im_fname
=
os
.
path
.
join
(
self
.
_image_dir
,
im_fname
)
if
self
.
_image_dir
else
im_fname
coco_rec
=
{
'im_file'
:
im_fname
,
'im_id'
:
np
.
array
([
img_id
]),
'h'
:
im_h
,
'w'
:
im_w
,
'is_crowd'
:
is_crowd
,
'gt_class'
:
gt_class
,
'gt_bbox'
:
gt_bbox
,
'gt_score'
:
gt_score
,
'gt_poly'
:
gt_poly
,
}
records
.
append
(
coco_rec
)
ct
+=
1
if
self
.
_sample_num
>
0
and
ct
>=
self
.
_sample_num
:
break
assert
len
(
records
)
>
0
,
'not found any coco record in %s'
%
(
self
.
_anno_path
)
logger
.
info
(
'{} samples in file {}'
.
format
(
ct
,
self
.
_anno_path
))
self
.
_roidbs
,
self
.
_cname2cid
=
records
,
cname2cid
@
property
def
num_classes
(
self
):
return
len
(
self
.
_cname2cid
)
def
__len__
(
self
):
return
len
(
self
.
_roidbs
)
def
_getitem_by_index
(
self
,
idx
):
roidb
=
self
.
_roidbs
[
idx
]
with
open
(
roidb
[
'im_file'
],
'rb'
)
as
f
:
data
=
np
.
frombuffer
(
f
.
read
(),
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
data
,
1
)
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im_info
=
np
.
array
([
roidb
[
'im_id'
][
0
],
roidb
[
'h'
],
roidb
[
'w'
]],
dtype
=
'int32'
)
gt_bbox
=
roidb
[
'gt_bbox'
]
gt_class
=
roidb
[
'gt_class'
]
gt_score
=
roidb
[
'gt_score'
]
return
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
def
__getitem__
(
self
,
idx
):
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
=
self
.
_getitem_by_index
(
idx
)
if
self
.
_mixup
:
mixup_idx
=
idx
+
np
.
random
.
randint
(
1
,
self
.
__len__
())
mixup_idx
%=
self
.
__len__
()
_
,
mixup_im
,
mixup_bbox
,
mixup_class
,
_
=
\
self
.
_getitem_by_index
(
mixup_idx
)
im
,
gt_bbox
,
gt_class
,
gt_score
=
\
self
.
_mixup_image
(
im
,
gt_bbox
,
gt_class
,
mixup_im
,
mixup_bbox
,
mixup_class
)
if
self
.
_transform
:
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
=
\
self
.
_transform
(
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
)
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
def
_mixup_image
(
self
,
img1
,
bbox1
,
class1
,
img2
,
bbox2
,
class2
):
factor
=
np
.
random
.
beta
(
self
.
_alpha
,
self
.
_beta
)
factor
=
max
(
0.0
,
min
(
1.0
,
factor
))
if
factor
>=
1.0
:
return
img1
,
bbox1
,
class1
,
np
.
ones_like
(
class1
,
dtype
=
"float32"
)
if
factor
<=
0.0
:
return
img2
,
bbox2
,
class2
,
np
.
ones_like
(
class2
,
dtype
=
"float32"
)
h
=
max
(
img1
.
shape
[
0
],
img2
.
shape
[
0
])
w
=
max
(
img1
.
shape
[
1
],
img2
.
shape
[
1
])
img
=
np
.
zeros
((
h
,
w
,
img1
.
shape
[
2
]),
'float32'
)
img
[:
img1
.
shape
[
0
],
:
img1
.
shape
[
1
],
:]
=
\
img1
.
astype
(
'float32'
)
*
factor
img
[:
img2
.
shape
[
0
],
:
img2
.
shape
[
1
],
:]
+=
\
img2
.
astype
(
'float32'
)
*
(
1.0
-
factor
)
gt_bbox
=
np
.
concatenate
((
bbox1
,
bbox2
),
axis
=
0
)
gt_class
=
np
.
concatenate
((
class1
,
class2
),
axis
=
0
)
score1
=
np
.
ones_like
(
class1
,
dtype
=
"float32"
)
*
factor
score2
=
np
.
ones_like
(
class2
,
dtype
=
"float32"
)
*
(
1.0
-
factor
)
gt_score
=
np
.
concatenate
((
score1
,
score2
),
axis
=
0
)
return
img
,
gt_bbox
,
gt_class
,
gt_score
@
property
def
mixup
(
self
):
return
self
.
_mixup
@
mixup
.
setter
def
mixup
(
self
,
value
):
if
not
isinstance
(
value
,
bool
):
raise
ValueError
(
"mixup should be a boolean number"
)
logger
.
info
(
"{} set mixup to {}"
.
format
(
self
,
value
))
self
.
_mixup
=
value
def
pascalvoc_label
(
with_background
=
True
):
labels_map
=
{
'aeroplane'
:
1
,
'bicycle'
:
2
,
'bird'
:
3
,
'boat'
:
4
,
'bottle'
:
5
,
'bus'
:
6
,
'car'
:
7
,
'cat'
:
8
,
'chair'
:
9
,
'cow'
:
10
,
'diningtable'
:
11
,
'dog'
:
12
,
'horse'
:
13
,
'motorbike'
:
14
,
'person'
:
15
,
'pottedplant'
:
16
,
'sheep'
:
17
,
'sofa'
:
18
,
'train'
:
19
,
'tvmonitor'
:
20
}
if
not
with_background
:
labels_map
=
{
k
:
v
-
1
for
k
,
v
in
labels_map
.
items
()}
return
labels_map
yolov3/coco_metric.py
浏览文件 @
97a365e5
...
...
@@ -17,8 +17,6 @@ import json
from
pycocotools.cocoeval
import
COCOeval
from
pycocotools.coco
import
COCO
from
metrics
import
Metric
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
...
...
@@ -31,7 +29,7 @@ OUTFILE = './bbox.json'
# considered to change to a callback later
class
COCOMetric
(
Metric
):
class
COCOMetric
():
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
...
...
@@ -43,26 +41,24 @@ class COCOMetric(Metric):
"""
def
__init__
(
self
,
anno_path
,
with_background
=
True
,
**
kwargs
):
super
(
COCOMetric
,
self
).
__init__
(
**
kwargs
)
self
.
anno_path
=
anno_path
self
.
with_background
=
with_background
self
.
bbox_results
=
[]
self
.
coco_gt
=
COCO
(
anno_path
)
cat_ids
=
self
.
coco_gt
.
getCatIds
()
self
.
clsid2catid
=
dict
(
{
i
+
int
(
with_background
):
catid
for
i
,
catid
in
enumerate
(
cat_ids
)})
self
.
clsid2catid
=
dict
(
{
i
+
int
(
with_background
):
catid
for
i
,
catid
in
enumerate
(
cat_ids
)})
def
update
(
self
,
preds
,
*
args
,
**
kwargs
):
im_ids
,
bboxes
=
preds
assert
im_ids
.
shape
[
0
]
==
1
,
\
def
update
(
self
,
img_id
,
bboxes
):
assert
img_id
.
shape
[
0
]
==
1
,
\
"COCOMetric can only update with batch size = 1"
if
bboxes
.
shape
[
1
]
!=
6
:
# no bbox detected in this batch
return
im
_id
=
int
(
im_ids
)
im
g_id
=
int
(
img_id
)
for
i
in
range
(
bboxes
.
shape
[
0
]):
dt
=
bboxes
[
i
,
:]
clsid
,
score
,
xmin
,
ymin
,
xmax
,
ymax
=
dt
.
tolist
()
...
...
@@ -72,7 +68,7 @@ class COCOMetric(Metric):
h
=
ymax
-
ymin
+
1
bbox
=
[
xmin
,
ymin
,
w
,
h
]
coco_res
=
{
'image_id'
:
im_id
,
'image_id'
:
im
g
_id
,
'category_id'
:
catid
,
'bbox'
:
bbox
,
'score'
:
score
...
...
@@ -83,30 +79,30 @@ class COCOMetric(Metric):
self
.
bbox_results
=
[]
def
accumulate
(
self
):
if
len
(
self
.
bbox_results
)
==
0
:
logger
.
warning
(
"The number of valid bbox detected is zero.
\n
\
Please use reasonable model and check input data.
\n
\
stop COCOMetric accumulate!"
)
return
[
0.0
]
with
open
(
OUTFILE
,
'w'
)
as
f
:
json
.
dump
(
self
.
bbox_results
,
f
)
map_stats
=
self
.
cocoapi_eval
(
OUTFILE
,
'bbox'
,
coco_gt
=
self
.
coco_gt
)
# flush coco evaluation result
sys
.
stdout
.
flush
()
if
len
(
self
.
bbox_results
)
==
0
:
logger
.
warning
(
"The number of valid bbox detected is zero.
\n
\
Please use reasonable model and check input data.
\n
\
stop COCOMetric accumulate!"
)
return
[
0.0
]
with
open
(
OUTFILE
,
'w'
)
as
f
:
json
.
dump
(
self
.
bbox_results
,
f
)
map_stats
=
self
.
cocoapi_eval
(
OUTFILE
,
'bbox'
,
coco_gt
=
self
.
coco_gt
)
# flush coco evaluation result
sys
.
stdout
.
flush
()
self
.
result
=
map_stats
[
0
]
return
self
.
result
return
[
self
.
result
]
def
cocoapi_eval
(
self
,
jsonfile
,
style
,
coco_gt
=
None
,
anno_file
=
None
):
assert
coco_gt
!=
None
or
anno_file
!=
None
if
coco_gt
==
None
:
coco_gt
=
COCO
(
anno_file
)
logger
.
info
(
"Start evaluate..."
)
coco_dt
=
coco_gt
.
loadRes
(
jsonfile
)
coco_eval
=
COCOeval
(
coco_gt
,
coco_dt
,
style
)
coco_eval
.
evaluate
()
coco_eval
.
accumulate
()
coco_eval
.
summarize
()
return
coco_eval
.
stats
assert
coco_gt
!=
None
or
anno_file
!=
None
if
coco_gt
==
None
:
coco_gt
=
COCO
(
anno_file
)
logger
.
info
(
"Start evaluate..."
)
coco_dt
=
coco_gt
.
loadRes
(
jsonfile
)
coco_eval
=
COCOeval
(
coco_gt
,
coco_dt
,
style
)
coco_eval
.
evaluate
()
coco_eval
.
accumulate
()
coco_eval
.
summarize
()
return
coco_eval
.
stats
yolov3/darknet.py
0 → 100755
浏览文件 @
97a365e5
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
paddle.fluid.dygraph.base
import
to_variable
__all__
=
[
'DarkNet53'
,
'ConvBNLayer'
]
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
filter_size
=
3
,
stride
=
1
,
groups
=
1
,
padding
=
0
,
act
=
"leaky"
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
Conv2D
(
num_channels
=
ch_in
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
)),
bias_attr
=
False
,
act
=
None
)
self
.
batch_norm
=
BatchNorm
(
num_channels
=
ch_out
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
),
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
)))
self
.
act
=
act
def
forward
(
self
,
inputs
):
out
=
self
.
conv
(
inputs
)
out
=
self
.
batch_norm
(
out
)
if
self
.
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
class
DownSample
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
):
super
(
DownSample
,
self
).
__init__
()
self
.
conv_bn_layer
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
):
out
=
self
.
conv_bn_layer
(
inputs
)
return
out
class
BasicBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
ch_out
,
ch_out
=
ch_out
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
inputs
):
conv1
=
self
.
conv1
(
inputs
)
conv2
=
self
.
conv2
(
conv1
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
inputs
,
y
=
conv2
,
act
=
None
)
return
out
class
LayerWarp
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
count
):
super
(
LayerWarp
,
self
).
__init__
()
self
.
basicblock0
=
BasicBlock
(
ch_in
,
ch_out
)
self
.
res_out_list
=
[]
for
i
in
range
(
1
,
count
):
res_out
=
self
.
add_sublayer
(
"basic_block_%d"
%
(
i
),
BasicBlock
(
ch_out
*
2
,
ch_out
))
self
.
res_out_list
.
append
(
res_out
)
self
.
ch_out
=
ch_out
def
forward
(
self
,
inputs
):
y
=
self
.
basicblock0
(
inputs
)
for
basic_block_i
in
self
.
res_out_list
:
y
=
basic_block_i
(
y
)
return
y
DarkNet_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
])}
class
DarkNet53
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
ch_in
=
3
):
super
(
DarkNet53
,
self
).
__init__
()
self
.
stages
=
DarkNet_cfg
[
53
]
self
.
stages
=
self
.
stages
[
0
:
5
]
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
downsample0
=
DownSample
(
ch_in
=
32
,
ch_out
=
32
*
2
)
self
.
darknet53_conv_block_list
=
[]
self
.
downsample_list
=
[]
ch_in
=
[
64
,
128
,
256
,
512
,
1024
]
for
i
,
stage
in
enumerate
(
self
.
stages
):
conv_block
=
self
.
add_sublayer
(
"stage_%d"
%
(
i
),
LayerWarp
(
int
(
ch_in
[
i
]),
32
*
(
2
**
i
),
stage
))
self
.
darknet53_conv_block_list
.
append
(
conv_block
)
for
i
in
range
(
len
(
self
.
stages
)
-
1
):
downsample
=
self
.
add_sublayer
(
"stage_%d_downsample"
%
i
,
DownSample
(
ch_in
=
32
*
(
2
**
(
i
+
1
)),
ch_out
=
32
*
(
2
**
(
i
+
2
))))
self
.
downsample_list
.
append
(
downsample
)
def
forward
(
self
,
inputs
):
out
=
self
.
conv0
(
inputs
)
out
=
self
.
downsample0
(
out
)
blocks
=
[]
for
i
,
conv_block_i
in
enumerate
(
self
.
darknet53_conv_block_list
):
out
=
conv_block_i
(
out
)
blocks
.
append
(
out
)
if
i
<
len
(
self
.
stages
)
-
1
:
out
=
self
.
downsample_list
[
i
](
out
)
return
blocks
[
-
1
:
-
4
:
-
1
]
yolov3/modeling.py
0 → 100644
浏览文件 @
97a365e5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
model
import
Model
,
Loss
from
.darknet
import
DarkNet53
,
ConvBNLayer
__all__
=
[
'YoloLoss'
,
'YOLOv3'
]
class
YoloDetectionBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
ch_in
,
channel
):
super
(
YoloDetectionBlock
,
self
).
__init__
()
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2"
.
format
(
channel
)
self
.
conv0
=
ConvBNLayer
(
ch_in
=
ch_in
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv1
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv3
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
route
=
ConvBNLayer
(
ch_in
=
channel
*
2
,
ch_out
=
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
tip
=
ConvBNLayer
(
ch_in
=
channel
,
ch_out
=
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
inputs
):
out
=
self
.
conv0
(
inputs
)
out
=
self
.
conv1
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
conv3
(
out
)
route
=
self
.
route
(
out
)
tip
=
self
.
tip
(
route
)
return
route
,
tip
class
YOLOv3
(
Model
):
def
__init__
(
self
,
num_classes
=
80
,
model_mode
=
'train'
):
super
(
YOLOv3
,
self
).
__init__
()
self
.
num_classes
=
num_classes
assert
str
.
lower
(
model_mode
)
in
[
'train'
,
'eval'
],
\
"model_mode should be 'train' or 'val', but got "
\
"{}"
.
format
(
model_mode
)
self
.
model_mode
=
str
.
lower
(
model_mode
)
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
self
.
valid_thresh
=
0.005
self
.
nms_thresh
=
0.45
self
.
nms_topk
=
400
self
.
nms_posk
=
100
self
.
draw_thresh
=
0.5
self
.
block
=
DarkNet53
()
self
.
block_outputs
=
[]
self
.
yolo_blocks
=
[]
self
.
route_blocks
=
[]
for
idx
,
num_chan
in
enumerate
([
1024
,
768
,
384
]):
yolo_block
=
self
.
add_sublayer
(
"yolo_detecton_block_{}"
.
format
(
idx
),
YoloDetectionBlock
(
num_chan
,
512
//
(
2
**
idx
)))
self
.
yolo_blocks
.
append
(
yolo_block
)
num_filters
=
len
(
self
.
anchor_masks
[
idx
])
*
(
self
.
num_classes
+
5
)
block_out
=
self
.
add_sublayer
(
"block_out_{}"
.
format
(
idx
),
Conv2D
(
num_channels
=
1024
//
(
2
**
idx
),
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
None
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
)),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
))))
self
.
block_outputs
.
append
(
block_out
)
if
idx
<
2
:
route
=
self
.
add_sublayer
(
"route2_{}"
.
format
(
idx
),
ConvBNLayer
(
ch_in
=
512
//
(
2
**
idx
),
ch_out
=
256
//
(
2
**
idx
),
filter_size
=
1
,
act
=
'leaky_relu'
))
self
.
route_blocks
.
append
(
route
)
def
forward
(
self
,
img_info
,
inputs
):
outputs
=
[]
boxes
=
[]
scores
=
[]
downsample
=
32
feats
=
self
.
block
(
inputs
)
route
=
None
for
idx
,
feat
in
enumerate
(
feats
):
if
idx
>
0
:
feat
=
fluid
.
layers
.
concat
(
input
=
[
route
,
feat
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
idx
](
feat
)
block_out
=
self
.
block_outputs
[
idx
](
tip
)
outputs
.
append
(
block_out
)
if
idx
<
2
:
route
=
self
.
route_blocks
[
idx
](
route
)
route
=
fluid
.
layers
.
resize_nearest
(
route
,
scale
=
2
)
if
self
.
model_mode
==
'eval'
:
anchor_mask
=
self
.
anchor_masks
[
idx
]
mask_anchors
=
[]
for
m
in
anchor_mask
:
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
])
mask_anchors
.
append
(
self
.
anchors
[
2
*
m
+
1
])
img_shape
=
fluid
.
layers
.
slice
(
img_info
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
3
])
img_id
=
fluid
.
layers
.
slice
(
img_info
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
])
b
,
s
=
fluid
.
layers
.
yolo_box
(
x
=
block_out
,
img_size
=
img_shape
,
anchors
=
mask_anchors
,
class_num
=
self
.
num_classes
,
conf_thresh
=
self
.
valid_thresh
,
downsample_ratio
=
downsample
)
boxes
.
append
(
b
)
scores
.
append
(
fluid
.
layers
.
transpose
(
s
,
perm
=
[
0
,
2
,
1
]))
downsample
//=
2
if
self
.
model_mode
==
'train'
:
return
outputs
return
outputs
+
[
img_id
[
0
,
:],
fluid
.
layers
.
multiclass_nms
(
bboxes
=
fluid
.
layers
.
concat
(
boxes
,
axis
=
1
),
scores
=
fluid
.
layers
.
concat
(
scores
,
axis
=
2
),
score_threshold
=
self
.
valid_thresh
,
nms_top_k
=
self
.
nms_topk
,
keep_top_k
=
self
.
nms_posk
,
nms_threshold
=
self
.
nms_thresh
,
background_label
=-
1
)
]
class
YoloLoss
(
Loss
):
def
__init__
(
self
,
num_classes
=
80
,
num_max_boxes
=
50
):
super
(
YoloLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
self
.
num_max_boxes
=
num_max_boxes
self
.
ignore_thresh
=
0.7
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
def
forward
(
self
,
outputs
,
labels
):
downsample
=
32
gt_box
,
gt_label
,
gt_score
=
labels
losses
=
[]
for
idx
,
out
in
enumerate
(
outputs
):
if
idx
==
3
:
break
# debug
anchor_mask
=
self
.
anchor_masks
[
idx
]
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
out
,
gt_box
=
gt_box
,
gt_label
=
gt_label
,
gt_score
=
gt_score
,
anchor_mask
=
anchor_mask
,
downsample_ratio
=
downsample
,
anchors
=
self
.
anchors
,
class_num
=
self
.
num_classes
,
ignore_thresh
=
self
.
ignore_thresh
,
use_label_smooth
=
True
)
loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
losses
.
append
(
loss
)
downsample
//=
2
return
losses
yolov3/transforms.py
0 → 100644
浏览文件 @
97a365e5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
cv2
import
traceback
import
numpy
as
np
import
logging
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'ColorDistort'
,
'RandomExpand'
,
'RandomCrop'
,
'RandomFlip'
,
'NormalizeBox'
,
'PadBox'
,
'RandomShape'
,
'NormalizeImage'
,
'BboxXYXY2XYWH'
,
'ResizeImage'
,
'Compose'
,
'BatchCompose'
]
class
Compose
(
object
):
def
__init__
(
self
,
transforms
=
[]):
self
.
transforms
=
transforms
def
__call__
(
self
,
*
data
):
for
f
in
self
.
transforms
:
try
:
data
=
f
(
*
data
)
except
Exception
as
e
:
stack_info
=
traceback
.
format_exc
()
logger
.
info
(
"fail to perform transform [{}] with error: "
"{} and stack:
\n
{}"
.
format
(
f
,
e
,
str
(
stack_info
)))
raise
e
return
data
class
BatchCompose
(
object
):
def
__init__
(
self
,
transforms
=
[]):
self
.
transforms
=
transforms
def
__call__
(
self
,
data
):
for
f
in
self
.
transforms
:
try
:
data
=
f
(
data
)
except
Exception
as
e
:
stack_info
=
traceback
.
format_exc
()
logger
.
info
(
"fail to perform batch transform [{}] with error: "
"{} and stack:
\n
{}"
.
format
(
f
,
e
,
str
(
stack_info
)))
raise
e
# sample list to batch data
batch
=
list
(
zip
(
*
data
))
return
batch
class
ColorDistort
(
object
):
"""Random color distortion.
Args:
hue (list): hue settings.
in [lower, upper, probability] format.
saturation (list): saturation settings.
in [lower, upper, probability] format.
contrast (list): contrast settings.
in [lower, upper, probability] format.
brightness (list): brightness settings.
in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order.
"""
def
__init__
(
self
,
hue
=
[
-
18
,
18
,
0.5
],
saturation
=
[
0.5
,
1.5
,
0.5
],
contrast
=
[
0.5
,
1.5
,
0.5
],
brightness
=
[
0.5
,
1.5
,
0.5
],
random_apply
=
True
):
self
.
hue
=
hue
self
.
saturation
=
saturation
self
.
contrast
=
contrast
self
.
brightness
=
brightness
self
.
random_apply
=
random_apply
def
apply_hue
(
self
,
img
):
low
,
high
,
prob
=
self
.
hue
if
np
.
random
.
uniform
(
0.
,
1.
)
<
prob
:
return
img
img
=
img
.
astype
(
np
.
float32
)
# XXX works, but result differ from HSV version
delta
=
np
.
random
.
uniform
(
low
,
high
)
u
=
np
.
cos
(
delta
*
np
.
pi
)
w
=
np
.
sin
(
delta
*
np
.
pi
)
bt
=
np
.
array
([[
1.0
,
0.0
,
0.0
],
[
0.0
,
u
,
-
w
],
[
0.0
,
w
,
u
]])
tyiq
=
np
.
array
([[
0.299
,
0.587
,
0.114
],
[
0.596
,
-
0.274
,
-
0.321
],
[
0.211
,
-
0.523
,
0.311
]])
ityiq
=
np
.
array
([[
1.0
,
0.956
,
0.621
],
[
1.0
,
-
0.272
,
-
0.647
],
[
1.0
,
-
1.107
,
1.705
]])
t
=
np
.
dot
(
np
.
dot
(
ityiq
,
bt
),
tyiq
).
T
img
=
np
.
dot
(
img
,
t
)
return
img
def
apply_saturation
(
self
,
img
):
low
,
high
,
prob
=
self
.
saturation
if
np
.
random
.
uniform
(
0.
,
1.
)
<
prob
:
return
img
delta
=
np
.
random
.
uniform
(
low
,
high
)
img
=
img
.
astype
(
np
.
float32
)
gray
=
img
*
np
.
array
([[[
0.299
,
0.587
,
0.114
]]],
dtype
=
np
.
float32
)
gray
=
gray
.
sum
(
axis
=
2
,
keepdims
=
True
)
gray
*=
(
1.0
-
delta
)
img
*=
delta
img
+=
gray
return
img
def
apply_contrast
(
self
,
img
):
low
,
high
,
prob
=
self
.
contrast
if
np
.
random
.
uniform
(
0.
,
1.
)
<
prob
:
return
img
delta
=
np
.
random
.
uniform
(
low
,
high
)
img
=
img
.
astype
(
np
.
float32
)
img
*=
delta
return
img
def
apply_brightness
(
self
,
img
):
low
,
high
,
prob
=
self
.
brightness
if
np
.
random
.
uniform
(
0.
,
1.
)
<
prob
:
return
img
delta
=
np
.
random
.
uniform
(
low
,
high
)
img
=
img
.
astype
(
np
.
float32
)
img
+=
delta
return
img
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
if
self
.
random_apply
:
distortions
=
np
.
random
.
permutation
([
self
.
apply_brightness
,
self
.
apply_contrast
,
self
.
apply_saturation
,
self
.
apply_hue
])
for
func
in
distortions
:
im
=
func
(
im
)
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
im
=
self
.
apply_brightness
(
im
)
if
np
.
random
.
randint
(
0
,
2
):
im
=
self
.
apply_contrast
(
im
)
im
=
self
.
apply_saturation
(
im
)
im
=
self
.
apply_hue
(
im
)
else
:
im
=
self
.
apply_saturation
(
im
)
im
=
self
.
apply_hue
(
im
)
im
=
self
.
apply_contrast
(
im
)
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
class
RandomExpand
(
object
):
"""Random expand the canvas.
Args:
ratio (float): maximum expansion ratio.
prob (float): probability to expand.
fill_value (list): color value used to fill the canvas. in RGB order.
"""
def
__init__
(
self
,
ratio
=
4.
,
prob
=
0.5
,
fill_value
=
[
123.675
,
116.28
,
103.53
]):
assert
ratio
>
1.01
,
"expand ratio must be larger than 1.01"
self
.
ratio
=
ratio
self
.
prob
=
prob
self
.
fill_value
=
fill_value
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
if
np
.
random
.
uniform
(
0.
,
1.
)
<
self
.
prob
:
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
height
,
width
,
_
=
im
.
shape
expand_ratio
=
np
.
random
.
uniform
(
1.
,
self
.
ratio
)
h
=
int
(
height
*
expand_ratio
)
w
=
int
(
width
*
expand_ratio
)
if
not
h
>
height
or
not
w
>
width
:
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
y
=
np
.
random
.
randint
(
0
,
h
-
height
)
x
=
np
.
random
.
randint
(
0
,
w
-
width
)
canvas
=
np
.
ones
((
h
,
w
,
3
),
dtype
=
np
.
uint8
)
canvas
*=
np
.
array
(
self
.
fill_value
,
dtype
=
np
.
uint8
)
canvas
[
y
:
y
+
height
,
x
:
x
+
width
,
:]
=
im
.
astype
(
np
.
uint8
)
gt_bbox
+=
np
.
array
([
x
,
y
,
x
,
y
],
dtype
=
np
.
float32
)
return
[
im_info
,
canvas
,
gt_bbox
,
gt_class
,
gt_score
]
class
RandomCrop
():
"""Random crop image and bboxes.
Args:
aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format.
thresholds (list): iou thresholds for decide a valid bbox crop.
scaling (list): ratio between a cropped region and the original image.
in [min, max] format.
num_attempts (int): number of tries before giving up.
allow_no_crop (bool): allow return without actually cropping them.
cover_all_box (bool): ensure all bboxes are covered in the final crop.
"""
def
__init__
(
self
,
aspect_ratio
=
[.
5
,
2.
],
thresholds
=
[.
0
,
.
1
,
.
3
,
.
5
,
.
7
,
.
9
],
scaling
=
[.
3
,
1.
],
num_attempts
=
50
,
allow_no_crop
=
True
,
cover_all_box
=
False
):
self
.
aspect_ratio
=
aspect_ratio
self
.
thresholds
=
thresholds
self
.
scaling
=
scaling
self
.
num_attempts
=
num_attempts
self
.
allow_no_crop
=
allow_no_crop
self
.
cover_all_box
=
cover_all_box
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
if
len
(
gt_bbox
)
==
0
:
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
# NOTE Original method attempts to generate one candidate for each
# threshold then randomly sample one from the resulting list.
# Here a short circuit approach is taken, i.e., randomly choose a
# threshold and attempt to find a valid crop, and simply return the
# first one found.
# The probability is not exactly the same, kinda resembling the
# "Monty Hall" problem. Actually carrying out the attempts will affect
# observability (just like opening doors in the "Monty Hall" game).
thresholds
=
list
(
self
.
thresholds
)
if
self
.
allow_no_crop
:
thresholds
.
append
(
'no_crop'
)
np
.
random
.
shuffle
(
thresholds
)
for
thresh
in
thresholds
:
if
thresh
==
'no_crop'
:
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
h
,
w
,
_
=
im
.
shape
found
=
False
for
i
in
range
(
self
.
num_attempts
):
scale
=
np
.
random
.
uniform
(
*
self
.
scaling
)
min_ar
,
max_ar
=
self
.
aspect_ratio
aspect_ratio
=
np
.
random
.
uniform
(
max
(
min_ar
,
scale
**
2
),
min
(
max_ar
,
scale
**-
2
))
crop_h
=
int
(
h
*
scale
/
np
.
sqrt
(
aspect_ratio
))
crop_w
=
int
(
w
*
scale
*
np
.
sqrt
(
aspect_ratio
))
crop_y
=
np
.
random
.
randint
(
0
,
h
-
crop_h
)
crop_x
=
np
.
random
.
randint
(
0
,
w
-
crop_w
)
crop_box
=
[
crop_x
,
crop_y
,
crop_x
+
crop_w
,
crop_y
+
crop_h
]
iou
=
self
.
_iou_matrix
(
gt_bbox
,
np
.
array
(
[
crop_box
],
dtype
=
np
.
float32
))
if
iou
.
max
()
<
thresh
:
continue
if
self
.
cover_all_box
and
iou
.
min
()
<
thresh
:
continue
cropped_box
,
valid_ids
=
self
.
_crop_box_with_center_constraint
(
gt_bbox
,
np
.
array
(
crop_box
,
dtype
=
np
.
float32
))
if
valid_ids
.
size
>
0
:
found
=
True
break
if
found
:
im
=
self
.
_crop_image
(
im
,
crop_box
)
gt_bbox
=
np
.
take
(
cropped_box
,
valid_ids
,
axis
=
0
)
gt_class
=
np
.
take
(
gt_class
,
valid_ids
,
axis
=
0
)
gt_score
=
np
.
take
(
gt_score
,
valid_ids
,
axis
=
0
)
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
def
_iou_matrix
(
self
,
a
,
b
):
tl_i
=
np
.
maximum
(
a
[:,
np
.
newaxis
,
:
2
],
b
[:,
:
2
])
br_i
=
np
.
minimum
(
a
[:,
np
.
newaxis
,
2
:],
b
[:,
2
:])
area_i
=
np
.
prod
(
br_i
-
tl_i
,
axis
=
2
)
*
(
tl_i
<
br_i
).
all
(
axis
=
2
)
area_a
=
np
.
prod
(
a
[:,
2
:]
-
a
[:,
:
2
],
axis
=
1
)
area_b
=
np
.
prod
(
b
[:,
2
:]
-
b
[:,
:
2
],
axis
=
1
)
area_o
=
(
area_a
[:,
np
.
newaxis
]
+
area_b
-
area_i
)
return
area_i
/
(
area_o
+
1e-10
)
def
_crop_box_with_center_constraint
(
self
,
box
,
crop
):
cropped_box
=
box
.
copy
()
cropped_box
[:,
:
2
]
=
np
.
maximum
(
box
[:,
:
2
],
crop
[:
2
])
cropped_box
[:,
2
:]
=
np
.
minimum
(
box
[:,
2
:],
crop
[
2
:])
cropped_box
[:,
:
2
]
-=
crop
[:
2
]
cropped_box
[:,
2
:]
-=
crop
[:
2
]
centers
=
(
box
[:,
:
2
]
+
box
[:,
2
:])
/
2
valid
=
np
.
logical_and
(
crop
[:
2
]
<=
centers
,
centers
<
crop
[
2
:]).
all
(
axis
=
1
)
valid
=
np
.
logical_and
(
valid
,
(
cropped_box
[:,
:
2
]
<
cropped_box
[:,
2
:]).
all
(
axis
=
1
))
return
cropped_box
,
np
.
where
(
valid
)[
0
]
def
_crop_image
(
self
,
img
,
crop
):
x1
,
y1
,
x2
,
y2
=
crop
return
img
[
y1
:
y2
,
x1
:
x2
,
:]
class
RandomFlip
():
def
__init__
(
self
,
prob
=
0.5
,
is_normalized
=
False
):
"""
Args:
prob (float): the probability of flipping image
is_normalized (bool): whether the bbox scale to [0,1]
"""
self
.
prob
=
prob
self
.
is_normalized
=
is_normalized
if
not
(
isinstance
(
self
.
prob
,
float
)
and
isinstance
(
self
.
is_normalized
,
bool
)):
raise
TypeError
(
"{}: input type is invalid."
.
format
(
self
))
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
"""Filp the image and bounding box.
Operators:
1. Flip the image numpy.
2. Transform the bboxes' x coordinates.
(Must judge whether the coordinates are normalized!)
"""
if
not
isinstance
(
im
,
np
.
ndarray
):
raise
TypeError
(
"{}: image is not a numpy array."
.
format
(
self
))
if
len
(
im
.
shape
)
!=
3
:
raise
ImageError
(
"{}: image is not 3-dimensional."
.
format
(
self
))
height
,
width
,
_
=
im
.
shape
if
np
.
random
.
uniform
(
0
,
1
)
<
self
.
prob
:
im
=
im
[:,
::
-
1
,
:]
if
gt_bbox
.
shape
[
0
]
>
0
:
oldx1
=
gt_bbox
[:,
0
].
copy
()
oldx2
=
gt_bbox
[:,
2
].
copy
()
if
self
.
is_normalized
:
gt_bbox
[:,
0
]
=
1
-
oldx2
gt_bbox
[:,
2
]
=
1
-
oldx1
else
:
gt_bbox
[:,
0
]
=
width
-
oldx2
-
1
gt_bbox
[:,
2
]
=
width
-
oldx1
-
1
if
gt_bbox
.
shape
[
0
]
!=
0
and
(
gt_bbox
[:,
2
]
<
gt_bbox
[:,
0
]).
all
():
m
=
"{}: invalid box, x2 should be greater than x1"
.
format
(
self
)
raise
ValueError
(
m
)
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
class
NormalizeBox
(
object
):
"""Transform the bounding box's coornidates to [0,1]."""
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
height
,
width
,
_
=
im
.
shape
for
i
in
range
(
gt_bbox
.
shape
[
0
]):
gt_bbox
[
i
][
0
]
=
gt_bbox
[
i
][
0
]
/
width
gt_bbox
[
i
][
1
]
=
gt_bbox
[
i
][
1
]
/
height
gt_bbox
[
i
][
2
]
=
gt_bbox
[
i
][
2
]
/
width
gt_bbox
[
i
][
3
]
=
gt_bbox
[
i
][
3
]
/
height
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
class
PadBox
(
object
):
def
__init__
(
self
,
num_max_boxes
=
50
):
"""
Pad zeros to bboxes if number of bboxes is less than num_max_boxes.
Args:
num_max_boxes (int): the max number of bboxes
"""
self
.
num_max_boxes
=
num_max_boxes
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
gt_num
=
min
(
self
.
num_max_boxes
,
len
(
gt_bbox
))
num_max
=
self
.
num_max_boxes
pad_bbox
=
np
.
zeros
((
num_max
,
4
),
dtype
=
np
.
float32
)
if
gt_num
>
0
:
pad_bbox
[:
gt_num
,
:]
=
gt_bbox
[:
gt_num
,
:]
gt_bbox
=
pad_bbox
pad_class
=
np
.
zeros
((
num_max
),
dtype
=
np
.
int32
)
if
gt_num
>
0
:
pad_class
[:
gt_num
]
=
gt_class
[:
gt_num
,
0
]
gt_class
=
pad_class
pad_score
=
np
.
zeros
((
num_max
),
dtype
=
np
.
float32
)
if
gt_num
>
0
:
pad_score
[:
gt_num
]
=
gt_score
[:
gt_num
,
0
]
gt_score
=
pad_score
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
class
BboxXYXY2XYWH
(
object
):
"""
Convert bbox XYXY format to XYWH format.
"""
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
gt_bbox
[:,
2
:
4
]
=
gt_bbox
[:,
2
:
4
]
-
gt_bbox
[:,
:
2
]
gt_bbox
[:,
:
2
]
=
gt_bbox
[:,
:
2
]
+
gt_bbox
[:,
2
:
4
]
/
2.
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
class
RandomShape
(
object
):
"""
Randomly reshape a batch. If random_inter is True, also randomly
select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is
False, use cv2.INTER_NEAREST.
Args:
sizes (list): list of int, random choose a size from these
random_inter (bool): whether to randomly interpolation, defalut true.
"""
def
__init__
(
self
,
sizes
=
[
320
,
352
,
384
,
416
,
448
,
480
,
512
,
544
,
576
,
608
],
random_inter
=
True
):
self
.
sizes
=
sizes
self
.
random_inter
=
random_inter
self
.
interps
=
[
cv2
.
INTER_NEAREST
,
cv2
.
INTER_LINEAR
,
cv2
.
INTER_AREA
,
cv2
.
INTER_CUBIC
,
cv2
.
INTER_LANCZOS4
,
]
if
random_inter
else
[]
def
__call__
(
self
,
samples
):
shape
=
np
.
random
.
choice
(
self
.
sizes
)
method
=
np
.
random
.
choice
(
self
.
interps
)
if
self
.
random_inter
\
else
cv2
.
INTER_NEAREST
for
i
in
range
(
len
(
samples
)):
im
=
samples
[
i
][
1
]
h
,
w
=
im
.
shape
[:
2
]
scale_x
=
float
(
shape
)
/
w
scale_y
=
float
(
shape
)
/
h
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
scale_x
,
fy
=
scale_y
,
interpolation
=
method
)
samples
[
i
][
1
]
=
im
return
samples
class
NormalizeImage
(
object
):
def
__init__
(
self
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
],
scale
=
True
,
channel_first
=
True
):
"""
Args:
mean (list): the pixel mean
std (list): the pixel variance
scale (bool): whether scale image to [0, 1]
channel_first (bool): whehter change [h, w, c] to [c, h, w]
"""
self
.
mean
=
mean
self
.
std
=
std
self
.
scale
=
scale
self
.
channel_first
=
channel_first
if
not
(
isinstance
(
self
.
mean
,
list
)
and
isinstance
(
self
.
std
,
list
)
and
isinstance
(
self
.
scale
,
bool
)):
raise
TypeError
(
"{}: input type is invalid."
.
format
(
self
))
from
functools
import
reduce
if
reduce
(
lambda
x
,
y
:
x
*
y
,
self
.
std
)
==
0
:
raise
ValueError
(
'{}: std is invalid!'
.
format
(
self
))
def
__call__
(
self
,
samples
):
"""Normalize the image.
Operators:
1. (optional) Scale the image to [0,1]
2. Each pixel minus mean and is divided by std
3. (optional) permute channel
"""
for
i
in
range
(
len
(
samples
)):
im
=
samples
[
i
][
1
]
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
if
self
.
scale
:
im
=
im
/
255.0
im
-=
mean
im
/=
std
if
self
.
channel_first
:
im
=
im
.
transpose
((
2
,
0
,
1
))
samples
[
i
][
1
]
=
im
return
samples
def
_iou_matrix
(
a
,
b
):
tl_i
=
np
.
maximum
(
a
[:,
np
.
newaxis
,
:
2
],
b
[:,
:
2
])
br_i
=
np
.
minimum
(
a
[:,
np
.
newaxis
,
2
:],
b
[:,
2
:])
area_i
=
np
.
prod
(
br_i
-
tl_i
,
axis
=
2
)
*
(
tl_i
<
br_i
).
all
(
axis
=
2
)
area_a
=
np
.
prod
(
a
[:,
2
:]
-
a
[:,
:
2
],
axis
=
1
)
area_b
=
np
.
prod
(
b
[:,
2
:]
-
b
[:,
:
2
],
axis
=
1
)
area_o
=
(
area_a
[:,
np
.
newaxis
]
+
area_b
-
area_i
)
return
area_i
/
(
area_o
+
1e-10
)
def
_crop_box_with_center_constraint
(
box
,
crop
):
cropped_box
=
box
.
copy
()
cropped_box
[:,
:
2
]
=
np
.
maximum
(
box
[:,
:
2
],
crop
[:
2
])
cropped_box
[:,
2
:]
=
np
.
minimum
(
box
[:,
2
:],
crop
[
2
:])
cropped_box
[:,
:
2
]
-=
crop
[:
2
]
cropped_box
[:,
2
:]
-=
crop
[:
2
]
centers
=
(
box
[:,
:
2
]
+
box
[:,
2
:])
/
2
valid
=
np
.
logical_and
(
crop
[:
2
]
<=
centers
,
centers
<
crop
[
2
:]).
all
(
axis
=
1
)
valid
=
np
.
logical_and
(
valid
,
(
cropped_box
[:,
:
2
]
<
cropped_box
[:,
2
:]).
all
(
axis
=
1
))
return
cropped_box
,
np
.
where
(
valid
)[
0
]
def
random_crop
(
inputs
):
aspect_ratios
=
[.
5
,
2.
]
thresholds
=
[.
0
,
.
1
,
.
3
,
.
5
,
.
7
,
.
9
]
scaling
=
[.
3
,
1.
]
img
,
img_ids
,
gt_box
,
gt_label
=
inputs
h
,
w
=
img
.
shape
[:
2
]
if
len
(
gt_box
)
==
0
:
return
inputs
np
.
random
.
shuffle
(
thresholds
)
for
thresh
in
thresholds
:
found
=
False
for
i
in
range
(
50
):
scale
=
np
.
random
.
uniform
(
*
scaling
)
min_ar
,
max_ar
=
aspect_ratios
ar
=
np
.
random
.
uniform
(
max
(
min_ar
,
scale
**
2
),
min
(
max_ar
,
scale
**-
2
))
crop_h
=
int
(
h
*
scale
/
np
.
sqrt
(
ar
))
crop_w
=
int
(
w
*
scale
*
np
.
sqrt
(
ar
))
crop_y
=
np
.
random
.
randint
(
0
,
h
-
crop_h
)
crop_x
=
np
.
random
.
randint
(
0
,
w
-
crop_w
)
crop_box
=
[
crop_x
,
crop_y
,
crop_x
+
crop_w
,
crop_y
+
crop_h
]
iou
=
_iou_matrix
(
gt_box
,
np
.
array
([
crop_box
],
dtype
=
np
.
float32
))
if
iou
.
max
()
<
thresh
:
continue
cropped_box
,
valid_ids
=
_crop_box_with_center_constraint
(
gt_box
,
np
.
array
(
crop_box
,
dtype
=
np
.
float32
))
if
valid_ids
.
size
>
0
:
found
=
True
break
if
found
:
x1
,
y1
,
x2
,
y2
=
crop_box
img
=
img
[
y1
:
y2
,
x1
:
x2
,
:]
gt_box
=
np
.
take
(
cropped_box
,
valid_ids
,
axis
=
0
)
gt_label
=
np
.
take
(
gt_label
,
valid_ids
,
axis
=
0
)
return
img
,
img_ids
,
gt_box
,
gt_label
return
inputs
class
ResizeImage
(
object
):
def
__init__
(
self
,
target_size
=
0
,
interp
=
cv2
.
INTER_CUBIC
):
"""
Rescale image to the specified target size.
If target_size is list, selected a scale randomly as the specified
target size.
Args:
target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list.
interp (int): the interpolation method
"""
self
.
interp
=
int
(
interp
)
if
not
(
isinstance
(
target_size
,
int
)
or
isinstance
(
target_size
,
list
)):
raise
TypeError
(
"Type of target_size is invalid. Must be Integer or List, now is {}"
.
format
(
type
(
target_size
)))
self
.
target_size
=
target_size
def
__call__
(
self
,
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
):
""" Resize the image numpy.
"""
if
not
isinstance
(
im
,
np
.
ndarray
):
raise
TypeError
(
"{}: image type is not numpy."
.
format
(
self
))
if
len
(
im
.
shape
)
!=
3
:
raise
ImageError
(
'{}: image is not 3-dimensional.'
.
format
(
self
))
im_shape
=
im
.
shape
im_scale_x
=
float
(
self
.
target_size
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
self
.
target_size
)
/
float
(
im_shape
[
0
])
resize_w
=
self
.
target_size
resize_h
=
self
.
target_size
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
return
[
im_info
,
im
,
gt_bbox
,
gt_class
,
gt_score
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录