Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0c81759a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c81759a
编写于
4月 01, 2020
作者:
Z
zhaoting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add YOLOv3 infer scipt and change dataset to MindRecord
上级
cc0ba93d
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
704 addition
and
227 deletion
+704
-227
example/yolov3_coco2017/config.py
example/yolov3_coco2017/config.py
+3
-0
example/yolov3_coco2017/dataset.py
example/yolov3_coco2017/dataset.py
+97
-168
example/yolov3_coco2017/eval.py
example/yolov3_coco2017/eval.py
+107
-0
example/yolov3_coco2017/run_distribute_train.sh
example/yolov3_coco2017/run_distribute_train.sh
+17
-7
example/yolov3_coco2017/run_eval.sh
example/yolov3_coco2017/run_eval.sh
+23
-0
example/yolov3_coco2017/run_standalone_train.sh
example/yolov3_coco2017/run_standalone_train.sh
+5
-3
example/yolov3_coco2017/train.py
example/yolov3_coco2017/train.py
+78
-36
example/yolov3_coco2017/util.py
example/yolov3_coco2017/util.py
+146
-0
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
+1
-0
mindspore/model_zoo/yolov3.py
mindspore/model_zoo/yolov3.py
+84
-13
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+2
-0
mindspore/ops/_op_impl/tbe/reduce_min.py
mindspore/ops/_op_impl/tbe/reduce_min.py
+76
-0
mindspore/ops/_op_impl/tbe/round.py
mindspore/ops/_op_impl/tbe/round.py
+65
-0
未找到文件。
example/yolov3_coco2017/config.py
浏览文件 @
0c81759a
...
@@ -26,6 +26,7 @@ class ConfigYOLOV3ResNet18:
...
@@ -26,6 +26,7 @@ class ConfigYOLOV3ResNet18:
img_shape
=
[
352
,
640
]
img_shape
=
[
352
,
640
]
feature_shape
=
[
32
,
3
,
352
,
640
]
feature_shape
=
[
32
,
3
,
352
,
640
]
num_classes
=
80
num_classes
=
80
nms_max_num
=
50
backbone_input_shape
=
[
64
,
64
,
128
,
256
]
backbone_input_shape
=
[
64
,
64
,
128
,
256
]
backbone_shape
=
[
64
,
128
,
256
,
512
]
backbone_shape
=
[
64
,
128
,
256
,
512
]
...
@@ -33,6 +34,8 @@ class ConfigYOLOV3ResNet18:
...
@@ -33,6 +34,8 @@ class ConfigYOLOV3ResNet18:
backbone_stride
=
[
1
,
2
,
2
,
2
]
backbone_stride
=
[
1
,
2
,
2
,
2
]
ignore_threshold
=
0.5
ignore_threshold
=
0.5
obj_threshold
=
0.3
nms_threshold
=
0.4
anchor_scales
=
[(
10
,
13
),
anchor_scales
=
[(
10
,
13
),
(
16
,
30
),
(
16
,
30
),
...
...
example/yolov3_coco2017/dataset.py
浏览文件 @
0c81759a
...
@@ -16,16 +16,14 @@
...
@@ -16,16 +16,14 @@
"""YOLOv3 dataset"""
"""YOLOv3 dataset"""
from
__future__
import
division
from
__future__
import
division
import
abc
import
io
import
os
import
os
import
math
import
json
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
from
matplotlib.colors
import
rgb_to_hsv
,
hsv_to_rgb
from
matplotlib.colors
import
rgb_to_hsv
,
hsv_to_rgb
import
mindspore.dataset
as
de
import
mindspore.dataset
as
de
from
mindspore.mindrecord
import
FileWriter
import
mindspore.dataset.transforms.vision.py_transforms
as
P
import
mindspore.dataset.transforms.vision.py_transforms
as
P
import
mindspore.dataset.transforms.vision.c_transforms
as
C
from
config
import
ConfigYOLOV3ResNet18
from
config
import
ConfigYOLOV3ResNet18
iter_cnt
=
0
iter_cnt
=
0
...
@@ -114,6 +112,29 @@ def preprocess_fn(image, box, is_training):
...
@@ -114,6 +112,29 @@ def preprocess_fn(image, box, is_training):
return
y_true
[
0
],
y_true
[
1
],
y_true
[
2
],
pad_gt_box0
,
pad_gt_box1
,
pad_gt_box2
return
y_true
[
0
],
y_true
[
1
],
y_true
[
2
],
pad_gt_box0
,
pad_gt_box1
,
pad_gt_box2
def
_infer_data
(
img_data
,
input_shape
,
box
):
w
,
h
=
img_data
.
size
input_h
,
input_w
=
input_shape
scale
=
min
(
float
(
input_w
)
/
float
(
w
),
float
(
input_h
)
/
float
(
h
))
nw
=
int
(
w
*
scale
)
nh
=
int
(
h
*
scale
)
img_data
=
img_data
.
resize
((
nw
,
nh
),
Image
.
BICUBIC
)
new_image
=
np
.
zeros
((
input_h
,
input_w
,
3
),
np
.
float32
)
new_image
.
fill
(
128
)
img_data
=
np
.
array
(
img_data
)
if
len
(
img_data
.
shape
)
==
2
:
img_data
=
np
.
expand_dims
(
img_data
,
axis
=-
1
)
img_data
=
np
.
concatenate
([
img_data
,
img_data
,
img_data
],
axis
=-
1
)
dh
=
int
((
input_h
-
nh
)
/
2
)
dw
=
int
((
input_w
-
nw
)
/
2
)
new_image
[
dh
:(
nh
+
dh
),
dw
:(
nw
+
dw
),
:]
=
img_data
new_image
/=
255.
new_image
=
np
.
transpose
(
new_image
,
(
2
,
0
,
1
))
new_image
=
np
.
expand_dims
(
new_image
,
0
)
return
new_image
,
np
.
array
([
h
,
w
],
np
.
float32
),
box
def
_data_aug
(
image
,
box
,
is_training
,
jitter
=
0.3
,
hue
=
0.1
,
sat
=
1.5
,
val
=
1.5
,
image_size
=
(
352
,
640
)):
def
_data_aug
(
image
,
box
,
is_training
,
jitter
=
0.3
,
hue
=
0.1
,
sat
=
1.5
,
val
=
1.5
,
image_size
=
(
352
,
640
)):
"""Data augmentation function."""
"""Data augmentation function."""
if
not
isinstance
(
image
,
Image
.
Image
):
if
not
isinstance
(
image
,
Image
.
Image
):
...
@@ -124,32 +145,7 @@ def preprocess_fn(image, box, is_training):
...
@@ -124,32 +145,7 @@ def preprocess_fn(image, box, is_training):
h
,
w
=
image_size
h
,
w
=
image_size
if
not
is_training
:
if
not
is_training
:
image
=
image
.
resize
((
w
,
h
),
Image
.
BICUBIC
)
return
_infer_data
(
image
,
image_size
,
box
)
image_data
=
np
.
array
(
image
)
/
255.
if
len
(
image_data
.
shape
)
==
2
:
image_data
=
np
.
expand_dims
(
image_data
,
axis
=-
1
)
image_data
=
np
.
concatenate
([
image_data
,
image_data
,
image_data
],
axis
=-
1
)
image_data
=
image_data
.
astype
(
np
.
float32
)
# correct boxes
box_data
=
np
.
zeros
((
max_boxes
,
5
))
if
len
(
box
)
>=
1
:
np
.
random
.
shuffle
(
box
)
if
len
(
box
)
>
max_boxes
:
box
=
box
[:
max_boxes
]
# xmin ymin xmax ymax
box
[:,
[
0
,
2
]]
=
box
[:,
[
0
,
2
]]
*
float
(
w
)
/
float
(
iw
)
box
[:,
[
1
,
3
]]
=
box
[:,
[
1
,
3
]]
*
float
(
h
)
/
float
(
ih
)
box_data
[:
len
(
box
)]
=
box
else
:
image_data
,
box_data
=
None
,
None
# preprocess bounding boxes
bbox_true_1
,
bbox_true_2
,
bbox_true_3
,
gt_box1
,
gt_box2
,
gt_box3
=
\
_preprocess_true_boxes
(
box_data
,
anchors
,
image_size
)
return
image_data
,
bbox_true_1
,
bbox_true_2
,
bbox_true_3
,
\
ori_image_shape
,
gt_box1
,
gt_box2
,
gt_box3
flip
=
_rand
()
<
.
5
flip
=
_rand
()
<
.
5
# correct boxes
# correct boxes
...
@@ -235,12 +231,16 @@ def preprocess_fn(image, box, is_training):
...
@@ -235,12 +231,16 @@ def preprocess_fn(image, box, is_training):
return
image_data
,
bbox_true_1
,
bbox_true_2
,
bbox_true_3
,
\
return
image_data
,
bbox_true_1
,
bbox_true_2
,
bbox_true_3
,
\
ori_image_shape
,
gt_box1
,
gt_box2
,
gt_box3
ori_image_shape
,
gt_box1
,
gt_box2
,
gt_box3
images
,
bbox_1
,
bbox_2
,
bbox_3
,
_
,
gt_box1
,
gt_box2
,
gt_box3
=
_data_aug
(
image
,
box
,
is_training
)
if
is_training
:
return
images
,
bbox_1
,
bbox_2
,
bbox_3
,
gt_box1
,
gt_box2
,
gt_box3
images
,
bbox_1
,
bbox_2
,
bbox_3
,
_
,
gt_box1
,
gt_box2
,
gt_box3
=
_data_aug
(
image
,
box
,
is_training
)
return
images
,
bbox_1
,
bbox_2
,
bbox_3
,
gt_box1
,
gt_box2
,
gt_box3
images
,
shape
,
anno
=
_data_aug
(
image
,
box
,
is_training
)
return
images
,
shape
,
anno
def
anno_parser
(
annos_str
):
def
anno_parser
(
annos_str
):
"""
Annotation parser
."""
"""
Parse annotation from string to list
."""
annos
=
[]
annos
=
[]
for
anno_str
in
annos_str
:
for
anno_str
in
annos_str
:
anno
=
list
(
map
(
int
,
anno_str
.
strip
().
split
(
','
)))
anno
=
list
(
map
(
int
,
anno_str
.
strip
().
split
(
','
)))
...
@@ -248,142 +248,71 @@ def anno_parser(annos_str):
...
@@ -248,142 +248,71 @@ def anno_parser(annos_str):
return
annos
return
annos
def
expand_path
(
path
):
def
filter_valid_data
(
image_dir
,
anno_path
):
"""Get file list from path."""
"""Filter valid image file, which both in image_dir and anno_path."""
files
=
[]
image_files
=
[]
if
os
.
path
.
isdir
(
path
):
image_anno_dict
=
{}
for
file
in
os
.
listdir
(
path
):
if
not
os
.
path
.
isdir
(
image_dir
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
path
,
file
)):
files
.
append
(
file
)
else
:
raise
RuntimeError
(
"Path given is not valid."
)
raise
RuntimeError
(
"Path given is not valid."
)
return
files
if
not
os
.
path
.
isfile
(
anno_path
):
raise
RuntimeError
(
"Annotation file is not valid."
)
def
read_image
(
img_path
):
with
open
(
anno_path
,
"rb"
)
as
f
:
"""Read image with PIL."""
lines
=
f
.
readlines
()
with
open
(
img_path
,
"rb"
)
as
f
:
for
line
in
lines
:
img
=
f
.
read
()
line_str
=
line
.
decode
(
"utf-8"
).
strip
()
data
=
io
.
BytesIO
(
img
)
line_split
=
str
(
line_str
).
split
(
' '
)
img
=
Image
.
open
(
data
)
file_name
=
line_split
[
0
]
return
np
.
array
(
img
)
if
os
.
path
.
isfile
(
os
.
path
.
join
(
image_dir
,
file_name
)):
image_anno_dict
[
file_name
]
=
anno_parser
(
line_split
[
1
:])
image_files
.
append
(
file_name
)
class
BaseDataset
():
return
image_files
,
image_anno_dict
"""BaseDataset for GeneratorDataset iterator."""
def
__init__
(
self
,
image_dir
,
anno_path
):
self
.
image_dir
=
image_dir
def
data_to_mindrecord_byte_image
(
image_dir
,
anno_path
,
mindrecord_dir
,
prefix
=
"yolo.mindrecord"
,
file_num
=
8
):
self
.
anno_path
=
anno_path
"""Create MindRecord file by image_dir and anno_path."""
self
.
cur_index
=
0
mindrecord_path
=
os
.
path
.
join
(
mindrecord_dir
,
prefix
)
self
.
samples
=
[]
writer
=
FileWriter
(
mindrecord_path
,
file_num
)
self
.
image_anno_dict
=
{}
image_files
,
image_anno_dict
=
filter_valid_data
(
image_dir
,
anno_path
)
self
.
_load_samples
()
yolo_json
=
{
def
__getitem__
(
self
,
item
):
"image"
:
{
"type"
:
"bytes"
},
sample
=
self
.
samples
[
item
]
"annotation"
:
{
"type"
:
"int64"
,
"shape"
:
[
-
1
,
5
]},
return
self
.
_next_data
(
sample
,
self
.
image_dir
,
self
.
image_anno_dict
)
}
writer
.
add_schema
(
yolo_json
,
"yolo_json"
)
def
__len__
(
self
):
return
len
(
self
.
samples
)
for
image_name
in
image_files
:
image_path
=
os
.
path
.
join
(
image_dir
,
image_name
)
@
staticmethod
with
open
(
image_path
,
'rb'
)
as
f
:
def
_next_data
(
sample
,
image_dir
,
image_anno_dict
):
img
=
f
.
read
()
"""Get next data."""
annos
=
np
.
array
(
image_anno_dict
[
image_name
])
image
=
read_image
(
os
.
path
.
join
(
image_dir
,
sample
))
row
=
{
"image"
:
img
,
"annotation"
:
annos
}
annos
=
image_anno_dict
[
sample
]
writer
.
write_raw_data
([
row
])
return
[
np
.
array
(
image
),
np
.
array
(
annos
)]
writer
.
commit
()
@
abc
.
abstractmethod
def
_load_samples
(
self
):
def
create_yolo_dataset
(
mindrecord_dir
,
batch_size
=
32
,
repeat_num
=
10
,
device_num
=
1
,
rank
=
0
,
"""Base load samples."""
class
YoloDataset
(
BaseDataset
):
"""YoloDataset for GeneratorDataset iterator."""
def
_load_samples
(
self
):
"""Load samples."""
image_files_raw
=
expand_path
(
self
.
image_dir
)
self
.
samples
=
self
.
_filter_valid_data
(
self
.
anno_path
,
image_files_raw
)
self
.
dataset_size
=
len
(
self
.
samples
)
if
self
.
dataset_size
==
0
:
raise
RuntimeError
(
"Valid dataset is none!"
)
def
_filter_valid_data
(
self
,
anno_path
,
image_files_raw
):
"""Filter valid data."""
image_files
=
[]
anno_dict
=
{}
print
(
"Start filter valid data."
)
with
open
(
anno_path
,
"rb"
)
as
f
:
lines
=
f
.
readlines
()
for
line
in
lines
:
line_str
=
line
.
decode
(
"utf-8"
)
line_split
=
str
(
line_str
).
split
(
' '
)
anno_dict
[
line_split
[
0
].
split
(
"/"
)[
-
1
]]
=
line_split
[
1
:]
anno_set
=
set
(
anno_dict
.
keys
())
image_set
=
set
(
image_files_raw
)
for
image_file
in
(
anno_set
&
image_set
):
image_files
.
append
(
image_file
)
self
.
image_anno_dict
[
image_file
]
=
anno_parser
(
anno_dict
[
image_file
])
image_files
.
sort
()
print
(
"Filter valid data done!"
)
return
image_files
class
DistributedSampler
():
"""DistributedSampler for YOLOv3"""
def
__init__
(
self
,
dataset_size
,
batch_size
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
True
):
if
num_replicas
is
None
:
num_replicas
=
1
if
rank
is
None
:
rank
=
0
self
.
dataset_size
=
dataset_size
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
%
num_replicas
self
.
epoch
=
0
self
.
num_samples
=
max
(
batch_size
,
int
(
math
.
ceil
(
dataset_size
*
1.0
/
self
.
num_replicas
)))
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
# deterministically shuffle based on epoch
if
self
.
shuffle
:
indices
=
np
.
random
.
RandomState
(
seed
=
self
.
epoch
).
permutation
(
self
.
dataset_size
)
indices
=
indices
.
tolist
()
else
:
indices
=
list
(
range
(
self
.
dataset_size
))
# add extra samples to make it evenly divisible
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
create_yolo_dataset
(
image_dir
,
anno_path
,
batch_size
=
32
,
repeat_num
=
10
,
device_num
=
1
,
rank
=
0
,
is_training
=
True
,
num_parallel_workers
=
8
):
is_training
=
True
,
num_parallel_workers
=
8
):
"""Creatr YOLOv3 dataset with
Generator
Dataset."""
"""Creatr YOLOv3 dataset with
Mind
Dataset."""
yolo_dataset
=
YoloDataset
(
image_dir
=
image_dir
,
anno_path
=
anno_path
)
ds
=
de
.
MindDataset
(
mindrecord_dir
,
columns_list
=
[
"image"
,
"annotation"
],
num_shards
=
device_num
,
shard_id
=
rank
,
distributed_sampler
=
DistributedSampler
(
yolo_dataset
.
dataset_size
,
batch_size
,
device_num
,
rank
)
num_parallel_workers
=
num_parallel_workers
,
shuffle
=
is_training
)
d
s
=
de
.
GeneratorDataset
(
yolo_dataset
,
column_names
=
[
"image"
,
"annotation"
],
sampler
=
distributed_sampler
)
d
ecode
=
C
.
Decode
(
)
ds
.
set_dataset_size
(
len
(
distributed_sampler
)
)
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
decode
)
compose_map_func
=
(
lambda
image
,
annotation
:
preprocess_fn
(
image
,
annotation
,
is_training
))
compose_map_func
=
(
lambda
image
,
annotation
:
preprocess_fn
(
image
,
annotation
,
is_training
))
hwc_to_chw
=
P
.
HWC2CHW
()
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
if
is_training
:
output_columns
=
[
"image"
,
"bbox_1"
,
"bbox_2"
,
"bbox_3"
,
"gt_box1"
,
"gt_box2"
,
"gt_box3"
],
hwc_to_chw
=
P
.
HWC2CHW
()
columns_order
=
[
"image"
,
"bbox_1"
,
"bbox_2"
,
"bbox_3"
,
"gt_box1"
,
"gt_box2"
,
"gt_box3"
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
operations
=
compose_map_func
,
num_parallel_workers
=
num_parallel_workers
)
output_columns
=
[
"image"
,
"bbox_1"
,
"bbox_2"
,
"bbox_3"
,
"gt_box1"
,
"gt_box2"
,
"gt_box3"
],
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc_to_chw
,
num_parallel_workers
=
num_parallel_workers
)
columns_order
=
[
"image"
,
"bbox_1"
,
"bbox_2"
,
"bbox_3"
,
"gt_box1"
,
"gt_box2"
,
"gt_box3"
],
ds
=
ds
.
shuffle
(
buffer_size
=
256
)
operations
=
compose_map_func
,
num_parallel_workers
=
num_parallel_workers
)
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
map
(
input_columns
=
[
"image"
],
operations
=
hwc_to_chw
,
num_parallel_workers
=
num_parallel_workers
)
ds
=
ds
.
repeat
(
repeat_num
)
ds
=
ds
.
shuffle
(
buffer_size
=
256
)
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
repeat_num
)
else
:
ds
=
ds
.
map
(
input_columns
=
[
"image"
,
"annotation"
],
output_columns
=
[
"image"
,
"image_shape"
,
"annotation"
],
columns_order
=
[
"image"
,
"image_shape"
,
"annotation"
],
operations
=
compose_map_func
,
num_parallel_workers
=
num_parallel_workers
)
return
ds
return
ds
example/yolov3_coco2017/eval.py
0 → 100644
浏览文件 @
0c81759a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for yolo_v3"""
import
os
import
argparse
import
time
from
mindspore
import
context
,
Tensor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.model_zoo.yolov3
import
yolov3_resnet18
,
YoloWithEval
from
dataset
import
create_yolo_dataset
,
data_to_mindrecord_byte_image
from
config
import
ConfigYOLOV3ResNet18
from
util
import
metrics
def
yolo_eval
(
dataset_path
,
ckpt_path
):
"""Yolov3 evaluation."""
ds
=
create_yolo_dataset
(
dataset_path
,
is_training
=
False
)
config
=
ConfigYOLOV3ResNet18
()
net
=
yolov3_resnet18
(
config
)
eval_net
=
YoloWithEval
(
net
,
config
)
print
(
"Load Checkpoint!"
)
param_dict
=
load_checkpoint
(
ckpt_path
)
load_param_into_net
(
net
,
param_dict
)
eval_net
.
set_train
(
False
)
i
=
1.
total
=
ds
.
get_dataset_size
()
start
=
time
.
time
()
pred_data
=
[]
print
(
"
\n
========================================
\n
"
)
print
(
"total images num: "
,
total
)
print
(
"Processing, please wait a moment."
)
for
data
in
ds
.
create_dict_iterator
():
img_np
=
data
[
'image'
]
image_shape
=
data
[
'image_shape'
]
annotation
=
data
[
'annotation'
]
eval_net
.
set_train
(
False
)
output
=
eval_net
(
Tensor
(
img_np
),
Tensor
(
image_shape
))
for
batch_idx
in
range
(
img_np
.
shape
[
0
]):
pred_data
.
append
({
"boxes"
:
output
[
0
].
asnumpy
()[
batch_idx
],
"box_scores"
:
output
[
1
].
asnumpy
()[
batch_idx
],
"annotation"
:
annotation
})
percent
=
round
(
i
/
total
*
100
,
2
)
print
(
' %s [%d/%d]'
%
(
str
(
percent
)
+
'%'
,
i
,
total
),
end
=
'
\r
'
)
i
+=
1
print
(
' %s [%d/%d] cost %d ms'
%
(
str
(
100.0
)
+
'%'
,
total
,
total
,
int
((
time
.
time
()
-
start
)
*
1000
)),
end
=
'
\n
'
)
precisions
,
recalls
=
metrics
(
pred_data
)
print
(
"
\n
========================================
\n
"
)
for
i
in
range
(
config
.
num_classes
):
print
(
"class {} precision is {:.2f}%, recall is {:.2f}%"
.
format
(
i
,
precisions
[
i
]
*
100
,
recalls
[
i
]
*
100
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Yolov3 evaluation'
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--mindrecord_dir"
,
type
=
str
,
default
=
"./Mindrecord_eval"
,
help
=
"Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
"rather than image_dir and anno_path. Default is ./Mindrecord_eval"
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
""
,
help
=
"Dataset directory, "
"the absolute image path is joined by the image_dir "
"and the relative path in anno_path."
)
parser
.
add_argument
(
"--anno_path"
,
type
=
str
,
default
=
""
,
help
=
"Annotation path."
)
parser
.
add_argument
(
"--ckpt_path"
,
type
=
str
,
required
=
True
,
help
=
"Checkpoint path."
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
enable_task_sink
=
True
,
enable_loop_sink
=
True
,
enable_mem_reuse
=
True
,
enable_auto_mixed_precision
=
False
)
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num.
if
not
os
.
path
.
isdir
(
args_opt
.
mindrecord_dir
):
os
.
makedirs
(
args_opt
.
mindrecord_dir
)
prefix
=
"yolo.mindrecord"
mindrecord_file
=
os
.
path
.
join
(
args_opt
.
mindrecord_dir
,
prefix
+
"0"
)
if
not
os
.
path
.
exists
(
mindrecord_file
):
if
os
.
path
.
isdir
(
args_opt
.
image_dir
)
and
os
.
path
.
exists
(
args_opt
.
anno_path
):
print
(
"Create Mindrecord"
)
data_to_mindrecord_byte_image
(
args_opt
.
image_dir
,
args_opt
.
anno_path
,
args_opt
.
mindrecord_dir
,
prefix
=
prefix
,
file_num
=
8
)
print
(
"Create Mindrecord Done, at {}"
.
format
(
args_opt
.
mindrecord_dir
))
else
:
print
(
"image_dir or anno_path not exits"
)
print
(
"Start Eval!"
)
yolo_eval
(
mindrecord_file
,
args_opt
.
ckpt_path
)
example/yolov3_coco2017/run_distribute_train.sh
浏览文件 @
0c81759a
...
@@ -14,17 +14,26 @@
...
@@ -14,17 +14,26 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"Please run the scipt as: "
echo
"sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
echo
"sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH"
echo
"for example: sh run_distribute_train.sh 8 100 ./dataset/coco/train2017 ./dataset/train.txt ./hccl.json"
echo
"for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json"
echo
"After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
echo
"It is better to use absolute path."
echo
"=============================================================================================================="
export
RANK_SIZE
=
$1
EPOCH_SIZE
=
$2
EPOCH_SIZE
=
$2
IMAGE_DIR
=
$3
MINDRECORD_DIR
=
$3
ANNO_PATH
=
$4
IMAGE_DIR
=
$4
export
MINDSPORE_HCCL_CONFIG_PATH
=
$5
ANNO_PATH
=
$5
# Before start distribute train, first create mindrecord files.
python train.py
--only_create_dataset
=
1
--mindrecord_dir
=
$MINDRECORD_DIR
--image_dir
=
$IMAGE_DIR
\
--anno_path
=
$ANNO_PATH
echo
"After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export
MINDSPORE_HCCL_CONFIG_PATH
=
$6
export
RANK_SIZE
=
$1
for
((
i
=
0
;
i<RANK_SIZE
;
i++
))
for
((
i
=
0
;
i<RANK_SIZE
;
i++
))
do
do
...
@@ -40,6 +49,7 @@ do
...
@@ -40,6 +49,7 @@ do
--distribute
=
1
\
--distribute
=
1
\
--device_num
=
$RANK_SIZE
\
--device_num
=
$RANK_SIZE
\
--device_id
=
$DEVICE_ID
\
--device_id
=
$DEVICE_ID
\
--mindrecord_dir
=
$MINDRECORD_DIR
\
--image_dir
=
$IMAGE_DIR
\
--image_dir
=
$IMAGE_DIR
\
--epoch_size
=
$EPOCH_SIZE
\
--epoch_size
=
$EPOCH_SIZE
\
--anno_path
=
$ANNO_PATH
>
log.txt 2>&1 &
--anno_path
=
$ANNO_PATH
>
log.txt 2>&1 &
...
...
example/yolov3_coco2017/run_eval.sh
0 → 100644
浏览文件 @
0c81759a
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"sh run_eval.sh DEVICE_ID CKPT_PATH MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
echo
"for example: sh run_eval.sh 0 yolo.ckpt ./Mindrecord_eval ./dataset ./dataset/eval.txt"
echo
"=============================================================================================================="
python eval.py
--device_id
=
$1
--ckpt_path
=
$2
--mindrecord_dir
=
$3
--image_dir
=
$4
--anno_path
=
$5
example/yolov3_coco2017/run_standalone_train.sh
浏览文件 @
0c81759a
...
@@ -14,8 +14,10 @@
...
@@ -14,8 +14,10 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
echo
"=============================================================================================================="
echo
"Please run the scipt as: "
echo
"Please run the scipt as: "
echo
"sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE IMAGE_DIR ANNO_PATH"
echo
"sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH"
echo
"for example: sh run_standalone_train.sh 0 50 ./dataset/coco/train2017 ./dataset/train.txt"
echo
"for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt"
echo
"=============================================================================================================="
python train.py
--device_id
=
$1
--epoch_size
=
$2
--
image_dir
=
$3
--anno_path
=
$4
python train.py
--device_id
=
$1
--epoch_size
=
$2
--
mindrecord_dir
=
$3
--image_dir
=
$4
--anno_path
=
$5
example/yolov3_coco2017/train.py
浏览文件 @
0c81759a
...
@@ -16,26 +16,30 @@
...
@@ -16,26 +16,30 @@
"""
"""
######################## train YOLOv3 example ########################
######################## train YOLOv3 example ########################
train YOLOv3 and get network model files(.ckpt) :
train YOLOv3 and get network model files(.ckpt) :
python train.py --image_dir dataset/coco/coco/train2017 --anno_path dataset/coco/train_coco.txt
python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train
If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path.
Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path.
"""
"""
import
os
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
,
Tensor
from
mindspore
import
context
,
Tensor
from
mindspore.common.initializer
import
initializer
from
mindspore.communication.management
import
init
from
mindspore.communication.management
import
init
from
mindspore.train.callback
import
CheckpointConfig
,
ModelCheckpoint
,
LossMonitor
,
TimeMonitor
from
mindspore.train.callback
import
CheckpointConfig
,
ModelCheckpoint
,
LossMonitor
,
TimeMonitor
from
mindspore.train
import
Model
,
ParallelMode
from
mindspore.train
import
Model
,
ParallelMode
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.common.initializer
import
initializer
from
mindspore.model_zoo.yolov3
import
yolov3_resnet18
,
YoloWithLossCell
,
TrainingWrapper
from
mindspore.model_zoo.yolov3
import
yolov3_resnet18
,
YoloWithLossCell
,
TrainingWrapper
from
dataset
import
create_yolo_dataset
from
dataset
import
create_yolo_dataset
,
data_to_mindrecord_byte_image
from
config
import
ConfigYOLOV3ResNet18
from
config
import
ConfigYOLOV3ResNet18
def
get_lr
(
learning_rate
,
start_step
,
global_step
,
decay_step
,
decay_rate
,
steps
=
False
):
def
get_lr
(
learning_rate
,
start_step
,
global_step
,
decay_step
,
decay_rate
,
steps
=
False
):
"""Set learning rate"""
"""Set learning rate
.
"""
lr_each_step
=
[]
lr_each_step
=
[]
lr
=
learning_rate
lr
=
learning_rate
for
i
in
range
(
global_step
):
for
i
in
range
(
global_step
):
...
@@ -57,7 +61,9 @@ def init_net_param(net, init='ones'):
...
@@ -57,7 +61,9 @@ def init_net_param(net, init='ones'):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
"YOLOv3"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"YOLOv3 train"
)
parser
.
add_argument
(
"--only_create_dataset"
,
type
=
bool
,
default
=
False
,
help
=
"If set it true, only create "
"Mindrecord, default is false."
)
parser
.
add_argument
(
"--distribute"
,
type
=
bool
,
default
=
False
,
help
=
"Run distribute, default is false."
)
parser
.
add_argument
(
"--distribute"
,
type
=
bool
,
default
=
False
,
help
=
"Run distribute, default is false."
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
parser
.
add_argument
(
"--device_num"
,
type
=
int
,
default
=
1
,
help
=
"Use device nums, default is 1."
)
parser
.
add_argument
(
"--device_num"
,
type
=
int
,
default
=
1
,
help
=
"Use device nums, default is 1."
)
...
@@ -67,12 +73,19 @@ if __name__ == '__main__':
...
@@ -67,12 +73,19 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Checkpoint file path"
)
parser
.
add_argument
(
"--checkpoint_path"
,
type
=
str
,
default
=
""
,
help
=
"Checkpoint file path"
)
parser
.
add_argument
(
"--save_checkpoint_epochs"
,
type
=
int
,
default
=
5
,
help
=
"Save checkpoint epochs, default is 5."
)
parser
.
add_argument
(
"--save_checkpoint_epochs"
,
type
=
int
,
default
=
5
,
help
=
"Save checkpoint epochs, default is 5."
)
parser
.
add_argument
(
"--loss_scale"
,
type
=
int
,
default
=
1024
,
help
=
"Loss scale, default is 1024."
)
parser
.
add_argument
(
"--loss_scale"
,
type
=
int
,
default
=
1024
,
help
=
"Loss scale, default is 1024."
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
required
=
True
,
help
=
"Dataset image dir."
)
parser
.
add_argument
(
"--mindrecord_dir"
,
type
=
str
,
default
=
"./Mindrecord_train"
,
parser
.
add_argument
(
"--anno_path"
,
type
=
str
,
required
=
True
,
help
=
"Dataset anno path."
)
help
=
"Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by"
"image_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
"rather than image_dir and anno_path. Default is ./Mindrecord_train"
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
""
,
help
=
"Dataset directory, "
"the absolute image path is joined by the image_dir "
"and the relative path in anno_path"
)
parser
.
add_argument
(
"--anno_path"
,
type
=
str
,
default
=
""
,
help
=
"Annotation path."
)
args_opt
=
parser
.
parse_args
()
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
enable_task_sink
=
True
,
enable_loop_sink
=
True
,
enable_mem_reuse
=
True
)
context
.
set_context
(
enable_task_sink
=
True
,
enable_loop_sink
=
True
,
enable_mem_reuse
=
True
,
enable_auto_mixed_precision
=
False
)
if
args_opt
.
distribute
:
if
args_opt
.
distribute
:
device_num
=
args_opt
.
device_num
device_num
=
args_opt
.
device_num
context
.
reset_auto_parallel_context
()
context
.
reset_auto_parallel_context
()
...
@@ -80,36 +93,65 @@ if __name__ == '__main__':
...
@@ -80,36 +93,65 @@ if __name__ == '__main__':
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
,
device_num
=
device_num
)
device_num
=
device_num
)
init
()
init
()
rank
=
args_opt
.
device_id
rank
=
args_opt
.
device_id
%
device_num
else
:
else
:
context
.
set_context
(
enable_hccl
=
False
)
context
.
set_context
(
enable_hccl
=
False
)
rank
=
0
rank
=
0
device_num
=
1
device_num
=
1
loss_scale
=
float
(
args_opt
.
loss_scale
)
print
(
"Start create dataset!"
)
dataset
=
create_yolo_dataset
(
args_opt
.
image_dir
,
args_opt
.
anno_path
,
repeat_num
=
args_opt
.
epoch_size
,
batch_size
=
args_opt
.
batch_size
,
device_num
=
device_num
,
rank
=
rank
)
# It will generate mindrecord file in args_opt.mindrecord_dir,
dataset_size
=
dataset
.
get_dataset_size
()
# and the file name is yolo.mindrecord0, 1, ... file_num.
net
=
yolov3_resnet18
(
ConfigYOLOV3ResNet18
())
if
not
os
.
path
.
isdir
(
args_opt
.
mindrecord_dir
):
net
=
YoloWithLossCell
(
net
,
ConfigYOLOV3ResNet18
())
os
.
makedirs
(
args_opt
.
mindrecord_dir
)
init_net_param
(
net
,
"XavierUniform"
)
prefix
=
"yolo.mindrecord"
# checkpoint
mindrecord_file
=
os
.
path
.
join
(
args_opt
.
mindrecord_dir
,
prefix
+
"0"
)
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
dataset_size
*
args_opt
.
save_checkpoint_epochs
)
if
not
os
.
path
.
exists
(
mindrecord_file
):
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"yolov3"
,
directory
=
None
,
config
=
ckpt_config
)
if
os
.
path
.
isdir
(
args_opt
.
image_dir
)
and
os
.
path
.
exists
(
args_opt
.
anno_path
):
if
args_opt
.
checkpoint_path
!=
""
:
print
(
"Create Mindrecord."
)
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
data_to_mindrecord_byte_image
(
args_opt
.
image_dir
,
load_param_into_net
(
net
,
param_dict
)
args_opt
.
anno_path
,
args_opt
.
mindrecord_dir
,
lr
=
Tensor
(
get_lr
(
learning_rate
=
0.001
,
start_step
=
0
,
global_step
=
args_opt
.
epoch_size
*
dataset_size
,
prefix
=
prefix
,
decay_step
=
1000
,
decay_rate
=
0.95
))
file_num
=
8
)
opt
=
nn
.
Adam
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
lr
,
loss_scale
=
loss_scale
)
print
(
"Create Mindrecord Done, at {}"
.
format
(
args_opt
.
mindrecord_dir
))
net
=
TrainingWrapper
(
net
,
opt
,
loss_scale
)
else
:
callback
=
[
TimeMonitor
(
data_size
=
dataset_size
),
LossMonitor
(),
ckpoint_cb
]
print
(
"image_dir or anno_path not exits."
)
model
=
Model
(
net
)
if
not
args_opt
.
only_create_dataset
:
dataset_sink_mode
=
False
loss_scale
=
float
(
args_opt
.
loss_scale
)
if
args_opt
.
mode
==
"graph"
:
dataset_sink_mode
=
True
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
print
(
"Start train YOLOv3."
)
dataset
=
create_yolo_dataset
(
mindrecord_file
,
repeat_num
=
args_opt
.
epoch_size
,
model
.
train
(
args_opt
.
epoch_size
,
dataset
,
callbacks
=
callback
,
dataset_sink_mode
=
dataset_sink_mode
)
batch_size
=
args_opt
.
batch_size
,
device_num
=
device_num
,
rank
=
rank
)
dataset_size
=
dataset
.
get_dataset_size
()
print
(
"Create dataset done!"
)
net
=
yolov3_resnet18
(
ConfigYOLOV3ResNet18
())
net
=
YoloWithLossCell
(
net
,
ConfigYOLOV3ResNet18
())
init_net_param
(
net
,
"XavierUniform"
)
# checkpoint
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
dataset_size
*
args_opt
.
save_checkpoint_epochs
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"yolov3"
,
directory
=
None
,
config
=
ckpt_config
)
lr
=
Tensor
(
get_lr
(
learning_rate
=
0.001
,
start_step
=
0
,
global_step
=
args_opt
.
epoch_size
*
dataset_size
,
decay_step
=
1000
,
decay_rate
=
0.95
))
opt
=
nn
.
Adam
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
lr
,
loss_scale
=
loss_scale
)
net
=
TrainingWrapper
(
net
,
opt
,
loss_scale
)
if
args_opt
.
checkpoint_path
!=
""
:
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
load_param_into_net
(
net
,
param_dict
)
callback
=
[
TimeMonitor
(
data_size
=
dataset_size
),
LossMonitor
(),
ckpoint_cb
]
model
=
Model
(
net
)
dataset_sink_mode
=
False
if
args_opt
.
mode
==
"graph"
:
print
(
"In graph mode, one epoch return a loss."
)
dataset_sink_mode
=
True
print
(
"Start train YOLOv3, the first epoch will be slower because of the graph compilation."
)
model
.
train
(
args_opt
.
epoch_size
,
dataset
,
callbacks
=
callback
,
dataset_sink_mode
=
dataset_sink_mode
)
example/yolov3_coco2017/util.py
0 → 100644
浏览文件 @
0c81759a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import
numpy
as
np
from
config
import
ConfigYOLOV3ResNet18
def
calc_iou
(
bbox_pred
,
bbox_ground
):
"""Calculate iou of predicted bbox and ground truth."""
x1
=
bbox_pred
[
0
]
y1
=
bbox_pred
[
1
]
width1
=
bbox_pred
[
2
]
-
bbox_pred
[
0
]
height1
=
bbox_pred
[
3
]
-
bbox_pred
[
1
]
x2
=
bbox_ground
[
0
]
y2
=
bbox_ground
[
1
]
width2
=
bbox_ground
[
2
]
-
bbox_ground
[
0
]
height2
=
bbox_ground
[
3
]
-
bbox_ground
[
1
]
endx
=
max
(
x1
+
width1
,
x2
+
width2
)
startx
=
min
(
x1
,
x2
)
width
=
width1
+
width2
-
(
endx
-
startx
)
endy
=
max
(
y1
+
height1
,
y2
+
height2
)
starty
=
min
(
y1
,
y2
)
height
=
height1
+
height2
-
(
endy
-
starty
)
if
width
<=
0
or
height
<=
0
:
iou
=
0
else
:
area
=
width
*
height
area1
=
width1
*
height1
area2
=
width2
*
height2
iou
=
area
*
1.
/
(
area1
+
area2
-
area
)
return
iou
def
apply_nms
(
all_boxes
,
all_scores
,
thres
,
max_boxes
):
"""Apply NMS to bboxes."""
x1
=
all_boxes
[:,
0
]
y1
=
all_boxes
[:,
1
]
x2
=
all_boxes
[:,
2
]
y2
=
all_boxes
[:,
3
]
areas
=
(
x2
-
x1
+
1
)
*
(
y2
-
y1
+
1
)
order
=
all_scores
.
argsort
()[::
-
1
]
keep
=
[]
while
order
.
size
>
0
:
i
=
order
[
0
]
keep
.
append
(
i
)
if
len
(
keep
)
>=
max_boxes
:
break
xx1
=
np
.
maximum
(
x1
[
i
],
x1
[
order
[
1
:]])
yy1
=
np
.
maximum
(
y1
[
i
],
y1
[
order
[
1
:]])
xx2
=
np
.
minimum
(
x2
[
i
],
x2
[
order
[
1
:]])
yy2
=
np
.
minimum
(
y2
[
i
],
y2
[
order
[
1
:]])
w
=
np
.
maximum
(
0.0
,
xx2
-
xx1
+
1
)
h
=
np
.
maximum
(
0.0
,
yy2
-
yy1
+
1
)
inter
=
w
*
h
ovr
=
inter
/
(
areas
[
i
]
+
areas
[
order
[
1
:]]
-
inter
)
inds
=
np
.
where
(
ovr
<=
thres
)[
0
]
order
=
order
[
inds
+
1
]
return
keep
def
metrics
(
pred_data
):
"""Calculate precision and recall of predicted bboxes."""
config
=
ConfigYOLOV3ResNet18
()
num_classes
=
config
.
num_classes
count_corrects
=
[
1e-6
for
_
in
range
(
num_classes
)]
count_grounds
=
[
1e-6
for
_
in
range
(
num_classes
)]
count_preds
=
[
1e-6
for
_
in
range
(
num_classes
)]
for
i
,
sample
in
enumerate
(
pred_data
):
gt_anno
=
sample
[
"annotation"
]
box_scores
=
sample
[
'box_scores'
]
boxes
=
sample
[
'boxes'
]
mask
=
box_scores
>=
config
.
obj_threshold
boxes_
=
[]
scores_
=
[]
classes_
=
[]
max_boxes
=
config
.
nms_max_num
for
c
in
range
(
num_classes
):
class_boxes
=
np
.
reshape
(
boxes
,
[
-
1
,
4
])[
np
.
reshape
(
mask
[:,
c
],
[
-
1
])]
class_box_scores
=
np
.
reshape
(
box_scores
[:,
c
],
[
-
1
])[
np
.
reshape
(
mask
[:,
c
],
[
-
1
])]
nms_index
=
apply_nms
(
class_boxes
,
class_box_scores
,
config
.
nms_threshold
,
max_boxes
)
class_boxes
=
class_boxes
[
nms_index
]
class_box_scores
=
class_box_scores
[
nms_index
]
classes
=
np
.
ones_like
(
class_box_scores
,
'int32'
)
*
c
boxes_
.
append
(
class_boxes
)
scores_
.
append
(
class_box_scores
)
classes_
.
append
(
classes
)
boxes
=
np
.
concatenate
(
boxes_
,
axis
=
0
)
classes
=
np
.
concatenate
(
classes_
,
axis
=
0
)
# metric
count_correct
=
[
1e-6
for
_
in
range
(
num_classes
)]
count_ground
=
[
1e-6
for
_
in
range
(
num_classes
)]
count_pred
=
[
1e-6
for
_
in
range
(
num_classes
)]
for
anno
in
gt_anno
:
count_ground
[
anno
[
4
]]
+=
1
for
box_index
,
box
in
enumerate
(
boxes
):
bbox_pred
=
[
box
[
1
],
box
[
0
],
box
[
3
],
box
[
2
]]
count_pred
[
classes
[
box_index
]]
+=
1
for
anno
in
gt_anno
:
class_ground
=
anno
[
4
]
if
classes
[
box_index
]
==
class_ground
:
iou
=
calc_iou
(
bbox_pred
,
anno
)
if
iou
>=
0.5
:
count_correct
[
class_ground
]
+=
1
break
count_corrects
=
[
count_corrects
[
i
]
+
count_correct
[
i
]
for
i
in
range
(
num_classes
)]
count_preds
=
[
count_preds
[
i
]
+
count_pred
[
i
]
for
i
in
range
(
num_classes
)]
count_grounds
=
[
count_grounds
[
i
]
+
count_ground
[
i
]
for
i
in
range
(
num_classes
)]
precision
=
np
.
array
([
count_corrects
[
ix
]
/
count_preds
[
ix
]
for
ix
in
range
(
num_classes
)])
recall
=
np
.
array
([
count_corrects
[
ix
]
/
count_grounds
[
ix
]
for
ix
in
range
(
num_classes
)])
return
precision
,
recall
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
浏览文件 @
0c81759a
...
@@ -34,6 +34,7 @@ static std::map<string, string> tbe_func_adapter_map = {
...
@@ -34,6 +34,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{
"tensor_add"
,
"add"
},
{
"tensor_add"
,
"add"
},
{
"reduce_mean"
,
"reduce_mean_d"
},
{
"reduce_mean"
,
"reduce_mean_d"
},
{
"reduce_max"
,
"reduce_max_d"
},
{
"reduce_max"
,
"reduce_max_d"
},
{
"reduce_min"
,
"reduce_min_d"
},
{
"conv2d_backprop_filter"
,
"conv2d_backprop_filter_d"
},
{
"conv2d_backprop_filter"
,
"conv2d_backprop_filter_d"
},
{
"conv2d_backprop_input"
,
"conv2d_backprop_input_d"
},
{
"conv2d_backprop_input"
,
"conv2d_backprop_input_d"
},
{
"top_kv2"
,
"top_k"
},
{
"top_kv2"
,
"top_k"
},
...
...
mindspore/model_zoo/yolov3.py
浏览文件 @
0c81759a
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""YOLOv3 based on ResNet18."""
"""YOLOv3 based on ResNet18."""
import
numpy
as
np
import
mindspore
as
ms
import
mindspore
as
ms
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
,
Tensor
from
mindspore
import
context
,
Tensor
...
@@ -31,19 +32,14 @@ def weight_variable():
...
@@ -31,19 +32,14 @@ def weight_variable():
return
TruncatedNormal
(
0.02
)
return
TruncatedNormal
(
0.02
)
class
_conv
_with_pa
d
(
nn
.
Cell
):
class
_conv
2
d
(
nn
.
Cell
):
"""Create Conv2D with padding."""
"""Create Conv2D with padding."""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
):
super
(
_conv_with_pad
,
self
).
__init__
()
super
(
_conv2d
,
self
).
__init__
()
total_pad
=
kernel_size
-
1
pad_begin
=
total_pad
//
2
pad_end
=
total_pad
-
pad_begin
self
.
pad
=
P
.
Pad
(((
0
,
0
),
(
0
,
0
),
(
pad_begin
,
pad_end
),
(
pad_begin
,
pad_end
)))
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'
valid
'
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'
same
'
,
weight_init
=
weight_variable
())
weight_init
=
weight_variable
())
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
x
=
self
.
pad
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
return
x
return
x
...
@@ -101,15 +97,15 @@ class BasicBlock(nn.Cell):
...
@@ -101,15 +97,15 @@ class BasicBlock(nn.Cell):
momentum
=
0.99
):
momentum
=
0.99
):
super
(
BasicBlock
,
self
).
__init__
()
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
_conv
_with_pa
d
(
in_channels
,
out_channels
,
3
,
stride
=
stride
)
self
.
conv1
=
_conv
2
d
(
in_channels
,
out_channels
,
3
,
stride
=
stride
)
self
.
bn1
=
_fused_bn
(
out_channels
,
momentum
=
momentum
)
self
.
bn1
=
_fused_bn
(
out_channels
,
momentum
=
momentum
)
self
.
conv2
=
_conv
_with_pa
d
(
out_channels
,
out_channels
,
3
)
self
.
conv2
=
_conv
2
d
(
out_channels
,
out_channels
,
3
)
self
.
bn2
=
_fused_bn
(
out_channels
,
momentum
=
momentum
)
self
.
bn2
=
_fused_bn
(
out_channels
,
momentum
=
momentum
)
self
.
relu
=
P
.
ReLU
()
self
.
relu
=
P
.
ReLU
()
self
.
down_sample_layer
=
None
self
.
down_sample_layer
=
None
self
.
downsample
=
(
in_channels
!=
out_channels
)
self
.
downsample
=
(
in_channels
!=
out_channels
)
if
self
.
downsample
:
if
self
.
downsample
:
self
.
down_sample_layer
=
_conv
_with_pa
d
(
in_channels
,
out_channels
,
1
,
stride
=
stride
)
self
.
down_sample_layer
=
_conv
2
d
(
in_channels
,
out_channels
,
1
,
stride
=
stride
)
self
.
add
=
P
.
TensorAdd
()
self
.
add
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
...
@@ -166,7 +162,7 @@ class ResNet(nn.Cell):
...
@@ -166,7 +162,7 @@ class ResNet(nn.Cell):
raise
ValueError
(
"the length of "
raise
ValueError
(
"the length of "
"layer_num, inchannel, outchannel list must be 4!"
)
"layer_num, inchannel, outchannel list must be 4!"
)
self
.
conv1
=
_conv
_with_pa
d
(
3
,
64
,
7
,
stride
=
2
)
self
.
conv1
=
_conv
2
d
(
3
,
64
,
7
,
stride
=
2
)
self
.
bn1
=
_fused_bn
(
64
)
self
.
bn1
=
_fused_bn
(
64
)
self
.
relu
=
P
.
ReLU
()
self
.
relu
=
P
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
pad_mode
=
'same'
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
pad_mode
=
'same'
)
...
@@ -452,7 +448,7 @@ class DetectionBlock(nn.Cell):
...
@@ -452,7 +448,7 @@ class DetectionBlock(nn.Cell):
if
self
.
training
:
if
self
.
training
:
return
grid
,
prediction
,
box_xy
,
box_wh
return
grid
,
prediction
,
box_xy
,
box_wh
return
self
.
concat
((
box_xy
,
box_wh
,
box_confidence
,
box_probs
))
return
box_xy
,
box_wh
,
box_confidence
,
box_probs
class
Iou
(
nn
.
Cell
):
class
Iou
(
nn
.
Cell
):
...
@@ -675,3 +671,78 @@ class TrainingWrapper(nn.Cell):
...
@@ -675,3 +671,78 @@ class TrainingWrapper(nn.Cell):
# apply grad reducer on grads
# apply grad reducer on grads
grads
=
self
.
grad_reducer
(
grads
)
grads
=
self
.
grad_reducer
(
grads
)
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
return
F
.
depend
(
loss
,
self
.
optimizer
(
grads
))
class
YoloBoxScores
(
nn
.
Cell
):
"""
Calculate the boxes of the original picture size and the score of each box.
Args:
config (Class): YOLOv3 config.
Returns:
Tensor, the boxes of the original picture size.
Tensor, the score of each box.
"""
def
__init__
(
self
,
config
):
super
(
YoloBoxScores
,
self
).
__init__
()
self
.
input_shape
=
Tensor
(
np
.
array
(
config
.
img_shape
),
ms
.
float32
)
self
.
num_classes
=
config
.
num_classes
def
construct
(
self
,
box_xy
,
box_wh
,
box_confidence
,
box_probs
,
image_shape
):
batch_size
=
F
.
shape
(
box_xy
)[
0
]
x
=
box_xy
[:,
:,
:,
:,
0
:
1
]
y
=
box_xy
[:,
:,
:,
:,
1
:
2
]
box_yx
=
P
.
Concat
(
-
1
)((
y
,
x
))
w
=
box_wh
[:,
:,
:,
:,
0
:
1
]
h
=
box_wh
[:,
:,
:,
:,
1
:
2
]
box_hw
=
P
.
Concat
(
-
1
)((
h
,
w
))
new_shape
=
P
.
Round
()(
image_shape
*
P
.
ReduceMin
()(
self
.
input_shape
/
image_shape
))
offset
=
(
self
.
input_shape
-
new_shape
)
/
2.0
/
self
.
input_shape
scale
=
self
.
input_shape
/
new_shape
box_yx
=
(
box_yx
-
offset
)
*
scale
box_hw
=
box_hw
*
scale
box_min
=
box_yx
-
box_hw
/
2.0
box_max
=
box_yx
+
box_hw
/
2.0
boxes
=
P
.
Concat
(
-
1
)((
box_min
[:,
:,
:,
:,
0
:
1
],
box_min
[:,
:,
:,
:,
1
:
2
],
box_max
[:,
:,
:,
:,
0
:
1
],
box_max
[:,
:,
:,
:,
1
:
2
]))
image_scale
=
P
.
Tile
()(
image_shape
,
(
1
,
2
))
boxes
=
boxes
*
image_scale
boxes
=
F
.
reshape
(
boxes
,
(
batch_size
,
-
1
,
4
))
boxes_scores
=
box_confidence
*
box_probs
boxes_scores
=
F
.
reshape
(
boxes_scores
,
(
batch_size
,
-
1
,
self
.
num_classes
))
return
boxes
,
boxes_scores
class
YoloWithEval
(
nn
.
Cell
):
"""
Encapsulation class of YOLOv3 evaluation.
Args:
network (Cell): The training network. Note that loss function and optimizer must not be added.
config (Class): YOLOv3 config.
Returns:
Tensor, the boxes of the original picture size.
Tensor, the score of each box.
Tensor, the original picture size.
"""
def
__init__
(
self
,
network
,
config
):
super
(
YoloWithEval
,
self
).
__init__
()
self
.
yolo_network
=
network
self
.
box_score_0
=
YoloBoxScores
(
config
)
self
.
box_score_1
=
YoloBoxScores
(
config
)
self
.
box_score_2
=
YoloBoxScores
(
config
)
def
construct
(
self
,
x
,
image_shape
):
yolo_output
=
self
.
yolo_network
(
x
)
boxes_0
,
boxes_scores_0
=
self
.
box_score_0
(
*
yolo_output
[
0
],
image_shape
)
boxes_1
,
boxes_scores_1
=
self
.
box_score_1
(
*
yolo_output
[
1
],
image_shape
)
boxes_2
,
boxes_scores_2
=
self
.
box_score_2
(
*
yolo_output
[
2
],
image_shape
)
boxes
=
P
.
Concat
(
1
)((
boxes_0
,
boxes_1
,
boxes_2
))
boxes_scores
=
P
.
Concat
(
1
)((
boxes_scores_0
,
boxes_scores_1
,
boxes_scores_2
))
return
boxes
,
boxes_scores
,
image_shape
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
0c81759a
...
@@ -85,7 +85,9 @@ from .logical_and import _logical_and_tbe
...
@@ -85,7 +85,9 @@ from .logical_and import _logical_and_tbe
from
.logical_not
import
_logical_not_tbe
from
.logical_not
import
_logical_not_tbe
from
.logical_or
import
_logical_or_tbe
from
.logical_or
import
_logical_or_tbe
from
.reduce_max
import
_reduce_max_tbe
from
.reduce_max
import
_reduce_max_tbe
from
.reduce_min
import
_reduce_min_tbe
from
.reduce_sum
import
_reduce_sum_tbe
from
.reduce_sum
import
_reduce_sum_tbe
from
.round
import
_round_tbe
from
.tanh
import
_tanh_tbe
from
.tanh
import
_tanh_tbe
from
.tanh_grad
import
_tanh_grad_tbe
from
.tanh_grad
import
_tanh_grad_tbe
from
.softmax
import
_softmax_tbe
from
.softmax
import
_softmax_tbe
...
...
mindspore/ops/_op_impl/tbe/reduce_min.py
0 → 100644
浏览文件 @
0c81759a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReduceMin op"""
from
mindspore.ops.op_info_register
import
op_info_register
@
op_info_register
(
"""{
"op_name": "ReduceMin",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "reduce_min_d.so",
"compute_cost": 10,
"kernel_name": "reduce_min_d",
"partial_flag": true,
"attr": [
{
"name": "axis",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "keep_dims",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float", "float", "int8", "int8", "uint8", "uint8"
],
"format": [
"DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ", "DefaultFormat", "FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}"""
)
def
_reduce_min_tbe
():
"""ReduceMin TBE register"""
return
mindspore/ops/_op_impl/tbe/round.py
0 → 100644
浏览文件 @
0c81759a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Round op"""
from
mindspore.ops.op_info_register
import
op_info_register
@
op_info_register
(
"""{
"op_name": "Round",
"imply_type": "TBE",
"fusion_type": "ELEMWISE",
"async_flag": false,
"binfile_name": "round.so",
"compute_cost": 10,
"kernel_name": "round",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
],
"name": "x",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16", "float16", "float16", "float", "float", "float"
],
"format": [
"DefaultFormat", "NC1HWC0", "FracZ", "DefaultFormat", "NC1HWC0", "FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}"""
)
def
_round_tbe
():
"""Round TBE register"""
return
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录