Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
97f08e74
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97f08e74
编写于
8月 18, 2020
作者:
D
danish
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
nms_sorting fix
lint py fix 2 nms_py_file test value fix lint fix
上级
e60c0b60
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
65 addition
and
115 deletion
+65
-115
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu
...ckend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu
+37
-51
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh
...kend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh
+7
-4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h
...ckend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h
+18
-12
tests/st/ops/gpu/test_nms_with_mask_op.py
tests/st/ops/gpu/test_nms_with_mask_op.py
+3
-48
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu
浏览文件 @
97f08e74
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include <limits>
#include <limits>
#include <algorithm>
#include <algorithm>
int
RoundUpPower2M
(
int
v
)
{
int
NMSRoundUpPower2
(
int
v
)
{
v
--
;
v
--
;
v
|=
v
>>
1
;
v
|=
v
>>
1
;
v
|=
v
>>
2
;
v
|=
v
>>
2
;
...
@@ -30,12 +30,22 @@ int RoundUpPower2M(int v) {
...
@@ -30,12 +30,22 @@ int RoundUpPower2M(int v) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__inline__
__device__
void
Swap
M
(
T
*
lhs
,
T
*
rhs
)
{
__inline__
__device__
void
Swap
(
T
*
lhs
,
T
*
rhs
)
{
T
tmp
=
lhs
[
0
];
T
tmp
=
lhs
[
0
];
lhs
[
0
]
=
rhs
[
0
];
lhs
[
0
]
=
rhs
[
0
];
rhs
[
0
]
=
tmp
;
rhs
[
0
]
=
tmp
;
}
}
template
<
typename
T
>
__global__
void
PopulateOutput
(
T
*
data_in
,
T
*
data_out
,
int
*
index_buff
,
const
int
num
,
int
box_size_
)
{
for
(
int
box_num
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
box_num
<
num
;
box_num
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
correct_index
=
index_buff
[(
num
-
1
)
-
box_num
];
// flip the array around
for
(
int
x
=
0
;
x
<
5
;
x
++
)
{
data_out
[(
box_num
*
box_size_
)
+
x
]
=
data_in
[(
correct_index
*
box_size_
)
+
x
];
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__inline__
__device__
bool
IOUDecision
(
T
*
output
,
int
box_A_ix
,
int
box_B_ix
,
int
box_A_start
,
int
box_B_start
,
T
*
area
,
__inline__
__device__
bool
IOUDecision
(
T
*
output
,
int
box_A_ix
,
int
box_B_ix
,
int
box_A_start
,
int
box_B_start
,
T
*
area
,
float
IOU_value
)
{
float
IOU_value
)
{
...
@@ -96,38 +106,29 @@ __global__ void FinalPass(const int num, const float IOU_value, T *output, T *ar
...
@@ -96,38 +106,29 @@ __global__ void FinalPass(const int num, const float IOU_value, T *output, T *ar
}
}
}
}
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
__global__
void
BitonicSortByKeyKernelM
(
const
int
outer
,
const
int
inner
,
const
int
ceil_power2
,
S
*
data_in
,
__global__
void
NMS_BitonicSortByKeyKernel
(
const
int
outer
,
const
int
inner
,
const
int
ceil_power2
,
T
*
input
,
S
*
data_out
,
T
*
index_buff
,
S
*
data_buff
,
int
box_size_
)
{
T
*
data_buff
,
int
*
index_buff
,
int
box_size_
)
{
// default: sort with share memory
extern
__shared__
T
share_mem_NMS
[];
T
*
index_arr
=
share_mem_NMS
;
S
*
data_arr
=
reinterpret_cast
<
S
*>
(
index_arr
+
ceil_power2
);
// sort with RAM
if
(
index_buff
!=
nullptr
&&
data_buff
!=
nullptr
)
{
index_arr
=
index_buff
+
blockIdx
.
x
*
ceil_power2
;
data_arr
=
data_buff
+
blockIdx
.
x
*
ceil_power2
;
}
for
(
int
i
=
threadIdx
.
x
;
i
<
ceil_power2
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
ceil_power2
;
i
+=
blockDim
.
x
)
{
index_arr
[
i
]
=
(
i
<
inner
)
?
T
(
i
)
:
std
::
numeric_limits
<
T
>::
max
();
data_buff
[
i
]
=
(
i
<
inner
)
?
input
[(
i
*
box_size_
)
+
4
]
:
std
::
numeric_limits
<
T
>::
max
();
// populated directly from input data
index_buff
[
i
]
=
i
;
data_arr
[
i
]
=
(
i
<
inner
)
?
data_in
[(
blockIdx
.
x
*
inner
+
i
)
*
box_size_
+
4
]
:
std
::
numeric_limits
<
S
>::
max
();
}
}
__syncthreads
();
__syncthreads
();
for
(
size_t
i
=
2
;
i
<=
ceil_power2
;
i
<<=
1
)
{
for
(
size_t
i
=
2
;
i
<=
ceil_power2
;
i
<<=
1
)
{
for
(
size_t
j
=
(
i
>>
1
);
j
>
0
;
j
>>=
1
)
{
for
(
size_t
j
=
(
i
>>
1
);
j
>
0
;
j
>>=
1
)
{
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
ceil_power2
;
tid
+=
blockDim
.
x
)
{
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
ceil_power2
;
tid
+=
blockDim
.
x
)
{
size_t
tid_comp
=
tid
^
j
;
size_t
tid_comp
=
tid
^
j
;
if
(
tid_comp
>
tid
)
{
if
(
tid_comp
>
tid
)
{
if
((
tid
&
i
)
==
0
)
{
if
((
tid
&
i
)
==
0
)
{
if
(
data_
arr
[
tid
]
>
data_arr
[
tid_comp
])
{
if
(
data_
buff
[
tid
]
>
data_buff
[
tid_comp
])
{
Swap
M
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
Swap
(
&
data_buff
[
tid
],
&
data_buff
[
tid_comp
]);
Swap
M
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
Swap
(
&
index_buff
[
tid
],
&
index_buff
[
tid_comp
]);
}
}
}
else
{
}
else
{
if
(
data_
arr
[
tid
]
<
data_arr
[
tid_comp
])
{
if
(
data_
buff
[
tid
]
<
data_buff
[
tid_comp
])
{
Swap
M
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
Swap
(
&
data_buff
[
tid
],
&
data_buff
[
tid_comp
]);
Swap
M
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
Swap
(
&
index_buff
[
tid
],
&
index_buff
[
tid_comp
]);
}
}
}
}
}
}
...
@@ -135,36 +136,21 @@ __global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const
...
@@ -135,36 +136,21 @@ __global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const
__syncthreads
();
__syncthreads
();
}
}
}
}
T
correct_index
;
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
inner
;
tid
+=
blockDim
.
x
)
{
correct_index
=
index_arr
[(
inner
-
1
)
-
tid
];
// moved data from input to output, correct ordering using sorted index array
for
(
auto
i
:
{
0
,
1
,
2
,
3
,
4
})
{
data_out
[(
blockIdx
.
x
*
inner
+
tid
)
*
box_size_
+
i
]
=
data_in
[(
blockIdx
.
x
*
inner
+
correct_index
)
*
box_size_
+
i
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
CalPreprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
output
,
int
box_size_
,
cudaStream_t
cuda_stream
)
{
void
CalPreprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
input
,
T
*
output
,
int
*
index_buff
,
int
box_size_
,
cudaStream_t
cuda_stream
)
{
PopulateOutput
<<<
GET_BLOCKS
(
num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
index_buff
,
num
,
box_size_
);
Preprocess
<<<
GET_BLOCKS
(
num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
num
,
sel_idx
,
area
,
output
,
box_size_
);
Preprocess
<<<
GET_BLOCKS
(
num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
num
,
sel_idx
,
area
,
output
,
box_size_
);
}
}
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
void
BitonicSortByKeyM
(
const
int
&
outer
,
const
int
&
inner
,
S
*
data_in
,
S
*
data_out
,
T
*
index_buff
,
S
*
data_buff
,
void
CalSortInit
(
const
int
&
num
,
T
*
data_in
,
T
*
data_out
,
int
*
index_buff
,
T
*
data_buff
,
int
box_size_
,
int
box_size_
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
int
ceil_power2
=
RoundUpPower2M
(
inner
);
int
ceil_p_2
=
NMSRoundUpPower2
(
num
);
size_t
share_mem
=
ceil_power2
*
(
sizeof
(
T
)
+
sizeof
(
S
));
int
thread
=
std
::
min
(
ceil_p_2
,
GET_THREADS
);
if
(
share_mem
>
SHARED_MEM_PER_BLOCK
)
{
NMS_BitonicSortByKeyKernel
<<<
1
,
thread
,
0
,
stream
>>>
(
1
,
num
,
ceil_p_2
,
data_in
,
data_buff
,
index_buff
,
box_size_
);
share_mem
=
0
;
}
else
{
data_buff
=
nullptr
;
index_buff
=
nullptr
;
}
int
thread
=
std
::
min
(
ceil_power2
,
GET_THREADS
);
BitonicSortByKeyKernelM
<<<
outer
,
thread
,
share_mem
,
stream
>>>
(
outer
,
inner
,
ceil_power2
,
data_in
,
data_out
,
index_buff
,
data_buff
,
box_size_
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -180,11 +166,11 @@ void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool
...
@@ -180,11 +166,11 @@ void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool
FinalPass
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
num
,
IOU_value
,
output
,
area
,
sel_boxes
,
box_size_
);
FinalPass
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
num
,
IOU_value
,
output
,
area
,
sel_boxes
,
box_size_
);
}
}
template
void
CalPreprocess
<
float
>(
const
int
num
,
int
*
sel_idx
,
float
*
area
,
float
*
output
,
int
box_size_
,
template
void
CalPreprocess
<
float
>(
const
int
num
,
int
*
sel_idx
,
float
*
area
,
float
*
input
,
float
*
output
,
cudaStream_t
cuda_stream
);
int
*
index_buff
,
int
box_size_
,
cudaStream_t
cuda_stream
);
template
void
BitonicSortByKeyM
(
const
int
&
outer
,
const
int
&
inner
,
float
*
data_in
,
float
*
data_out
,
int
*
index
_buff
,
template
void
CalSortInit
<
float
>(
const
int
&
inner
,
float
*
data_in
,
float
*
data_out
,
int
*
index_buff
,
float
*
data
_buff
,
float
*
data_buff
,
int
box_size_
,
cudaStream_t
stream
);
int
box_size_
,
cudaStream_t
stream
);
template
void
CalNMSWithMask
<
float
>(
const
int
num
,
const
float
IOU_value
,
float
*
output
,
float
*
area
,
bool
*
sel_boxes
,
template
void
CalNMSWithMask
<
float
>(
const
int
num
,
const
float
IOU_value
,
float
*
output
,
float
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
int
box_size_
,
cudaStream_t
cuda_stream
);
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh
浏览文件 @
97f08e74
...
@@ -20,18 +20,21 @@
...
@@ -20,18 +20,21 @@
#include "runtime/device/gpu/cuda_common.h"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
template
<
typename
T
>
void
CalPreprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
output
,
int
box_size_
,
cudaStream_t
cuda_stream
);
void
CalPreprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
input
,
T
*
output
,
int
*
index_buff
,
int
box_size_
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
template
<
typename
T
>
void
CalNMSWithMask
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
void
CalNMSWithMask
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
cudaStream_t
cuda_stream
);
template
<
typename
T
,
typename
S
>
template
<
typename
T
>
void
BitonicSortByKeyM
(
const
int
&
outer
,
const
int
&
inner
,
S
*
data_in
,
S
*
data_out
,
T
*
index_buff
,
S
*
data_buff
,
void
CalSortInit
(
const
int
&
inner
,
T
*
data_in
,
T
*
data_out
,
int
*
index_buff
,
T
*
data_buff
,
int
box_size_
,
int
box_size_
,
cudaStream_t
stream
);
cudaStream_t
stream
);
template
<
typename
T
>
template
<
typename
T
>
void
CalFinalPass
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
void
CalFinalPass
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
cudaStream_t
cuda_stream
);
int
NMSRoundUpPower2
(
int
v
);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h
浏览文件 @
97f08e74
...
@@ -30,7 +30,8 @@ namespace kernel {
...
@@ -30,7 +30,8 @@ namespace kernel {
template
<
typename
T
>
template
<
typename
T
>
class
NMSWithMaskGpuFwdKernel
:
public
GpuKernel
{
class
NMSWithMaskGpuFwdKernel
:
public
GpuKernel
{
public:
public:
NMSWithMaskGpuFwdKernel
()
:
num_input_
(
0
),
iou_value_
(
0.5
),
input_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
)
{}
NMSWithMaskGpuFwdKernel
()
:
num_input_
(
0
),
iou_value_
(
0.5
),
input_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
),
ceil_power_2
(
0
)
{}
~
NMSWithMaskGpuFwdKernel
()
override
=
default
;
~
NMSWithMaskGpuFwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
@@ -40,22 +41,24 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
...
@@ -40,22 +41,24 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
data_buff
=
GetDeviceAddress
<
T
>
(
workspace
,
0
);
// sort buffer
T
*
area
=
GetDeviceAddress
<
T
>
(
workspace
,
0
);
// store area values for all boxes
int
*
index_buff
=
GetDeviceAddress
<
int
>
(
workspace
,
1
);
T
*
data_buff
=
GetDeviceAddress
<
T
>
(
workspace
,
1
);
// sort buffer
T
*
area
=
GetDeviceAddress
<
T
>
(
workspace
,
2
);
// store area values for all boxes
int
*
index_buff
=
GetDeviceAddress
<
int
>
(
workspace
,
2
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
int
*
sel_idx
=
GetDeviceAddress
<
int
>
(
outputs
,
1
);
int
*
sel_idx
=
GetDeviceAddress
<
int
>
(
outputs
,
1
);
bool
*
sel_boxes
=
GetDeviceAddress
<
bool
>
(
outputs
,
2
);
bool
*
sel_boxes
=
GetDeviceAddress
<
bool
>
(
outputs
,
2
);
BitonicSortByKeyM
(
num_input_
,
num_input_
,
input
,
output
,
index_buff
,
data_buff
,
box_size_
,
CalSortInit
(
num_input_
,
input
,
output
,
index_buff
,
data_buff
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalPreprocess
(
num_input_
,
sel_idx
,
area
,
output
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalPreprocess
(
num_input_
,
sel_idx
,
area
,
input
,
output
,
index_buff
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalNMSWithMask
(
num_input_
,
iou_value_
,
output
,
area
,
sel_boxes
,
box_size_
,
CalNMSWithMask
(
num_input_
,
iou_value_
,
output
,
area
,
sel_boxes
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFinalPass
(
num_input_
,
iou_value_
,
output
,
area
,
sel_boxes
,
box_size_
,
CalFinalPass
(
num_input_
,
iou_value_
,
output
,
area
,
sel_boxes
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
return
true
;
}
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
iou_value_
=
GetAttr
<
float
>
(
kernel_node
,
"iou_threshold"
);
iou_value_
=
GetAttr
<
float
>
(
kernel_node
,
"iou_threshold"
);
...
@@ -79,10 +82,13 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
...
@@ -79,10 +82,13 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
}
}
num_input_
=
input_shape
[
0
];
// Get N value in [N,5] data
num_input_
=
input_shape
[
0
];
// Get N value in [N,5] data
ceil_power_2
=
NMSRoundUpPower2
(
num_input_
);
input_size_
=
num_input_
*
sizeof
(
T
)
*
box_size_
;
// 5 values per bbox
input_size_
=
num_input_
*
sizeof
(
T
)
*
box_size_
;
// 5 values per bbox
output_size_
=
(
input_size_
)
+
(
num_input_
*
sizeof
(
int
))
+
(
num_input_
*
sizeof
(
bool
));
output_size_
=
(
input_size_
)
+
(
num_input_
*
sizeof
(
int
))
+
(
num_input_
*
sizeof
(
bool
));
workspace_size_
=
(
2
*
num_input_
*
sizeof
(
T
))
+
(
1
*
num_input_
*
sizeof
(
int
));
workspace_size_
=
num_input_
*
sizeof
(
int
);
workspace_size_
+=
ceil_power_2
*
(
sizeof
(
T
)
+
sizeof
(
int
));
InitSizeLists
();
InitSizeLists
();
return
true
;
return
true
;
...
@@ -97,20 +103,20 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
...
@@ -97,20 +103,20 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
output_size_list_
.
push_back
(
num_input_
*
sizeof
(
bool
));
output_size_list_
.
push_back
(
num_input_
*
sizeof
(
bool
));
// N sized workspace arrs
// N sized workspace arrs
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
));
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
));
// area list
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
int
));
workspace_size_list_
.
push_back
(
ceil_power_2
*
sizeof
(
T
));
// data buff
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
));
workspace_size_list_
.
push_back
(
ceil_power_2
*
sizeof
(
int
));
// index buff
}
}
private:
private:
int
num_input_
;
int
num_input_
;
float
iou_value_
;
float
iou_value_
;
static
const
int
box_size_
=
5
;
// pre_defined box width
static
const
int
box_size_
=
5
;
// pre_defined box width
// int box_size__ = 5; // current size of bboxes
// default values
// default values
size_t
input_size_
;
size_t
input_size_
;
size_t
output_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
size_t
workspace_size_
;
size_t
ceil_power_2
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
...
...
tests/st/ops/gpu/test_nms_with_mask_op.py
浏览文件 @
97f08e74
...
@@ -21,29 +21,6 @@ import mindspore
...
@@ -21,29 +21,6 @@ import mindspore
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
def
manualNMS
(
bbox
,
overlap_val_iou
):
mask
=
[
True
]
*
len
(
bbox
)
for
box_a_index
,
_
in
enumerate
(
bbox
):
if
not
mask
[
box_a_index
]:
continue
# ignore if not in list
box_a
=
bbox
[
box_a_index
]
# select box for value extraction
for
box_b_index
in
range
(
box_a_index
+
1
,
len
(
bbox
)):
if
not
mask
[
box_b_index
]:
continue
# ignore if not in list
box_b
=
bbox
[
box_b_index
]
areaA
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
])
areaB
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
])
overlap_x1
=
max
(
box_a
[
0
],
box_b
[
0
])
overlap_y1
=
max
(
box_a
[
1
],
box_b
[
1
])
overlap_x2
=
min
(
box_a
[
2
],
box_b
[
2
])
overlap_y2
=
min
(
box_a
[
3
],
box_b
[
3
])
width
=
max
((
overlap_x2
-
overlap_x1
),
0
)
height
=
max
((
overlap_y2
-
overlap_y1
),
0
)
# generate IOU decision
mask
[
box_b_index
]
=
not
(
(
width
*
height
)
/
(
areaA
+
areaB
-
(
width
*
height
)))
>
overlap_val_iou
return
mask
def
runMSRun
(
op
,
bbox
):
def
runMSRun
(
op
,
bbox
):
inputs
=
Tensor
(
bbox
,
mindspore
.
float32
)
inputs
=
Tensor
(
bbox
,
mindspore
.
float32
)
...
@@ -60,10 +37,10 @@ def runMSRun(op, bbox):
...
@@ -60,10 +37,10 @@ def runMSRun(op, bbox):
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
def
test_nms_with_mask_check_order
():
def
test_nms_with_mask_check_order
():
context
.
set_context
(
mode
=
context
.
GRAPH
_MODE
,
device_target
=
"GPU"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE
_MODE
,
device_target
=
"GPU"
)
nms_op
=
P
.
NMSWithMask
(
0.5
)
nms_op
=
P
.
NMSWithMask
(
0.5
)
for
_
in
range
(
50
0
):
for
_
in
range
(
1
0
):
count
=
2
0
count
=
800
0
box
=
np
.
random
.
randint
(
1
,
100
,
size
=
(
count
,
4
))
box
=
np
.
random
.
randint
(
1
,
100
,
size
=
(
count
,
4
))
box
[:,
2
]
=
box
[:,
0
]
+
box
[:,
2
]
box
[:,
2
]
=
box
[:,
0
]
+
box
[:,
2
]
box
[:,
3
]
=
box
[:,
1
]
+
box
[:,
3
]
box
[:,
3
]
=
box
[:,
1
]
+
box
[:,
3
]
...
@@ -77,28 +54,6 @@ def test_nms_with_mask_check_order():
...
@@ -77,28 +54,6 @@ def test_nms_with_mask_check_order():
ms_sorted_scores
,
np_sorted_scores
)
ms_sorted_scores
,
np_sorted_scores
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nms_with_masl_check_result
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
test_count
=
500
for
x
in
range
(
1
,
test_count
+
1
):
count
=
20
# size of bbox lists
nms_op
=
P
.
NMSWithMask
(
x
*
0.002
)
# will test full range b/w 0 and 1
box
=
np
.
random
.
randint
(
1
,
100
,
size
=
(
count
,
4
))
box
[:,
2
]
=
box
[:,
0
]
+
box
[:,
2
]
box
[:,
3
]
=
box
[:,
1
]
+
box
[:,
3
]
unsorted_scores
=
np
.
random
.
rand
(
count
,
1
)
sorted_scores
=
np
.
sort
(
unsorted_scores
,
axis
=
0
)[::
-
1
]
bbox
=
np
.
hstack
((
box
,
sorted_scores
))
bbox
=
Tensor
(
bbox
,
dtype
=
mindspore
.
float32
)
_
,
_
,
mask
=
nms_op
(
bbox
)
mask
=
mask
.
asnumpy
()
manual_mask
=
manualNMS
(
box
,
x
*
0.002
)
np
.
testing
.
assert_array_equal
(
mask
,
np
.
array
(
manual_mask
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录