Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b5df3b97
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b5df3b97
编写于
8月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4588 add op_constantofshape and testcase
Merge pull request !4588 from songhonglei413/roi
上级
d93bde32
b9e69b27
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
378 addition
and
1 deletion
+378
-1
mindspore/lite/schema/model.fbs
mindspore/lite/schema/model.fbs
+1
-0
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+4
-0
mindspore/lite/src/ops/canstant_of_shape.cc
mindspore/lite/src/ops/canstant_of_shape.cc
+55
-0
mindspore/lite/src/ops/ops.cc
mindspore/lite/src/ops/ops.cc
+2
-0
mindspore/lite/src/ops/ops.h
mindspore/lite/src/ops/ops.h
+7
-0
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+14
-0
mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc
...ore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc
+106
-0
mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.h
...pore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.h
+48
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.c
...ite/src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.c
+28
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h
...ite/src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h
+40
-0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc
...rc/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc
+72
-0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/roi_pooling_fp32_tests.cc
.../ut/src/runtime/kernel/arm/fp32/roi_pooling_fp32_tests.cc
+1
-1
未找到文件。
mindspore/lite/schema/model.fbs
浏览文件 @
b5df3b97
...
...
@@ -155,6 +155,7 @@ union PrimitiveType {
Select,
Scatter,
ScatterND,
ConstantOfShape,
Unique,
Unstack,
LogicalAnd,
...
...
mindspore/lite/schema/ops.fbs
浏览文件 @
b5df3b97
...
...
@@ -249,6 +249,10 @@ table PoolingGrad {
table Shape {
}
table ConstantOfShape{
value: float = 0;
}
table Nchw2Nhwc {
}
...
...
mindspore/lite/src/ops/canstant_of_shape.cc
0 → 100644
浏览文件 @
b5df3b97
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ops.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
namespace
mindspore
::
lite
{
namespace
{
constexpr
int
kShapeInputNum
=
1
;
constexpr
int
kShapeOutputNum
=
1
;
}
// namespace
int
ConstantOfShape
::
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs_
,
std
::
vector
<
tensor
::
Tensor
*>
outputs_
)
{
if
(
inputs_
.
size
()
!=
kShapeInputNum
)
{
MS_LOG
(
ERROR
)
<<
"inputs to ConstantOfShape operator should be 1, but "
<<
inputs_
.
size
()
<<
" is given."
;
return
RET_ERROR
;
}
if
(
inputs_
.
front
()
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"primitive is nullptr!"
;
return
RET_PARAM_INVALID
;
}
if
(
outputs_
.
size
()
!=
kShapeOutputNum
)
{
MS_LOG
(
ERROR
)
<<
"outputs to ConstantOfShape operator should be 1, but "
<<
outputs_
.
size
()
<<
" is given."
;
return
RET_ERROR
;
}
auto
in_tensor
=
inputs_
.
front
();
auto
in_data
=
reinterpret_cast
<
int
*>
(
in_tensor
->
Data
());
auto
out_tensor
=
outputs_
.
front
();
int
size
=
in_tensor
->
ElementsNum
();
std
::
vector
<
int
>
out_shape
(
size
);
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
out_shape
[
i
]
=
in_data
[
i
];
}
out_tensor
->
set_shape
(
out_shape
);
out_tensor
->
set_data_type
(
kNumberTypeFloat32
);
out_tensor
->
SetFormat
(
in_tensor
->
GetFormat
());
return
RET_OK
;
}
}
// namespace mindspore::lite
mindspore/lite/src/ops/ops.cc
浏览文件 @
b5df3b97
...
...
@@ -145,6 +145,8 @@ Primitive *Primitive::CreatePrimitive(schema::Primitive *primitive) {
return
new
lite
::
MatMul
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_EmbeddingLookup
:
return
new
lite
::
EmbeddingLookup
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
case
schema
::
PrimitiveType_ConstantOfShape
:
return
new
lite
::
ConstantOfShape
(
const_cast
<
schema
::
Primitive
*>
(
primitive
));
default:
break
;
}
...
...
mindspore/lite/src/ops/ops.h
浏览文件 @
b5df3b97
...
...
@@ -717,6 +717,13 @@ class Shape : public Primitive {
int
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs
,
std
::
vector
<
tensor
::
Tensor
*>
outputs
)
override
;
};
class
ConstantOfShape
:
public
Primitive
{
public:
explicit
ConstantOfShape
(
schema
::
Primitive
*
primitive
)
:
Primitive
(
primitive
)
{}
const
schema
::
ConstantOfShape
*
GetAttribute
()
const
{
return
this
->
primitive
->
value_as_ConstantOfShape
();
}
int
InferShape
(
std
::
vector
<
tensor
::
Tensor
*>
inputs
,
std
::
vector
<
tensor
::
Tensor
*>
outputs
)
override
;
};
class
ScatterND
:
public
Primitive
{
public:
explicit
ScatterND
(
schema
::
Primitive
*
primitive
)
:
Primitive
(
primitive
)
{}
...
...
mindspore/lite/src/populate_parameter.cc
浏览文件 @
b5df3b97
...
...
@@ -28,6 +28,7 @@
#include "src/runtime/kernel/arm/nnacl/fp32/broadcast_to.h"
#include "src/runtime/kernel/arm/nnacl/reshape_parameter.h"
#include "src/runtime/kernel/arm/nnacl/shape.h"
#include "src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h"
#include "src/runtime/kernel/arm/nnacl/fp32/stack.h"
#include "src/runtime/kernel/arm/nnacl/unstack.h"
#include "src/runtime/kernel/arm/nnacl/depth_to_space.h"
...
...
@@ -937,6 +938,18 @@ OpParameter *PopulateShapeParameter(const lite::Primitive *primitive) {
return
reinterpret_cast
<
OpParameter
*>
(
shape_param
);
}
OpParameter
*
PopulateConstantOfShapeParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
attr
=
primitive
->
Value
()
->
value_as_ConstantOfShape
();
ConstantOfShapeParameter
*
param
=
new
(
std
::
nothrow
)
ConstantOfShapeParameter
();
if
(
param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConstantOfShapeParameter failed."
;
return
nullptr
;
}
param
->
op_parameter_
.
type_
=
primitive
->
Type
();
param
->
value_
=
attr
->
value
();
return
reinterpret_cast
<
OpParameter
*>
(
param
);
}
OpParameter
*
PopulateReverseParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
reverse_attr
=
primitive
->
Value
()
->
value_as_Reverse
();
ReverseParameter
*
reverse_param
=
new
(
std
::
nothrow
)
ReverseParameter
();
...
...
@@ -1370,6 +1383,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_
[
schema
::
PrimitiveType_Cast
]
=
PopulateCastParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Scale
]
=
PopulateScaleParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Reshape
]
=
PopulateReshapeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ConstantOfShape
]
=
PopulateConstantOfShapeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Shape
]
=
PopulateShapeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Concat
]
=
PopulateConcatParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Tile
]
=
PopulateTileParameter
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.cc
0 → 100644
浏览文件 @
b5df3b97
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/constant_of_shape.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using
mindspore
::
kernel
::
KERNEL_ARCH
::
kCPU
;
using
mindspore
::
lite
::
KernelRegistrar
;
using
mindspore
::
lite
::
RET_ERROR
;
using
mindspore
::
lite
::
RET_OK
;
using
mindspore
::
schema
::
PrimitiveType_ConstantOfShape
;
namespace
mindspore
::
kernel
{
namespace
{
constexpr
int
kInputNum
=
1
;
constexpr
int
kOutputNum
=
1
;
}
// namespace
int
ConstantOfShapeCPUKernel
::
Init
()
{
return
RET_OK
;
}
int
ConstantOfShapeCPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
ConstantOfShapeCPUKernel
::
DoExecute
(
int
task_id
)
{
int
ret
=
ConstantOfShape
(
out_ptr_
,
task_id
,
param_
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ConstantOfShapeRun error task_id["
<<
task_id
<<
"] error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
RET_OK
;
}
int
ConstantOfShapeRun
(
int
task_id
,
LiteParallelGroupEnv
*
penv
,
void
*
cdata
)
{
auto
g_kernel
=
reinterpret_cast
<
ConstantOfShapeCPUKernel
*>
(
cdata
);
auto
ret
=
g_kernel
->
DoExecute
(
task_id
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ConstantOfShapeRun error task_id["
<<
task_id
<<
"] error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
RET_OK
;
}
int
ConstantOfShapeCPUKernel
::
Run
()
{
auto
prepare_ret
=
Prepare
();
if
(
prepare_ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail!ret: "
<<
prepare_ret
;
return
prepare_ret
;
}
param_
->
element_sz_
=
out_tensors_
.
front
()
->
ElementsNum
();
int
thread_num
=
MSMIN
(
param_
->
op_parameter_
.
thread_num_
,
param_
->
element_sz_
);
param_
->
unit_
=
UP_DIV
(
param_
->
element_sz_
,
thread_num
);
param_
->
op_parameter_
.
thread_num_
=
thread_num
;
out_ptr_
=
reinterpret_cast
<
float
*>
(
out_tensors_
.
front
()
->
Data
());
auto
ret
=
LiteBackendParallelLaunch
(
ConstantOfShapeRun
,
this
,
thread_num
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"ConstantOfShapeRun error error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
ret
;
}
kernel
::
LiteKernel
*
CpuConstantOfShapeFp32KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
opParameter
,
const
lite
::
Context
*
ctx
,
const
kernel
::
KernelKey
&
desc
,
const
lite
::
Primitive
*
primitive
)
{
MS_ASSERT
(
opParameter
!=
nullptr
);
if
(
opParameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Create kernel failed, opParameter is nullptr, type: PrimitiveType_ConstantOfShape. "
;
return
nullptr
;
}
MS_ASSERT
(
desc
.
type
==
schema
::
PrimitiveType_ConstantOfShape
);
auto
*
kernel
=
new
(
std
::
nothrow
)
ConstantOfShapeCPUKernel
(
opParameter
,
inputs
,
outputs
,
ctx
,
primitive
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConstantOfShapeCPUKernel fail!"
;
return
nullptr
;
}
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
opParameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
opParameter
->
type_
));
delete
kernel
;
return
nullptr
;
}
return
kernel
;
}
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_ConstantOfShape
,
CpuConstantOfShapeFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.h
0 → 100644
浏览文件 @
b5df3b97
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h"
using
mindspore
::
lite
::
Context
;
namespace
mindspore
::
kernel
{
class
ConstantOfShapeCPUKernel
:
public
LiteKernel
{
public:
ConstantOfShapeCPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
const
lite
::
Context
*
ctx
,
const
lite
::
Primitive
*
primitive
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
,
ctx
,
primitive
)
{
param_
=
reinterpret_cast
<
ConstantOfShapeParameter
*>
(
parameter
);
}
~
ConstantOfShapeCPUKernel
()
override
=
default
;
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
int
DoExecute
(
int
task_id
);
private:
ConstantOfShapeParameter
*
param_
;
float
*
out_ptr_
;
};
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONSTANT_OF_SHAPE_H_
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.c
0 → 100644
浏览文件 @
b5df3b97
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/constant_of_shape.h"
int
ConstantOfShape
(
float
*
output
,
int
tid
,
ConstantOfShapeParameter
*
param
)
{
int
size
=
param
->
unit_
;
float
data
=
param
->
value_
;
int
ind_st
=
MSMIN
(
tid
*
size
,
param
->
element_sz_
);
int
ind_end
=
MSMIN
(
param
->
element_sz_
,
(
tid
+
1
)
*
size
);
for
(
int
i
=
ind_st
;
i
<
ind_end
;
++
i
)
{
output
[
i
]
=
data
;
}
return
NNACL_OK
;
}
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/constant_of_shape.h
0 → 100644
浏览文件 @
b5df3b97
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONSTANT_OF_SHAPE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONSTANT_OF_SHAPE_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
typedef
struct
ConstantOfShapeParameter
{
OpParameter
op_parameter_
;
float
value_
;
int
unit_
;
int
element_sz_
;
}
ConstantOfShapeParameter
;
#ifdef __cplusplus
extern
"C"
{
#endif
int
ConstantOfShape
(
float
*
output
,
int
tid
,
ConstantOfShapeParameter
*
param
);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_CONSTANT_OF_SHAPE_H_
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/constant_of_shape_fp32_test.cc
0 → 100644
浏览文件 @
b5df3b97
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/constant_of_shape.h"
#include "src/kernel_registry.h"
#include "src/lite_kernel.h"
namespace
mindspore
{
class
TestConstantOfShapeFp32
:
public
mindspore
::
CommonTest
{
public:
TestConstantOfShapeFp32
()
{}
};
int
ConstantOfShapeTestInit
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
outputs_
,
float
*
a_ptr
,
std
::
vector
<
int
>
a_shape
)
{
auto
in_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeInt32
,
a_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
in_t
->
MallocData
();
memcpy
(
in_t
->
Data
(),
a_ptr
,
sizeof
(
float
)
*
in_t
->
ElementsNum
());
inputs_
->
push_back
(
in_t
);
std
::
vector
<
int
>
c_shape
(
in_t
->
ElementsNum
());
for
(
int
i
=
0
;
i
<
c_shape
.
size
();
++
i
)
{
c_shape
[
i
]
=
a_ptr
[
i
];
}
auto
out_t
=
new
lite
::
tensor
::
Tensor
(
kNumberTypeFloat
,
c_shape
,
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
out_t
->
MallocData
();
outputs_
->
push_back
(
out_t
);
return
out_t
->
ElementsNum
();
}
TEST_F
(
TestConstantOfShapeFp32
,
Simple
)
{
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
;
auto
param
=
new
ConstantOfShapeParameter
();
param
->
value_
=
1
;
float
a
[]
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
int
>
a_shape
=
{
4
,
1
,
1
,
1
};
// std::vector<int> c_shape = {2, 2, 2, 1};
int
total_size
=
ConstantOfShapeTestInit
(
&
inputs_
,
&
outputs_
,
a
,
a_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
4
;
kernel
::
ConstantOfShapeCPUKernel
*
op
=
new
kernel
::
ConstantOfShapeCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
op
->
Init
();
op
->
Run
();
float
correct
[]
=
{
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
};
float
*
output
=
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
());
for
(
int
i
=
0
;
i
<
8
;
++
i
)
printf
(
"%f "
,
output
[
i
]);
printf
(
"
\n
"
);
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
op
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
}
}
// namespace mindspore
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/roi_pooling_fp32_tests.cc
浏览文件 @
b5df3b97
...
...
@@ -63,7 +63,7 @@ TEST_F(TestROIPoolingFp32, Simple) {
std
::
vector
<
int
>
c_shape
=
{
2
,
2
,
2
,
1
};
int
total_size
=
ROIPoolingTestInit
(
&
inputs_
,
&
outputs_
,
a
,
b
,
a_shape
,
b_shape
,
c_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
1
;
ctx
->
thread_num_
=
3
;
kernel
::
ROIPoolingCPUKernel
*
op
=
new
kernel
::
ROIPoolingCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
param
),
inputs_
,
outputs_
,
ctx
,
nullptr
);
op
->
Init
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录