Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e1f93ec2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e1f93ec2
编写于
7月 31, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 31, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3619 NMSWithMask - CUDA Impl
Merge pull request !3619 from danishnxt/GPU_One
上级
1165b27f
a2ffc953
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
534 addition
and
0 deletion
+534
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu
...ckend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu
+193
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh
...kend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh
+37
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc
...kend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc
+29
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h
...ckend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h
+121
-0
tests/st/ops/gpu/test_nms_with_mask_op.py
tests/st/ops/gpu/test_nms_with_mask_op.py
+154
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu
0 → 100644
浏览文件 @
e1f93ec2
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, softwareg
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nms_with_mask_impl.cuh"
#include <limits>
#include <algorithm>
int
RoundUpPower2M
(
int
v
)
{
v
--
;
v
|=
v
>>
1
;
v
|=
v
>>
2
;
v
|=
v
>>
4
;
v
|=
v
>>
8
;
v
|=
v
>>
16
;
v
++
;
return
v
;
}
template
<
typename
T
>
__inline__
__device__
void
SwapM
(
T
*
lhs
,
T
*
rhs
)
{
T
tmp
=
lhs
[
0
];
lhs
[
0
]
=
rhs
[
0
];
rhs
[
0
]
=
tmp
;
}
template
<
typename
T
>
__inline__
__device__
bool
IOUDecision
(
T
*
output
,
int
box_A_ix
,
int
box_B_ix
,
int
box_A_start
,
int
box_B_start
,
T
*
area
,
float
IOU_value
)
{
T
x_1
=
max
(
output
[
box_A_start
+
0
],
output
[
box_B_start
+
0
]);
T
y_1
=
max
(
output
[
box_A_start
+
1
],
output
[
box_B_start
+
1
]);
T
x_2
=
min
(
output
[
box_A_start
+
2
],
output
[
box_B_start
+
2
]);
T
y_2
=
min
(
output
[
box_A_start
+
3
],
output
[
box_B_start
+
3
]);
T
width
=
max
(
x_2
-
x_1
,
T
(
0
));
// in case of no overlap
T
height
=
max
(
y_2
-
y_1
,
T
(
0
));
T
combined_area
=
area
[
box_A_ix
]
+
area
[
box_B_ix
];
// return decision to keep or remove box
return
!
(((
width
*
height
)
/
(
combined_area
-
(
width
*
height
)))
>
IOU_value
);
}
template
<
typename
T
>
__global__
void
Preprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
output
,
int
box_size_
)
{
for
(
int
box_num
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
box_num
<
num
;
box_num
+=
blockDim
.
x
*
gridDim
.
x
)
{
sel_idx
[
box_num
]
=
box_num
;
area
[
box_num
]
=
(
output
[(
box_num
*
box_size_
)
+
2
]
-
output
[(
box_num
*
box_size_
)
+
0
])
*
(
output
[(
box_num
*
box_size_
)
+
3
]
-
output
[(
box_num
*
box_size_
)
+
1
]);
}
}
template
<
typename
T
>
__global__
void
NMSWithMaskKernel
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
)
{
for
(
int
box_num
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
box_num
<
num
;
box_num
+=
blockDim
.
x
*
gridDim
.
x
)
{
// represents highest score box in that GPU block
if
(
threadIdx
.
x
==
0
)
{
sel_boxes
[
box_num
]
=
true
;
continue
;
}
int
box_start_index
=
box_num
*
box_size_
;
// start index adjustment
int
block_max_box_num
=
((
blockIdx
.
x
*
blockDim
.
x
)
+
0
);
int
block_max_box_start_index
=
block_max_box_num
*
box_size_
;
// start index adjustment
sel_boxes
[
box_num
]
=
IOUDecision
(
output
,
box_num
,
block_max_box_num
,
block_max_box_start_index
,
box_start_index
,
area
,
IOU_value
);
// update mask
}
}
template
<
typename
T
>
__global__
void
FinalPass
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
)
{
int
box_i
,
box_j
;
// access all shared mem meta data with these
int
box_i_start_index
,
box_j_start_index
;
// actual input data indexing
for
(
int
i
=
0
;
i
<
num
-
1
;
i
++
)
{
box_i
=
i
;
box_i_start_index
=
box_i
*
box_size_
;
// adjust starting index
if
(
sel_boxes
[
box_i
])
{
for
(
int
j
=
i
+
1
;
j
<
num
;
j
++
)
{
box_j
=
j
;
box_j_start_index
=
box_j
*
box_size_
;
if
(
sel_boxes
[
box_j
])
{
sel_boxes
[
box_j
]
=
IOUDecision
(
output
,
box_i
,
box_j
,
box_i_start_index
,
box_j_start_index
,
area
,
IOU_value
);
}
}
}
}
}
template
<
typename
T
,
typename
S
>
__global__
void
BitonicSortByKeyKernelM
(
const
int
outer
,
const
int
inner
,
const
int
ceil_power2
,
S
*
data_in
,
S
*
data_out
,
T
*
index_buff
,
S
*
data_buff
,
int
box_size_
)
{
// default: sort with share memory
extern
__shared__
T
share_mem_NMS
[];
T
*
index_arr
=
share_mem_NMS
;
S
*
data_arr
=
reinterpret_cast
<
S
*>
(
index_arr
+
ceil_power2
);
// sort with RAM
if
(
index_buff
!=
nullptr
&&
data_buff
!=
nullptr
)
{
index_arr
=
index_buff
+
blockIdx
.
x
*
ceil_power2
;
data_arr
=
data_buff
+
blockIdx
.
x
*
ceil_power2
;
}
for
(
int
i
=
threadIdx
.
x
;
i
<
ceil_power2
;
i
+=
blockDim
.
x
)
{
index_arr
[
i
]
=
(
i
<
inner
)
?
T
(
i
)
:
std
::
numeric_limits
<
T
>::
max
();
// populated directly from input data
data_arr
[
i
]
=
(
i
<
inner
)
?
data_in
[(
blockIdx
.
x
*
inner
+
i
)
*
box_size_
+
4
]
:
std
::
numeric_limits
<
S
>::
max
();
}
__syncthreads
();
for
(
size_t
i
=
2
;
i
<=
ceil_power2
;
i
<<=
1
)
{
for
(
size_t
j
=
(
i
>>
1
);
j
>
0
;
j
>>=
1
)
{
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
ceil_power2
;
tid
+=
blockDim
.
x
)
{
size_t
tid_comp
=
tid
^
j
;
if
(
tid_comp
>
tid
)
{
if
((
tid
&
i
)
==
0
)
{
if
(
data_arr
[
tid
]
>
data_arr
[
tid_comp
])
{
SwapM
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
SwapM
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
}
}
else
{
if
(
data_arr
[
tid
]
<
data_arr
[
tid_comp
])
{
SwapM
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
SwapM
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
}
}
}
}
__syncthreads
();
}
}
T
correct_index
;
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
inner
;
tid
+=
blockDim
.
x
)
{
correct_index
=
index_arr
[(
inner
-
1
)
-
tid
];
// moved data from input to output, correct ordering using sorted index array
for
(
auto
i
:
{
0
,
1
,
2
,
3
,
4
})
{
data_out
[(
blockIdx
.
x
*
inner
+
tid
)
*
box_size_
+
i
]
=
data_in
[(
blockIdx
.
x
*
inner
+
correct_index
)
*
box_size_
+
i
];
}
}
}
template
<
typename
T
>
void
CalPreprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
output
,
int
box_size_
,
cudaStream_t
cuda_stream
)
{
Preprocess
<<<
GET_BLOCKS
(
num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
num
,
sel_idx
,
area
,
output
,
box_size_
);
}
template
<
typename
T
,
typename
S
>
void
BitonicSortByKeyM
(
const
int
&
outer
,
const
int
&
inner
,
S
*
data_in
,
S
*
data_out
,
T
*
index_buff
,
S
*
data_buff
,
int
box_size_
,
cudaStream_t
stream
)
{
int
ceil_power2
=
RoundUpPower2M
(
inner
);
size_t
share_mem
=
ceil_power2
*
(
sizeof
(
T
)
+
sizeof
(
S
));
if
(
share_mem
>
SHARED_MEM_PER_BLOCK
)
{
share_mem
=
0
;
}
else
{
data_buff
=
nullptr
;
index_buff
=
nullptr
;
}
int
thread
=
std
::
min
(
ceil_power2
,
GET_THREADS
);
BitonicSortByKeyKernelM
<<<
outer
,
thread
,
share_mem
,
stream
>>>
(
outer
,
inner
,
ceil_power2
,
data_in
,
data_out
,
index_buff
,
data_buff
,
box_size_
);
}
template
<
typename
T
>
void
CalNMSWithMask
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
)
{
NMSWithMaskKernel
<<<
GET_BLOCKS
(
num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
num
,
IOU_value
,
output
,
area
,
sel_boxes
,
box_size_
);
}
template
<
typename
T
>
void
CalFinalPass
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
)
{
FinalPass
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
num
,
IOU_value
,
output
,
area
,
sel_boxes
,
box_size_
);
}
template
void
CalPreprocess
<
float
>(
const
int
num
,
int
*
sel_idx
,
float
*
area
,
float
*
output
,
int
box_size_
,
cudaStream_t
cuda_stream
);
template
void
BitonicSortByKeyM
(
const
int
&
outer
,
const
int
&
inner
,
float
*
data_in
,
float
*
data_out
,
int
*
index_buff
,
float
*
data_buff
,
int
box_size_
,
cudaStream_t
stream
);
template
void
CalNMSWithMask
<
float
>(
const
int
num
,
const
float
IOU_value
,
float
*
output
,
float
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
template
void
CalFinalPass
<
float
>(
const
int
num
,
const
float
IOU_value
,
float
*
output
,
float
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh
0 → 100644
浏览文件 @
e1f93ec2
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
void
CalPreprocess
(
const
int
num
,
int
*
sel_idx
,
T
*
area
,
T
*
output
,
int
box_size_
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
CalNMSWithMask
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
template
<
typename
T
,
typename
S
>
void
BitonicSortByKeyM
(
const
int
&
outer
,
const
int
&
inner
,
S
*
data_in
,
S
*
data_out
,
T
*
index_buff
,
S
*
data_buff
,
int
box_size_
,
cudaStream_t
stream
);
template
<
typename
T
>
void
CalFinalPass
(
const
int
num
,
const
float
IOU_value
,
T
*
output
,
T
*
area
,
bool
*
sel_boxes
,
int
box_size_
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.cc
0 → 100644
浏览文件 @
e1f93ec2
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
NMSWithMask
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeInt32
)
.
AddOutputAttr
(
kNumberTypeBool
),
NMSWithMaskGpuFwdKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h
0 → 100644
浏览文件 @
e1f93ec2
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
#include <vector>
#include <memory>
#include <iostream>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
NMSWithMaskGpuFwdKernel
:
public
GpuKernel
{
public:
NMSWithMaskGpuFwdKernel
()
:
num_input_
(
0
),
iou_value_
(
0.5
),
input_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
)
{}
~
NMSWithMaskGpuFwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
data_buff
=
GetDeviceAddress
<
T
>
(
workspace
,
0
);
// sort buffer
int
*
index_buff
=
GetDeviceAddress
<
int
>
(
workspace
,
1
);
T
*
area
=
GetDeviceAddress
<
T
>
(
workspace
,
2
);
// store area values for all boxes
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
int
*
sel_idx
=
GetDeviceAddress
<
int
>
(
outputs
,
1
);
bool
*
sel_boxes
=
GetDeviceAddress
<
bool
>
(
outputs
,
2
);
BitonicSortByKeyM
(
num_input_
,
num_input_
,
input
,
output
,
index_buff
,
data_buff
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalPreprocess
(
num_input_
,
sel_idx
,
area
,
output
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalNMSWithMask
(
num_input_
,
iou_value_
,
output
,
area
,
sel_boxes
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFinalPass
(
num_input_
,
iou_value_
,
output
,
area
,
sel_boxes
,
box_size_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
iou_value_
=
GetAttr
<
float
>
(
kernel_node
,
"iou_threshold"
);
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but NMSWithMask needs 1 input."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but NMSWithMask needs 3 output."
;
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
if
(
CHECK_NULL_INPUT
(
input_shape
))
{
MS_LOG
(
WARNING
)
<<
"NMSWithMask input is null"
;
InitSizeLists
();
return
true
;
}
num_input_
=
input_shape
[
0
];
// Get N value in [N,5] data
input_size_
=
num_input_
*
sizeof
(
T
)
*
box_size_
;
// 5 values per bbox
output_size_
=
(
input_size_
)
+
(
num_input_
*
sizeof
(
int
))
+
(
num_input_
*
sizeof
(
bool
));
workspace_size_
=
(
2
*
num_input_
*
sizeof
(
T
))
+
(
1
*
num_input_
*
sizeof
(
int
));
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
// N sized input/output data
input_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
)
*
box_size_
);
output_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
)
*
box_size_
);
output_size_list_
.
push_back
(
num_input_
*
sizeof
(
int
));
output_size_list_
.
push_back
(
num_input_
*
sizeof
(
bool
));
// N sized workspace arrs
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
));
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
int
));
workspace_size_list_
.
push_back
(
num_input_
*
sizeof
(
T
));
}
private:
int
num_input_
;
float
iou_value_
;
static
const
int
box_size_
=
5
;
// pre_defined box width
// int box_size__ = 5; // current size of bboxes
// default values
size_t
input_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
tests/st/ops/gpu/test_nms_with_mask_op.py
0 → 100644
浏览文件 @
e1f93ec2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
def
manualNMS
(
bbox
,
overlap_val_iou
):
mask
=
[
True
]
*
len
(
bbox
)
for
box_a_index
,
_
in
enumerate
(
bbox
):
if
not
mask
[
box_a_index
]:
continue
# ignore if not in list
box_a
=
bbox
[
box_a_index
]
# select box for value extraction
for
box_b_index
in
range
(
box_a_index
+
1
,
len
(
bbox
)):
if
not
mask
[
box_b_index
]:
continue
# ignore if not in list
box_b
=
bbox
[
box_b_index
]
areaA
=
(
box_a
[
2
]
-
box_a
[
0
])
*
(
box_a
[
3
]
-
box_a
[
1
])
areaB
=
(
box_b
[
2
]
-
box_b
[
0
])
*
(
box_b
[
3
]
-
box_b
[
1
])
overlap_x1
=
max
(
box_a
[
0
],
box_b
[
0
])
overlap_y1
=
max
(
box_a
[
1
],
box_b
[
1
])
overlap_x2
=
min
(
box_a
[
2
],
box_b
[
2
])
overlap_y2
=
min
(
box_a
[
3
],
box_b
[
3
])
width
=
max
((
overlap_x2
-
overlap_x1
),
0
)
height
=
max
((
overlap_y2
-
overlap_y1
),
0
)
# generate IOU decision
mask
[
box_b_index
]
=
not
(
(
width
*
height
)
/
(
areaA
+
areaB
-
(
width
*
height
)))
>
overlap_val_iou
return
mask
def
runMSRun
(
op
,
bbox
):
inputs
=
Tensor
(
bbox
,
mindspore
.
float32
)
box
,
_
,
mask
=
op
(
inputs
)
box
=
box
.
asnumpy
()
mask
=
mask
.
asnumpy
()
sel_idx
=
np
.
where
(
mask
)
sel_rows
=
box
[
sel_idx
][:,
0
:
4
]
sel_score
=
box
[
sel_idx
][:,
-
1
]
return
sel_rows
,
sel_score
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nms_with_mask_check_order
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
nms_op
=
P
.
NMSWithMask
(
0.5
)
for
_
in
range
(
500
):
count
=
20
box
=
np
.
random
.
randint
(
1
,
100
,
size
=
(
count
,
4
))
box
[:,
2
]
=
box
[:,
0
]
+
box
[:,
2
]
box
[:,
3
]
=
box
[:,
1
]
+
box
[:,
3
]
unsorted_scores
=
np
.
random
.
rand
(
count
,
1
)
bbox
=
np
.
hstack
((
box
,
unsorted_scores
))
bbox
=
Tensor
(
bbox
,
dtype
=
mindspore
.
float32
)
prop
,
_
,
_
=
nms_op
(
bbox
)
ms_sorted_scores
=
(
prop
.
asnumpy
()[:,
-
1
])
# select just scores
np_sorted_scores
=
(
np
.
sort
(
unsorted_scores
,
axis
=
0
)[::
-
1
][:,
0
])
# sort manually
np
.
testing
.
assert_array_almost_equal
(
ms_sorted_scores
,
np_sorted_scores
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nms_with_masl_check_result
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
test_count
=
500
for
x
in
range
(
1
,
test_count
+
1
):
count
=
20
# size of bbox lists
nms_op
=
P
.
NMSWithMask
(
x
*
0.002
)
# will test full range b/w 0 and 1
box
=
np
.
random
.
randint
(
1
,
100
,
size
=
(
count
,
4
))
box
[:,
2
]
=
box
[:,
0
]
+
box
[:,
2
]
box
[:,
3
]
=
box
[:,
1
]
+
box
[:,
3
]
unsorted_scores
=
np
.
random
.
rand
(
count
,
1
)
sorted_scores
=
np
.
sort
(
unsorted_scores
,
axis
=
0
)[::
-
1
]
bbox
=
np
.
hstack
((
box
,
sorted_scores
))
bbox
=
Tensor
(
bbox
,
dtype
=
mindspore
.
float32
)
_
,
_
,
mask
=
nms_op
(
bbox
)
mask
=
mask
.
asnumpy
()
manual_mask
=
manualNMS
(
box
,
x
*
0.002
)
np
.
testing
.
assert_array_equal
(
mask
,
np
.
array
(
manual_mask
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nms_with_mask_edge_case_1
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
# CASE 1 - FULL OVERLAP BOXES - Every box is duplicated and has a different score
nms_op1
=
P
.
NMSWithMask
(
0.3
)
bbox1
=
[[
12
,
4
,
33
,
17
,
0.6
],
[
20
,
11
,
38
,
23
,
0.1
],
[
20
,
10
,
45
,
26
,
0.9
],
[
15
,
17
,
35
,
38
,
0.5
],
[
10
,
20
,
30
,
40
,
0.4
],
[
35
,
35
,
89
,
90
,
0.8
],
[
12
,
4
,
33
,
17
,
0.3
],
[
20
,
11
,
38
,
23
,
0.2
],
[
20
,
10
,
45
,
26
,
0.1
],
[
15
,
17
,
35
,
38
,
0.8
],
[
10
,
20
,
30
,
40
,
0.41
],
[
35
,
35
,
89
,
90
,
0.82
]]
expected_bbox
=
np
.
array
([[
20.
,
10.
,
45.
,
26.
],
[
35.
,
35.
,
89.
,
90.
],
[
15.
,
17.
,
35.
,
38.
],
[
12.
,
4.
,
33.
,
17.
]])
expected_score
=
np
.
array
([
0.9
,
0.82
,
0.8
,
0.6
])
sel_rows
,
sel_score
=
runMSRun
(
nms_op1
,
bbox1
)
np
.
testing
.
assert_almost_equal
(
sel_rows
,
expected_bbox
)
np
.
testing
.
assert_almost_equal
(
sel_score
,
expected_score
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nms_with_mask_edge_case_2
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
# CASE 2 - 0 value boxes - with valid scores
nms_op2
=
P
.
NMSWithMask
(
0.5
)
bbox2
=
[[
0
,
0
,
0
,
0
,
0.6
],
[
0
,
0
,
0
,
0
,
0.1
]]
expected_bbox
=
np
.
array
([[
0.
,
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
,
0.
]])
expected_score
=
np
.
array
([
0.6
,
0.1
])
sel_rows
,
sel_score
=
runMSRun
(
nms_op2
,
bbox2
)
np
.
testing
.
assert_almost_equal
(
sel_rows
,
expected_bbox
)
np
.
testing
.
assert_almost_equal
(
sel_score
,
expected_score
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_nms_with_mask_edge_case_3
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
# CASE 3 - x2/x1 and y2/y1 sequence out of place
nms_op3
=
P
.
NMSWithMask
(
0.7
)
bbox3
=
[[
70
,
70
,
45
,
75
,
0.6
],
[
30
,
33
,
43
,
29
,
0.1
]]
expected_bbox
=
np
.
array
([[
70.
,
70.
,
45.
,
75.
],
[
30.
,
33.
,
43.
,
29.
]])
expected_score
=
np
.
array
([
0.6
,
0.1
])
sel_rows
,
sel_score
=
runMSRun
(
nms_op3
,
bbox3
)
np
.
testing
.
assert_almost_equal
(
sel_rows
,
expected_bbox
)
np
.
testing
.
assert_almost_equal
(
sel_score
,
expected_score
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录