Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6e523569
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e523569
编写于
8月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4270 [MS][LITE]add op_roipooling and testcase
Merge pull request !4270 from songhonglei413/roi
上级
8461b2f0
49399972
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
460 addition
and
0 deletion
+460
-0
mindspore/lite/schema/model.fbs
mindspore/lite/schema/model.fbs
+1
-0
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+6
-0
mindspore/lite/src/ops/ops.cc
mindspore/lite/src/ops/ops.cc
+2
-0
mindspore/lite/src/ops/ops.h
mindspore/lite/src/ops/ops.h
+7
-0
mindspore/lite/src/ops/roi_pooling.cc
mindspore/lite/src/ops/roi_pooling.cc
+58
-0
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+17
-0
mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.cc
mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.cc
+112
-0
mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.h
mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.h
+51
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/roi_pooling.cc
...ore/lite/src/runtime/kernel/arm/nnacl/fp32/roi_pooling.cc
+96
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h
...pore/lite/src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h
+30
-0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/roi_pooling_fp32_tests.cc
.../ut/src/runtime/kernel/arm/fp32/roi_pooling_fp32_tests.cc
+80
-0
未找到文件。
mindspore/lite/schema/model.fbs
浏览文件 @
6e523569
...
...
@@ -56,6 +56,7 @@ union PrimitiveType {
BatchNorm,
BiasAdd,
Pooling,
ROIPooling,
DepthwiseConv2D,
DeDepthwiseConv2D,
Resize,
...
...
mindspore/lite/schema/ops.fbs
浏览文件 @
6e523569
...
...
@@ -262,6 +262,12 @@ table BiasAdd {
axis: [int];
}
table ROIPooling {
pooledH: int;
pooledW: int;
scale: float;
}
table Pooling {
format: Format = 0;
poolingMode: PoolMode;
...
...
mindspore/lite/src/ops/ops.cc
浏览文件 @
6e523569
...
...
@@ -35,6 +35,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) {
return
new
lite
::
Reduce
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_Pooling
:
return
new
lite
::
Pooling
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_ROIPooling
:
return
new
lite
::
ROIPooling
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_DepthwiseConv2D
:
return
new
lite
::
DepthwiseConv2D
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_FusedBatchNorm
:
...
...
mindspore/lite/src/ops/ops.h
浏览文件 @
6e523569
...
...
@@ -56,6 +56,13 @@ class Primitive {
bool
infer_flag_
=
true
;
};
class
ROIPooling
:
public
Primitive
{
public:
explicit
ROIPooling
(
schema
::
Primitive
*
primitive
)
:
Primitive
(
primitive
)
{}
const
schema
::
ROIPooling
*
GetAttribute
()
const
{
return
this
->
primitive
->
value_as_ROIPooling
();
}
int
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
override
;
};
class
Conv2D
:
public
Primitive
{
public:
explicit
Conv2D
(
schema
::
Primitive
*
primitive
)
:
Primitive
(
primitive
)
{}
...
...
mindspore/lite/src/ops/roi_pooling.cc
0 → 100644
浏览文件 @
6e523569
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace
mindspore
::
lite
{
int
ROIPooling
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
if
(
inputs_
.
size
()
!=
kDoubleNum
)
{
MS_LOG
(
ERROR
)
<<
"inputs number is not equal to "
<<
kDoubleNum
;
return
RET_ERROR
;
}
auto
input
=
inputs_
.
front
();
if
(
input
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
roi
=
inputs_
.
at
(
1
);
if
(
roi
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
output
=
outputs_
.
front
();
if
(
output
==
nullptr
)
{
return
RET_NULL_PTR
;
}
auto
ROIPooling
=
GetAttribute
();
auto
new_h
=
ROIPooling
->
pooledH
();
auto
new_w
=
ROIPooling
->
pooledW
();
auto
shape_data
=
roi
->
shape
();
std
::
vector
<
int
>
output_shape
;
output_shape
.
push_back
(
shape_data
[
0
]);
output_shape
.
push_back
(
new_h
);
output_shape
.
push_back
(
new_w
);
output_shape
.
push_back
(
input
->
Channel
());
output
->
set_shape
(
output_shape
);
output
->
set_data_type
(
input
->
data_type
());
output
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/populate_parameter.cc
浏览文件 @
6e523569
...
...
@@ -33,6 +33,7 @@
#include "src/runtime/kernel/arm/nnacl/conv_parameter.h"
#include "src/runtime/kernel/arm/nnacl/fp32/pooling.h"
#include "src/runtime/kernel/arm/nnacl/matmul_parameter.h"
#include "src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h"
#include "src/runtime/kernel/arm/nnacl/softmax_parameter.h"
#include "src/runtime/kernel/arm/nnacl/tile.h"
#include "src/runtime/kernel/arm/nnacl/fp32/topk.h"
...
...
@@ -74,6 +75,21 @@
#include "src/runtime/kernel/arm/nnacl/fp32/elu.h"
namespace
mindspore
::
kernel
{
OpParameter
*
PopulateROIPoolingParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
pooling_primitive
=
primitive
->
Value
()
->
value_as_ROIPooling
();
ROIPoolingParameter
*
param
=
new
(
std
::
nothrow
)
ROIPoolingParameter
();
if
(
param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PoolingParameter failed."
;
return
nullptr
;
}
param
->
op_parameter_
.
type_
=
primitive
->
Type
();
param
->
pooledH_
=
pooling_primitive
->
pooledH
();
param
->
pooledW_
=
pooling_primitive
->
pooledW
();
param
->
scale_
=
pooling_primitive
->
scale
();
return
reinterpret_cast
<
OpParameter
*>
(
param
);
}
OpParameter
*
PopulateBatchNorm
(
const
lite
::
Primitive
*
primitive
)
{
BatchNormParameter
*
batch_norm_param
=
new
(
std
::
nothrow
)
BatchNormParameter
();
if
(
batch_norm_param
==
nullptr
)
{
...
...
@@ -1270,6 +1286,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_
[
schema
::
PrimitiveType_Reduce
]
=
PopulateReduceParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Mean
]
=
PopulateMeanParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Pooling
]
=
PopulatePoolingParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ROIPooling
]
=
PopulateROIPoolingParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_DepthwiseConv2D
]
=
PopulateConvDwParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_DeDepthwiseConv2D
]
=
PopulateDeconvDwParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_DeConv2D
]
=
PopulateDeconvParameter
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.cc
0 → 100644
浏览文件 @
6e523569
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/roi_pooling.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using
mindspore
::
kernel
::
KERNEL_ARCH
::
kCPU
;
using
mindspore
::
lite
::
KernelRegistrar
;
using
mindspore
::
lite
::
RET_ERROR
;
using
mindspore
::
lite
::
RET_OK
;
using
mindspore
::
schema
::
PrimitiveType_ROIPooling
;
namespace
mindspore
::
kernel
{
int
ROIPoolingCPUKernel
::
Init
()
{
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
ROIPoolingCPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
ROIPoolingCPUKernel
::
DoExecute
(
int
task_id
)
{
auto
ret
=
ROIPooling
(
in_ptr_
,
out_ptr_
,
roi_ptr_
,
in_shape_
,
out_shape_
,
dim_
,
task_id
,
param_
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ROIPooling Execute error task_id["
<<
task_id
<<
"] error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
RET_OK
;
}
int
ROIPoolingRun
(
int
task_id
,
LiteParallelGroupEnv
*
penv
,
void
*
cdata
)
{
auto
Data
=
reinterpret_cast
<
ROIPoolingCPUKernel
*>
(
cdata
);
auto
ret
=
Data
->
DoExecute
(
task_id
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ROIPooling Run error task_id["
<<
task_id
<<
"] error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
RET_OK
;
}
int
ROIPoolingCPUKernel
::
Run
()
{
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail! ret: "
<<
ret
;
return
ret
;
}
in_ptr_
=
reinterpret_cast
<
float
*>
(
inputs_
.
front
()
->
Data
());
out_ptr_
=
reinterpret_cast
<
float
*>
(
outputs_
.
front
()
->
Data
());
roi_ptr_
=
reinterpret_cast
<
float
*>
(
inputs_
.
at
(
1
)
->
Data
());
in_shape_
=
reinterpret_cast
<
const
int
*>
(
inputs_
.
front
()
->
shape
().
data
());
out_shape_
=
reinterpret_cast
<
const
int
*>
(
outputs_
.
front
()
->
shape
().
data
());
dim_
=
inputs_
.
front
()
->
shape
().
size
();
thread_count_
=
1
;
ret
=
LiteBackendParallelLaunch
(
ROIPoolingRun
,
this
,
thread_count_
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ROIPooling error: error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
ret
;
}
kernel
::
LiteKernel
*
CpuROIPoolingFp32KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
opParameter
,
const
lite
::
Context
*
ctx
,
const
kernel
::
KernelKey
&
desc
,
const
lite
::
Primitive
*
primitive
)
{
if
(
opParameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Input opParameter is nullptr!"
;
return
nullptr
;
}
if
(
ctx
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Input context is nullptr!"
;
return
nullptr
;
}
if
(
ctx
->
thread_num_
==
0
)
{
MS_LOG
(
ERROR
)
<<
"context thread num is 0!"
;
return
nullptr
;
}
auto
*
kernel
=
new
(
std
::
nothrow
)
ROIPoolingCPUKernel
(
opParameter
,
inputs
,
outputs
,
ctx
,
primitive
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ROIPoolingCPUKernel fail!"
;
return
nullptr
;
}
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
)
{
delete
kernel
;
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
opParameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
opParameter
->
type_
));
return
nullptr
;
}
return
kernel
;
}
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_ROIPooling
,
CpuROIPoolingFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.h
0 → 100644
浏览文件 @
6e523569
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ROI_POOLING_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ROI_POOLING_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h"
namespace
mindspore
::
kernel
{
class
ROIPoolingCPUKernel
:
public
LiteKernel
{
public:
ROIPoolingCPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
const
lite
::
Context
*
ctx
,
const
lite
::
Primitive
*
primitive
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
,
ctx
,
primitive
)
{
param_
=
reinterpret_cast
<
ROIPoolingParameter
*>
(
parameter
);
}
~
ROIPoolingCPUKernel
()
override
=
default
;
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
int
DoExecute
(
int
task_id
);
private:
float
*
in_ptr_
;
float
*
out_ptr_
;
float
*
roi_ptr_
;
const
int
*
in_shape_
;
const
int
*
out_shape_
;
ROIPoolingParameter
*
param_
;
int
dim_
;
int
thread_count_
;
};
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_REVERSE_H_
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/roi_pooling.cc
0 → 100644
浏览文件 @
6e523569
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/roi_pooling.h"
#include <math.h>
#include "nnacl/errorcode.h"
int
ROIPooling
(
float
*
in_ptr
,
float
*
out_ptr
,
float
*
roi
,
const
int
*
in_shape
,
const
int
*
out_shape
,
int
dim
,
int
tid
,
ROIPoolingParameter
*
param
)
{
int
num_rois
=
out_shape
[
kNHWC_N
];
int
batch_size
=
in_shape
[
kNHWC_N
];
int
height_
=
in_shape
[
kNHWC_H
];
int
width_
=
in_shape
[
kNHWC_W
];
int
channels_
=
in_shape
[
kNHWC_C
];
int
scale
=
param
->
scale_
;
int
pooled_height
=
param
->
pooledH_
;
int
pooled_width
=
param
->
pooledW_
;
int
in_stride
[
DIMENSION_4D
];
int
out_stride
[
DIMENSION_4D
];
int
roi_stride
=
5
;
in_stride
[
DIMENSION_4D
-
1
]
=
1
;
out_stride
[
DIMENSION_4D
-
1
]
=
1
;
for
(
int
i
=
dim
-
2
;
i
>=
0
;
--
i
)
{
in_stride
[
i
]
=
in_stride
[
i
+
1
]
*
in_shape
[
i
+
1
];
out_stride
[
i
]
=
out_stride
[
i
+
1
]
*
out_shape
[
i
+
1
];
}
int
roi_ind_st
=
0
;
for
(
int
i
=
0
;
i
<
num_rois
;
++
i
)
{
int
roi_batch_ind
=
(
int
)
roi
[
roi_ind_st
];
// batch_index
if
(
roi_batch_ind
>=
batch_size
)
{
return
NNACL_ERRCODE_INDEX_OUT_OF_RANGE
;
}
int
roi_start_h
=
(
int
)
roundf
(
roi
[
roi_ind_st
+
1
]
*
scale
);
// top-left x1
int
roi_start_w
=
(
int
)
roundf
(
roi
[
roi_ind_st
+
2
]
*
scale
);
// top-left y1
int
roi_end_h
=
(
int
)
roundf
(
roi
[
roi_ind_st
+
3
]
*
scale
);
// bottom-right x2
int
roi_end_w
=
(
int
)
roundf
(
roi
[
roi_ind_st
+
4
]
*
scale
);
// bottom-fight y2
int
roi_height
=
MSMAX
(
roi_end_h
-
roi_start_h
+
1
,
1
);
int
roi_width
=
MSMAX
(
roi_end_w
-
roi_start_w
+
1
,
1
);
float
bin_size_h
=
(
float
)
roi_height
/
(
float
)
pooled_height
;
float
bin_size_w
=
(
float
)
roi_width
/
(
float
)
pooled_width
;
float
*
batch_data
=
in_ptr
+
in_stride
[
kNHWC_N
]
*
roi_batch_ind
;
int
out_ind
=
i
*
out_stride
[
0
];
for
(
int
c
=
kNHWC_N
;
c
<
channels_
;
++
c
)
{
float
max_v
=
-
__FLT_MAX__
;
for
(
int
ph
=
0
;
ph
<
pooled_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
pooled_width
;
++
pw
)
{
int
pooled_index
=
i
*
out_stride
[
kNHWC_N
]
+
ph
*
out_stride
[
kNHWC_H
]
+
pw
*
out_stride
[
kNHWC_W
]
+
c
*
out_stride
[
kNHWC_C
];
int
hstart
=
(
int
)
floorf
(
ph
*
bin_size_h
);
// block xi_1
int
wstart
=
(
int
)
floorf
(
pw
*
bin_size_w
);
// block yi_1
int
hend
=
(
int
)
ceilf
((
ph
+
1
)
*
bin_size_h
);
// block xi_2
int
wend
=
(
int
)
ceilf
((
pw
+
1
)
*
bin_size_w
);
// block yi_2
hstart
=
MSMIN
(
MSMAX
(
hstart
+
roi_start_h
,
0
),
height_
);
hend
=
MSMIN
(
MSMAX
(
hend
+
roi_start_h
,
0
),
height_
);
wstart
=
MSMIN
(
MSMAX
(
wstart
+
roi_start_w
,
0
),
width_
);
wend
=
MSMIN
(
MSMAX
(
wend
+
roi_start_w
,
0
),
width_
);
bool
is_empty
=
(
hend
<=
hstart
)
||
(
wend
<=
wstart
);
if
(
is_empty
)
{
max_v
=
0
;
}
int
bd_index
=
c
*
in_stride
[
kNHWC_C
]
+
hstart
*
in_stride
[
kNHWC_H
];
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
int
wi
=
bd_index
+
wstart
*
in_stride
[
kNHWC_W
];
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
max_v
=
MSMAX
(
batch_data
[
wi
],
max_v
);
// printf("bd:index: %d, data: %f, max_v: %f\n",wi,batch_data[wi],max_v);
wi
+=
in_stride
[
kNHWC_W
];
}
bd_index
+=
in_stride
[
kNHWC_H
];
}
out_ptr
[
pooled_index
]
=
max_v
;
}
}
}
roi_ind_st
+=
roi_stride
;
}
return
NNACL_OK
;
}
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/roi_pooling.h
0 → 100644
浏览文件 @
6e523569
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ROI_POOLING_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ROI_POOLING_H_
#include "nnacl/op_base.h"
typedef
struct
ROIPoolingParameter
{
OpParameter
op_parameter_
;
int
pooledW_
;
int
pooledH_
;
float
scale_
;
}
ROIPoolingParameter
;
int
ROIPooling
(
float
*
in_ptr
,
float
*
out_ptr
,
float
*
roi
,
const
int
*
in_shape
,
const
int
*
out_shape
,
int
dim
,
int
tid
,
ROIPoolingParameter
*
param
);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP32_ROI_POOLING_H_
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/roi_pooling_fp32_tests.cc
0 → 100644
浏览文件 @
6e523569
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/roi_pooling.h"
#include "src/kernel_registry.h"
#include "src/lite_kernel.h"
namespace
mindspore
{
class
TestROIPoolingFp32
:
public
mindspore
::
Common
{
public:
TestROIPoolingFp32
()
{}
};
int
ROIPoolingTestInit
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
outputs_
,
float
*
a_ptr
,
float
*
b_ptr
,
std
::
vector
<
int
>
a_shape
,
std
::
vector
<
int
>
b_shape
,
std
::
vector
<
int
>
c_shape
)
{
auto
in_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat
,
a_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
in_t
->
MallocData
();
memcpy
(
in_t
->
Data
(),
a_ptr
,
sizeof
(
float
)
*
in_t
->
ElementsNum
());
inputs_
->
push_back
(
in_t
);
auto
roi_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat
,
b_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
roi_t
->
MallocData
();
memcpy
(
roi_t
->
Data
(),
b_ptr
,
sizeof
(
float
)
*
roi_t
->
ElementsNum
());
inputs_
->
push_back
(
roi_t
);
auto
out_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat
,
c_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
out_t
->
MallocData
();
outputs_
->
push_back
(
out_t
);
return
out_t
->
ElementsNum
();
}
TEST_F
(
TestROIPoolingFp32
,
Simple
)
{
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
;
auto
param
=
new
ROIPoolingParameter
();
param
->
scale_
=
1
;
param
->
pooledW_
=
2
;
param
->
pooledH_
=
2
;
float
a
[]
=
{
1
,
2
,
3
,
4
,
5
,
11
,
12
,
13
,
14
,
15
,
21
,
22
,
23
,
24
,
25
,
31
,
32
,
33
,
34
,
35
,
1
,
2
,
3
,
4
,
5
,
11
,
12
,
13
,
14
,
15
,
21
,
22
,
23
,
24
,
25
,
31
,
32
,
33
,
34
,
35
};
float
b
[]
=
{
0
,
1
,
1
,
3
,
4
,
1
,
1
,
1
,
3
,
4
};
std
::
vector
<
int
>
a_shape
=
{
2
,
4
,
5
,
1
};
std
::
vector
<
int
>
b_shape
=
{
2
,
5
};
std
::
vector
<
int
>
c_shape
=
{
2
,
2
,
2
,
1
};
int
total_size
=
ROIPoolingTestInit
(
&
inputs_
,
&
outputs_
,
a
,
b
,
a_shape
,
b_shape
,
c_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
1
;
kernel
::
ROIPoolingCPUKernel
*
op
=
new
kernel
::
ROIPoolingCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
op
->
Init
();
op
->
Run
();
float
correct
[]
=
{
23
,
25
,
33
,
35
,
23
,
25
,
33
,
35
};
float
*
output
=
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
());
for
(
int
i
=
0
;
i
<
8
;
++
i
)
printf
(
"%f "
,
output
[
i
]);
printf
(
"
\n
"
);
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
op
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
}
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录