Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2aaba750
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2aaba750
编写于
8月 19, 2020
作者:
C
chenjianping
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
kernels support resize
上级
10be5005
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
260 addition
and
185 deletion
+260
-185
mindspore/lite/src/lite_session.cc
mindspore/lite/src/lite_session.cc
+7
-6
mindspore/lite/src/lite_session.h
mindspore/lite/src/lite_session.h
+0
-1
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc
+4
-4
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h
+1
-1
mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc
mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc
+5
-1
mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc
mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc
+12
-11
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc
+9
-10
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h
+4
-4
mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc
...spore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc
+17
-17
mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h
mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h
+1
-0
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
+38
-15
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h
+5
-4
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc
+7
-1
mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc
...pore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc
+28
-27
mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc
mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc
+14
-13
mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc
mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc
+10
-13
mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h
mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h
+3
-3
mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc
mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc
+0
-4
mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc
mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc
+22
-13
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc
+14
-5
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h
+8
-4
mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc
mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc
+5
-8
mindspore/lite/src/scheduler.cc
mindspore/lite/src/scheduler.cc
+12
-20
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc
...est/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc
+34
-0
未找到文件。
mindspore/lite/src/lite_session.cc
浏览文件 @
2aaba750
...
...
@@ -219,7 +219,6 @@ std::vector<mindspore::tensor::MSTensor *> LiteSession::GetInputs() const {
int
LiteSession
::
RunGraph
(
const
session
::
KernelCallBack
&
before
,
const
session
::
KernelCallBack
&
after
)
{
MS_EXCEPTION_IF_NULL
(
this
->
context_
);
SetMaxWokerNum
(
context_
->
thread_num_
);
context_
->
running_
=
true
;
if
(
before
==
nullptr
&&
after
==
nullptr
)
{
return
executor
->
Run
(
this
->
inputs_
,
this
->
outputs_
,
this
->
kernels_
,
this
->
context_
->
allocator
.
get
());
}
else
{
...
...
@@ -333,19 +332,21 @@ int LiteSession::ResizeInputs(const std::vector<mindspore::tensor::MSTensor *> &
}
int
LiteSession
::
Resize
(
const
std
::
vector
<
mindspore
::
tensor
::
MSTensor
*>
&
inputs
)
{
inputs_old_
.
clear
();
inputs_old_
=
inputs_
;
std
::
vector
<
tensor
::
Tensor
*>
inputs_old
(
inputs_
);
auto
ret
=
ResizeInputs
(
inputs
);
if
(
ret
!=
RET_OK
)
{
inputs_
=
inputs_old
_
;
inputs_
=
inputs_old
;
return
ret
;
}
Scheduler
scheduler
(
context_
);
ret
=
scheduler
.
ReSizeKernels
(
kernels_
);
if
(
ret
!=
RET_OK
)
{
inputs_
=
inputs_old_
;
scheduler
.
ReSizeKernels
(
kernels_
);
inputs_
=
inputs_old
;
auto
resize_ret
=
scheduler
.
ReSizeKernels
(
kernels_
);
if
(
resize_ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"restore kernel size fail!ret: "
<<
resize_ret
;
}
return
ret
;
}
return
RET_OK
;
...
...
mindspore/lite/src/lite_session.h
浏览文件 @
2aaba750
...
...
@@ -79,7 +79,6 @@ class LiteSession : public session::LiteSession {
std
::
vector
<
tensor
::
Tensor
*>
tensors_
;
// graph input tensors
std
::
vector
<
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
tensor
::
Tensor
*>
inputs_old_
;
// graph output tensors
std
::
vector
<
tensor
::
Tensor
*>
outputs_
;
// graph input MSTensors
...
...
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.cc
浏览文件 @
2aaba750
...
...
@@ -98,14 +98,14 @@ int ReduceBaseCPUKernel::Init() {
if
(
ret
!=
RET_OK
)
{
return
ret
;
}
ret
=
CheckParameters
();
if
(
ret
!=
RET_OK
)
{
return
ret
;
}
return
RET_OK
;
}
int
ReduceBaseCPUKernel
::
ReSize
()
{
return
CheckParameters
();
}
kernel
::
LiteKernel
*
CpuReduceFp32KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
opParameter
,
const
lite
::
Context
*
ctx
,
...
...
mindspore/lite/src/runtime/kernel/arm/base/reduce_base.h
浏览文件 @
2aaba750
...
...
@@ -32,7 +32,7 @@ class ReduceBaseCPUKernel : public LiteKernel {
virtual
~
ReduceBaseCPUKernel
()
=
default
;
int
Init
()
override
;
int
ReSize
()
override
{
return
0
;
}
;
int
ReSize
()
override
;
private:
int
CheckInputsOutputs
();
...
...
mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc
浏览文件 @
2aaba750
...
...
@@ -59,7 +59,11 @@ int ReduceFp16CPUKernel::Init() {
int
ReduceFp16CPUKernel
::
ReSize
()
{
FreeTmpBuffer
();
auto
ret
=
MallocTmpBuffer
();
auto
ret
=
ReduceBaseCPUKernel
::
ReSize
();
if
(
ret
!=
RET_OK
)
{
return
ret
;
}
ret
=
MallocTmpBuffer
();
if
(
ret
!=
RET_OK
)
{
FreeTmpBuffer
();
return
ret
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/batchnorm.cc
浏览文件 @
2aaba750
...
...
@@ -60,11 +60,21 @@ int BatchnormCPUKernel::InitConstTensor() {
}
int
BatchnormCPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
BatchnormCPUKernel
::
ReSize
()
{
if
(
mean_addr_
!=
nullptr
)
{
free
(
mean_addr_
);
mean_addr_
=
nullptr
;
}
if
(
var_addr_
!=
nullptr
)
{
free
(
var_addr_
);
var_addr_
=
nullptr
;
}
auto
input_shapes
=
in_tensors_
[
0
]
->
shape
();
auto
n_dim
=
input_shapes
.
size
();
batchnorm_param_
->
channel_
=
input_shapes
[
n_dim
-
1
];
...
...
@@ -83,15 +93,6 @@ int BatchnormCPUKernel::Init() {
return
RET_OK
;
}
int
BatchnormCPUKernel
::
ReSize
()
{
auto
input_shapes
=
in_tensors_
[
0
]
->
shape
();
batchnorm_param_
->
unit_
=
1
;
for
(
int
i
=
0
;
i
<
input_shapes
.
size
()
-
1
;
i
++
)
{
batchnorm_param_
->
unit_
*=
input_shapes
[
i
];
}
return
RET_OK
;
}
int
BatchnormCPUKernel
::
DoExecute
(
int
task_id
)
{
BatchNorm
(
out_addr_
,
in_addr_
,
mean_addr_
,
var_addr_
,
task_id
,
batchnorm_param_
);
return
RET_OK
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.cc
浏览文件 @
2aaba750
...
...
@@ -16,7 +16,6 @@
#include "src/runtime/kernel/arm/fp32/fullconnection.h"
#include "src/runtime/runtime_api.h"
using
mindspore
::
lite
::
RET_ERROR
;
using
mindspore
::
lite
::
RET_MEMORY_FAILED
;
using
mindspore
::
lite
::
RET_OK
;
...
...
@@ -48,15 +47,6 @@ void FullconnectionCPUKernel::FreeBuf() {
int
FullconnectionCPUKernel
::
ReSize
()
{
FreeBuf
();
Init
();
return
RET_OK
;
}
int
FullconnectionCPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
return
RET_OK
;
}
fc_param_
->
row_
=
(
in_tensors_
[
0
]
->
shape
())[
0
];
fc_param_
->
col_
=
(
in_tensors_
[
1
]
->
shape
())[
0
];
fc_param_
->
deep_
=
(
in_tensors_
[
1
]
->
shape
())[
1
];
...
...
@@ -81,12 +71,14 @@ int FullconnectionCPUKernel::Init() {
b_r8_ptr_
=
reinterpret_cast
<
float
*>
(
malloc
(
fc_param_
->
col_8_
*
fc_param_
->
deep_
*
sizeof
(
float
)));
if
(
b_r8_ptr_
==
nullptr
)
{
FreeBuf
();
return
RET_MEMORY_FAILED
;
}
memset
(
b_r8_ptr_
,
0
,
fc_param_
->
col_8_
*
fc_param_
->
deep_
*
sizeof
(
float
));
c_r8x8_ptr_
=
reinterpret_cast
<
float
*>
(
malloc
(
fc_param_
->
row_8_
*
fc_param_
->
col_8_
*
sizeof
(
float
)));
if
(
c_r8x8_ptr_
==
nullptr
)
{
FreeBuf
();
return
RET_MEMORY_FAILED
;
}
memset
(
c_r8x8_ptr_
,
0
,
fc_param_
->
row_8_
*
fc_param_
->
col_8_
*
sizeof
(
float
));
...
...
@@ -98,6 +90,13 @@ int FullconnectionCPUKernel::Init() {
return
RET_OK
;
}
int
FullconnectionCPUKernel
::
Init
()
{
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
void
FullconnectionCPUKernel
::
InitMatrixA
(
float
*
src_ptr
,
float
*
dst_ptr
)
{
if
(
fc_param_
->
a_const_
==
true
)
{
return
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/fullconnection.h
浏览文件 @
2aaba750
...
...
@@ -47,10 +47,10 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel {
void
InitMatrixB
(
float
*
src_ptr
,
float
*
dst_ptr
);
private:
float
*
a_c8_ptr_
;
float
*
b_r8_ptr_
;
float
*
c_r8x8_ptr_
;
float
*
bias_ptr_
;
float
*
a_c8_ptr_
=
nullptr
;
float
*
b_r8_ptr_
=
nullptr
;
float
*
c_r8x8_ptr_
=
nullptr
;
float
*
bias_ptr_
=
nullptr
;
};
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_FULLCONNECTION_H_
mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.cc
浏览文件 @
2aaba750
...
...
@@ -29,7 +29,9 @@ using mindspore::lite::RET_OK;
using
mindspore
::
schema
::
PrimitiveType_FusedBatchNorm
;
namespace
mindspore
::
kernel
{
FusedBatchnormCPUKernel
::~
FusedBatchnormCPUKernel
()
{
FusedBatchnormCPUKernel
::~
FusedBatchnormCPUKernel
()
{
FreeTmpBuffer
();
}
void
FusedBatchnormCPUKernel
::
FreeTmpBuffer
()
{
if
(
scale_addr_
!=
nullptr
)
{
free
(
scale_addr_
);
scale_addr_
=
nullptr
;
...
...
@@ -84,10 +86,14 @@ int FusedBatchnormCPUKernel::InitConstTensor() {
}
int
FusedBatchnormCPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
FusedBatchnormCPUKernel
::
ReSize
()
{
FreeTmpBuffer
();
auto
input_shapes
=
in_tensors_
[
0
]
->
shape
();
auto
n_dim
=
input_shapes
.
size
();
batchnorm_param_
->
channel_
=
input_shapes
[
n_dim
-
1
];
...
...
@@ -106,15 +112,6 @@ int FusedBatchnormCPUKernel::Init() {
return
RET_OK
;
}
int
FusedBatchnormCPUKernel
::
ReSize
()
{
auto
input_shapes
=
in_tensors_
[
0
]
->
shape
();
batchnorm_param_
->
unit_
=
1
;
for
(
int
i
=
0
;
i
<
input_shapes
.
size
()
-
1
;
i
++
)
{
batchnorm_param_
->
unit_
*=
input_shapes
[
i
];
}
return
RET_OK
;
}
int
FusedBatchnormCPUKernel
::
Execute
(
int
task_id
)
{
FusedBatchNorm
(
out_addr_
,
in_addr_
,
scale_addr_
,
offset_addr_
,
mean_addr_
,
var_addr_
,
task_id
,
batchnorm_param_
);
return
RET_OK
;
...
...
@@ -149,13 +146,16 @@ int FusedBatchnormCPUKernel::Run() {
kernel
::
LiteKernel
*
CpuFusedBatchnormKernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
op
P
arameter
,
const
lite
::
Context
*
ctx
,
OpParameter
*
op
_p
arameter
,
const
lite
::
Context
*
ctx
,
const
kernel
::
KernelKey
&
desc
,
const
mindspore
::
lite
::
PrimitiveC
*
primitive
)
{
MS_ASSERT
(
opParameter
!=
nullptr
);
if
(
op_parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Input parameter is nullptr!"
;
return
nullptr
;
}
MS_ASSERT
(
desc
.
type
==
schema
::
PrimitiveType_FusedBatchNorm
);
FusedBatchnormCPUKernel
*
kernel
=
new
(
std
::
nothrow
)
FusedBatchnormCPUKernel
(
op
P
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
new
(
std
::
nothrow
)
FusedBatchnormCPUKernel
(
op
_p
arameter
,
inputs
,
outputs
,
ctx
,
primitive
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new FusedBatchnormCPUKernel fail!"
;
return
nullptr
;
...
...
@@ -163,8 +163,8 @@ kernel::LiteKernel *CpuFusedBatchnormKernelCreator(const std::vector<lite::tenso
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
)
{
delete
kernel
;
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
op
P
arameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
op
P
arameter
->
type_
));
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
op
_p
arameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
op
_p
arameter
->
type_
));
return
nullptr
;
}
return
kernel
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/fused_batchnorm.h
浏览文件 @
2aaba750
...
...
@@ -40,6 +40,7 @@ class FusedBatchnormCPUKernel : public LiteKernel {
int
Execute
(
int
task_id
);
private:
void
FreeTmpBuffer
();
float
*
in_addr_
=
nullptr
;
float
*
mean_addr_
=
nullptr
;
float
*
var_addr_
=
nullptr
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
浏览文件 @
2aaba750
...
...
@@ -25,20 +25,29 @@ using mindspore::lite::RET_MEMORY_FAILED;
using
mindspore
::
lite
::
RET_OK
;
namespace
mindspore
::
kernel
{
MatmulCPUKernel
::~
MatmulCPUKernel
()
{
ctx_
->
allocator
->
Free
(
a_c8_ptr_
);
ctx_
->
allocator
->
Free
(
b_r8_ptr_
);
ctx_
->
allocator
->
Free
(
c_r8x8_ptr_
);
ctx_
->
allocator
->
Free
(
bias_ptr_
);
}
int
MatmulCPUKernel
::
ReSize
()
{
return
RET_OK
;
}
MatmulCPUKernel
::~
MatmulCPUKernel
()
{
FreeTmpBuffer
();
}
int
MatmulCPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
return
RET_OK
;
void
MatmulCPUKernel
::
FreeTmpBuffer
()
{
if
(
a_c8_ptr_
!=
nullptr
)
{
ctx_
->
allocator
->
Free
(
a_c8_ptr_
);
a_c8_ptr_
=
nullptr
;
}
if
(
b_r8_ptr_
!=
nullptr
)
{
ctx_
->
allocator
->
Free
(
b_r8_ptr_
);
b_r8_ptr_
=
nullptr
;
}
if
(
c_r8x8_ptr_
!=
nullptr
)
{
ctx_
->
allocator
->
Free
(
c_r8x8_ptr_
);
c_r8x8_ptr_
=
nullptr
;
}
if
(
bias_ptr_
!=
nullptr
)
{
ctx_
->
allocator
->
Free
(
bias_ptr_
);
bias_ptr_
=
nullptr
;
}
}
int
MatmulCPUKernel
::
ReSize
()
{
FreeTmpBuffer
();
int
batch
=
1
;
auto
a_shape
=
in_tensors_
[
0
]
->
shape
();
auto
c_shape
=
out_tensors_
[
0
]
->
shape
();
...
...
@@ -63,17 +72,20 @@ int MatmulCPUKernel::Init() {
thread_stride_
=
UP_DIV
(
UP_DIV
(
params_
->
col_8_
,
8
),
thread_count_
);
a_c8_ptr_
=
reinterpret_cast
<
float
*>
(
ctx_
->
allocator
->
Malloc
(
params_
->
row_8_
*
params_
->
deep_
*
sizeof
(
float
)));
if
(
!
a_c8_ptr_
)
{
if
(
a_c8_ptr_
==
nullptr
)
{
FreeTmpBuffer
();
return
RET_MEMORY_FAILED
;
}
memset
(
a_c8_ptr_
,
0
,
params_
->
row_8_
*
params_
->
deep_
*
sizeof
(
float
));
b_r8_ptr_
=
reinterpret_cast
<
float
*>
(
ctx_
->
allocator
->
Malloc
(
params_
->
col_8_
*
params_
->
deep_
*
sizeof
(
float
)));
if
(
!
b_r8_ptr_
)
{
if
(
b_r8_ptr_
==
nullptr
)
{
FreeTmpBuffer
();
return
RET_MEMORY_FAILED
;
}
memset
(
b_r8_ptr_
,
0
,
params_
->
col_8_
*
params_
->
deep_
*
sizeof
(
float
));
c_r8x8_ptr_
=
reinterpret_cast
<
float
*>
(
ctx_
->
allocator
->
Malloc
(
params_
->
row_8_
*
params_
->
col_8_
*
sizeof
(
float
)));
if
(
!
c_r8x8_ptr_
)
{
if
(
c_r8x8_ptr_
==
nullptr
)
{
FreeTmpBuffer
();
return
RET_MEMORY_FAILED
;
}
memset
(
c_r8x8_ptr_
,
0
,
params_
->
row_8_
*
params_
->
col_8_
*
sizeof
(
float
));
...
...
@@ -85,6 +97,10 @@ int MatmulCPUKernel::Init() {
if
(
in_tensors_
.
size
()
==
3
)
{
bias_ptr_
=
reinterpret_cast
<
float
*>
(
malloc
(
params_
->
col_8_
*
sizeof
(
float
)));
if
(
bias_ptr_
==
nullptr
)
{
FreeTmpBuffer
();
return
RET_MEMORY_FAILED
;
}
memset
(
bias_ptr_
,
0
,
params_
->
col_8_
*
sizeof
(
float
));
memcpy
(
bias_ptr_
,
in_tensors_
[
2
]
->
Data
(),
params_
->
col_
*
sizeof
(
float
));
}
else
{
...
...
@@ -128,6 +144,13 @@ void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {
return
;
}
int
MatmulCPUKernel
::
Init
()
{
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
MatmulCPUKernel
::
RunImpl
(
int
task_id
)
{
int
cur_oc
=
MSMIN
(
thread_stride_
,
UP_DIV
(
params_
->
col_8_
,
8
)
-
task_id
*
thread_stride_
);
if
(
cur_oc
<=
0
)
{
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.h
浏览文件 @
2aaba750
...
...
@@ -38,12 +38,13 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel {
private:
void
InitMatrixA
(
float
*
src_ptr
,
float
*
dst_ptr
);
void
InitMatrixB
(
float
*
src_ptr
,
float
*
dst_ptr
);
void
FreeTmpBuffer
();
private:
float
*
a_c8_ptr_
;
float
*
b_r8_ptr_
;
float
*
c_r8x8_ptr_
;
float
*
bias_ptr_
;
float
*
a_c8_ptr_
=
nullptr
;
float
*
b_r8_ptr_
=
nullptr
;
float
*
c_r8x8_ptr_
=
nullptr
;
float
*
bias_ptr_
=
nullptr
;
};
}
// namespace mindspore::kernel
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc
浏览文件 @
2aaba750
...
...
@@ -81,7 +81,13 @@ int ReduceCPUKernel::Init() {
return
ReSize
();
}
int
ReduceCPUKernel
::
ReSize
()
{
return
MallocTmpBuffer
();
}
int
ReduceCPUKernel
::
ReSize
()
{
auto
ret
=
ReduceBaseCPUKernel
::
ReSize
();
if
(
ret
!=
RET_OK
)
{
return
ret
;
}
return
MallocTmpBuffer
();
}
int
ReduceCPUKernel
::
CallReduceUnit
(
int
task_id
)
{
auto
ret
=
reducer_
(
outer_size_
,
inner_size_
,
axis_size_
,
src_data_
,
tmp_shape_
.
data
(),
dst_data_
,
task_id
,
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/reverse_sequence.cc
浏览文件 @
2aaba750
...
...
@@ -24,10 +24,34 @@ using mindspore::schema::PrimitiveType_ReverseSequence;
namespace
mindspore
::
kernel
{
int
ReverseSequenceCPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
void
ReverseSequenceCPUKernel
::
ConvertAxisToPositive
(
const
std
::
vector
<
int
>
shape
,
int
*
axis
)
{
if
(
axis
!=
nullptr
&&
*
axis
<
0
)
{
*
axis
+=
shape
.
size
();
}
}
int
ReverseSequenceCPUKernel
::
CalcCountPreAxis
(
const
std
::
vector
<
int
>
shape
,
int
axis
)
{
int
count
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
count
*=
shape
[
i
];
}
return
count
;
}
int
ReverseSequenceCPUKernel
::
CalcCountAfterAxis
(
const
std
::
vector
<
int
>
shape
,
int
axis
)
{
int
count
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
shape
.
size
();
++
i
)
{
count
*=
shape
[
i
];
}
return
count
;
}
int
ReverseSequenceCPUKernel
::
ReSize
()
{
auto
input0
=
in_tensors_
.
at
(
0
);
auto
input1
=
in_tensors_
.
at
(
1
);
auto
output
=
out_tensors_
.
at
(
0
);
...
...
@@ -64,34 +88,11 @@ int ReverseSequenceCPUKernel::Init() {
return
RET_OK
;
}
void
ReverseSequenceCPUKernel
::
ConvertAxisToPositive
(
const
std
::
vector
<
int
>
shape
,
int
*
axis
)
{
if
(
axis
!=
nullptr
&&
*
axis
<
0
)
{
*
axis
+=
shape
.
size
();
}
}
int
ReverseSequenceCPUKernel
::
CalcCountPreAxis
(
const
std
::
vector
<
int
>
shape
,
int
axis
)
{
int
count
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
count
*=
shape
[
i
];
}
return
count
;
}
int
ReverseSequenceCPUKernel
::
CalcCountAfterAxis
(
const
std
::
vector
<
int
>
shape
,
int
axis
)
{
int
count
=
1
;
for
(
int
i
=
axis
+
1
;
i
<
shape
.
size
();
++
i
)
{
count
*=
shape
[
i
];
}
return
count
;
}
int
ReverseSequenceCPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
ReverseSequenceCPUKernel
::
Run
()
{
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail
ed."
;
return
RET_ERROR
;
MS_LOG
(
ERROR
)
<<
"Prepare fail
!ret: "
<<
ret
;
return
ret
;
}
float
*
input0
=
reinterpret_cast
<
float
*>
(
in_tensors_
.
at
(
0
)
->
Data
());
int
*
input1
=
reinterpret_cast
<
int
*>
(
in_tensors_
.
at
(
1
)
->
Data
());
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/tile.cc
浏览文件 @
2aaba750
...
...
@@ -25,18 +25,10 @@ using mindspore::schema::PrimitiveType_Tile;
namespace
mindspore
::
kernel
{
int
TileCPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
auto
tile_parameter_
=
reinterpret_cast
<
TileParameter
*>
(
op_parameter_
);
for
(
int
i
=
0
;
i
<
tile_parameter_
->
in_dim_
;
++
i
)
{
tile_parameter_
->
in_shape_
[
i
]
=
in_tensors_
[
0
]
->
shape
()[
i
];
tile_parameter_
->
out_shape_
[
i
]
=
out_tensors_
[
0
]
->
shape
()[
i
];
}
ComputeStrides
(
tile_parameter_
->
in_shape_
,
tile_parameter_
->
in_strides_
,
tile_parameter_
->
in_dim_
);
ComputeStrides
(
tile_parameter_
->
out_shape_
,
tile_parameter_
->
out_strides_
,
tile_parameter_
->
in_dim_
);
return
RET_OK
;
return
ReSize
();
}
void
TileCPUKernel
::
ComputeStrides
(
int
*
shape
,
int
*
strides
,
int
ndim
)
{
...
...
@@ -47,13 +39,22 @@ void TileCPUKernel::ComputeStrides(int *shape, int *strides, int ndim) {
}
}
int
TileCPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
TileCPUKernel
::
ReSize
()
{
auto
tile_parameter_
=
reinterpret_cast
<
TileParameter
*>
(
op_parameter_
);
for
(
int
i
=
0
;
i
<
tile_parameter_
->
in_dim_
;
++
i
)
{
tile_parameter_
->
in_shape_
[
i
]
=
in_tensors_
[
0
]
->
shape
()[
i
];
tile_parameter_
->
out_shape_
[
i
]
=
out_tensors_
[
0
]
->
shape
()[
i
];
}
ComputeStrides
(
tile_parameter_
->
in_shape_
,
tile_parameter_
->
in_strides_
,
tile_parameter_
->
in_dim_
);
ComputeStrides
(
tile_parameter_
->
out_shape_
,
tile_parameter_
->
out_strides_
,
tile_parameter_
->
in_dim_
);
return
RET_OK
;
}
int
TileCPUKernel
::
Run
()
{
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail
ed."
;
return
RET_ERROR
;
MS_LOG
(
ERROR
)
<<
"Prepare fail
!ret: "
<<
ret
;
return
ret
;
}
auto
input_addr
=
reinterpret_cast
<
float
*>
(
in_tensors_
.
at
(
0
)
->
Data
());
auto
output_addr
=
reinterpret_cast
<
float
*>
(
out_tensors_
.
at
(
0
)
->
Data
());
...
...
mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.cc
浏览文件 @
2aaba750
...
...
@@ -88,27 +88,24 @@ int PadInt8CPUKernel::InitPadParam() {
}
int
PadInt8CPUKernel
::
ReSize
()
{
InitPadParam
();
return
RET_OK
;
}
int
PadInt8CPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
return
RET_OK
;
}
int
error_code
=
InitPadParam
();
if
(
error_code
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"InitPadParam failed. errorcode: "
<<
error_code
;
return
error_code
;
}
return
RET_OK
;
}
error_code
=
SetQuantParam
();
int
PadInt8CPUKernel
::
Init
()
{
auto
error_code
=
SetQuantParam
();
if
(
error_code
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"SetQuantParam failed. errorcode: "
<<
error_code
;
return
error_code
;
}
return
RET_OK
;
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
PadInt8CPUKernel
::
RunImpl
(
int
task_id
)
{
...
...
@@ -128,8 +125,8 @@ int PadInt8Impl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
int
PadInt8CPUKernel
::
Run
()
{
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail
ed."
;
return
RET_ERROR
;
MS_LOG
(
ERROR
)
<<
"Prepare fail
!ret: "
<<
ret
;
return
ret
;
}
in_data_
=
reinterpret_cast
<
int8_t
*>
(
in_tensors_
[
0
]
->
Data
());
out_data_
=
reinterpret_cast
<
int8_t
*>
(
out_tensors_
[
0
]
->
Data
());
...
...
mindspore/lite/src/runtime/kernel/arm/int8/pad_int8.h
浏览文件 @
2aaba750
...
...
@@ -46,9 +46,9 @@ class PadInt8CPUKernel : public LiteKernel {
void
FreeQuantParam
();
private:
PadParameter
*
pad_param_
;
int8_t
*
in_data_
;
int8_t
*
out_data_
;
PadParameter
*
pad_param_
=
nullptr
;
int8_t
*
in_data_
=
nullptr
;
int8_t
*
out_data_
=
nullptr
;
int
in_dims_
[
DEFAULT_PAD_NDIMS
];
int
out_dims_
[
DEFAULT_PAD_NDIMS
];
};
...
...
mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc
浏览文件 @
2aaba750
...
...
@@ -26,10 +26,6 @@ using mindspore::lite::RET_OK;
namespace
mindspore
::
kernel
{
int
PoolingInt8CPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
return
RET_OK
;
}
auto
ret
=
PoolingBaseCPUKernel
::
Init
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"PoolingBase Init failed."
;
...
...
mindspore/lite/src/runtime/kernel/arm/int8/prelu_int8.cc
浏览文件 @
2aaba750
...
...
@@ -29,34 +29,43 @@ using mindspore::schema::PrimitiveType_Prelu;
namespace
mindspore
::
kernel
{
int
PreluInt8CPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
return
RET_OK
;
}
PreluBaseCPUKernel
::
Init
();
auto
*
input_tensor
=
in_tensors_
.
at
(
kInputIndex
);
auto
in_quant_args
=
input_tensor
->
GetQuantParams
();
quant_prelu_parm_
->
quant_arg
.
in_args_
.
scale_
=
in_quant_args
.
front
().
scale
;
quant_prelu_parm_
->
quant_arg
.
in_args_
.
zp_
=
in_quant_args
.
front
().
zeroPoint
;
auto
input_dim
=
input_tensor
->
shape
().
size
();
MS_ASSERT
(
input_dim
<=
CROP_OFFSET_MAX_SIZE
);
quant_prelu_parm_
->
input_dim_
=
input_dim
;
quant_prelu_parm_
->
element_num
=
in_tensors_
[
0
]
->
Size
();
auto
*
out_tensor
=
out_tensors_
.
at
(
kOutputIndex
);
auto
out_quant_args
=
out_tensor
->
GetQuantParams
();
quant_prelu_parm_
->
quant_arg
.
out_args_
.
scale_
=
out_quant_args
.
front
().
scale
;
quant_prelu_parm_
->
quant_arg
.
out_args_
.
zp_
=
out_quant_args
.
front
().
zeroPoint
;
quant_prelu_parm_
->
in_shape_
=
input_tensor
->
shape
().
data
();
quant_prelu_parm_
->
out_shape_
=
out_tensor
->
shape
().
data
();
quant_prelu_parm_
->
quant_arg
.
output_activation_max_
=
std
::
numeric_limits
<
int8_t
>::
max
();
quant_prelu_parm_
->
quant_arg
.
output_activation_min_
=
std
::
numeric_limits
<
int8_t
>::
min
();
return
RET_OK
;
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
PreluInt8CPUKernel
::
ReSize
()
{
return
0
;
}
int
PreluInt8CPUKernel
::
ReSize
()
{
auto
*
input_tensor
=
in_tensors_
.
at
(
kInputIndex
);
auto
*
out_tensor
=
out_tensors_
.
at
(
kOutputIndex
);
auto
input_dim
=
input_tensor
->
shape
().
size
();
MS_ASSERT
(
input_dim
<=
CROP_OFFSET_MAX_SIZE
);
quant_prelu_parm_
->
input_dim_
=
input_dim
;
quant_prelu_parm_
->
element_num
=
in_tensors_
[
0
]
->
Size
();
quant_prelu_parm_
->
in_shape_
=
input_tensor
->
shape
().
data
();
quant_prelu_parm_
->
out_shape_
=
out_tensor
->
shape
().
data
();
}
int
PreluInt8CPUKernel
::
Run
()
{
auto
ret
=
LiteBackendParallelLaunch
(
PreluInt8Run
,
this
,
quant_prelu_parm_
->
op_parameter_
.
thread_num_
);
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail!ret: "
<<
ret
;
return
ret
;
}
ret
=
LiteBackendParallelLaunch
(
PreluInt8Run
,
this
,
quant_prelu_parm_
->
op_parameter_
.
thread_num_
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"RunPreluParam failed. errorcode: "
;
}
...
...
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc
浏览文件 @
2aaba750
...
...
@@ -25,11 +25,18 @@ using mindspore::schema::PrimitiveType_TopK;
namespace
mindspore
::
kernel
{
int
TopKInt8CPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
TopKInt8CPUKernel
::
ReSize
()
{
TopkParameter
*
parameter
=
reinterpret_cast
<
TopkParameter
*>
(
op_parameter_
);
if
(
parameter
->
topk_node_list_
!=
nullptr
)
{
free
(
parameter
->
topk_node_list_
);
parameter
->
topk_node_list_
=
nullptr
;
}
lite
::
tensor
::
Tensor
*
input
=
in_tensors_
.
at
(
0
);
parameter
->
last_dim_size_
=
input
->
shape
()[
input
->
shape
().
size
()
-
1
];
parameter
->
loop_num_
=
1
;
...
...
@@ -45,8 +52,6 @@ int TopKInt8CPUKernel::Init() {
return
RET_OK
;
}
int
TopKInt8CPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
TopKInt8CPUKernel
::
Run
()
{
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
...
...
@@ -65,7 +70,11 @@ kernel::LiteKernel *CpuTopKInt8KernelCreator(const std::vector<lite::tensor::Ten
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
parameter
,
const
lite
::
Context
*
ctx
,
const
KernelKey
&
desc
,
const
mindspore
::
lite
::
PrimitiveC
*
primitive
)
{
MS_ASSERT
(
parameter
!=
nullptr
);
if
(
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"input parameter is nullptr!"
;
return
nullptr
;
}
TopKInt8CPUKernel
*
kernel
=
new
(
std
::
nothrow
)
TopKInt8CPUKernel
(
parameter
,
inputs
,
outputs
,
ctx
,
primitive
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new TopKInt8CPUKernel fail!"
;
...
...
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h
浏览文件 @
2aaba750
...
...
@@ -26,17 +26,21 @@ class TopKInt8CPUKernel : public LiteKernel {
explicit
TopKInt8CPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
const
lite
::
Context
*
ctx
,
const
mindspore
::
lite
::
PrimitiveC
*
primitive
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
,
ctx
,
primitive
)
{}
:
LiteKernel
(
parameter
,
inputs
,
outputs
,
ctx
,
primitive
)
{
TopkParameter
*
param
=
reinterpret_cast
<
TopkParameter
*>
(
op_parameter_
);
param
->
topk_node_list_
=
nullptr
;
}
~
TopKInt8CPUKernel
()
override
{
TopkParameter
*
parameter
=
reinterpret_cast
<
TopkParameter
*>
(
op_parameter_
);
free
(
parameter
->
topk_node_list_
);
if
(
parameter
->
topk_node_list_
!=
nullptr
)
{
free
(
parameter
->
topk_node_list_
);
parameter
->
topk_node_list_
=
nullptr
;
}
}
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
private:
};
}
// namespace mindspore::kernel
...
...
mindspore/lite/src/runtime/kernel/arm/int8/unsqueeze_int8.cc
浏览文件 @
2aaba750
...
...
@@ -29,10 +29,6 @@ using mindspore::schema::PrimitiveType_Unsqueeze;
namespace
mindspore
::
kernel
{
int
Unsqueezeint8CPUKernel
::
Init
()
{
if
(
context_
->
infer_shape_interrupt_
&&
!
context_
->
running_
)
{
set_need_reinit
();
return
RET_OK
;
}
auto
*
input_tensor
=
in_tensors_
.
at
(
0
);
auto
quant_args
=
input_tensor
->
GetQuantParams
();
MS_ASSERT
(
quant_args
.
size
()
==
1
);
...
...
@@ -43,9 +39,10 @@ int Unsqueezeint8CPUKernel::Init() {
Unsq_para_
->
quant_arg
.
out_quant_args_
.
scale_
=
out_quant_args
.
front
().
scale
;
Unsq_para_
->
quant_arg
.
out_quant_args_
.
zp_
=
out_quant_args
.
front
().
zeroPoint
;
Unsq_para_
->
thread_count_
=
thread_count_
;
int
ret
=
ReSize
();
return
ret
;
if
(
!
InferShapeDone
())
{
return
RET_OK
;
}
return
ReSize
();
}
int
Unsqueezeint8CPUKernel
::
ReSize
()
{
...
...
@@ -86,7 +83,7 @@ int UnsqueezeIn8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
int
Unsqueezeint8CPUKernel
::
Run
()
{
auto
ret
=
Prepare
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Prepare fail
ed."
;
MS_LOG
(
ERROR
)
<<
"Prepare fail
!ret: "
<<
ret
;
return
ret
;
}
in_ptr_
=
reinterpret_cast
<
float
*>
(
in_tensors_
.
at
(
0
)
->
Data
());
...
...
mindspore/lite/src/scheduler.cc
浏览文件 @
2aaba750
...
...
@@ -82,6 +82,7 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
MS_EXCEPTION_IF_NULL
(
tensors
);
auto
meta_graph
=
model
->
GetMetaGraph
();
MS_EXCEPTION_IF_NULL
(
meta_graph
);
bool
infer_shape_interrupt
=
false
;
uint32_t
kernelCount
=
meta_graph
->
nodes
()
->
size
();
for
(
uint32_t
i
=
0
;
i
<
kernelCount
;
i
++
)
{
auto
cNode
=
meta_graph
->
nodes
()
->
GetAs
<
schema
::
CNode
>
(
i
);
...
...
@@ -101,27 +102,18 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<tensor::Tensor *
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
());
return
RET_ERROR
;
}
if
(
!
context_
->
infer_shape_interrupt_
)
{
auto
ret
=
primitive
->
InferShape
(
inputs
,
outputs
);
if
(
ret
==
RET_INFER_INVALID
)
{
MS_LOG
(
INFO
)
<<
"InferShape shouldn't be done before runtime, name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
())
<<
"flag set to false."
;
primitive
->
SetInferFlag
(
false
);
context_
->
InferShapeInterrupt
();
}
else
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"InferShape failed, name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
());
return
RET_INFER_ERR
;
}
}
else
{
primitive
->
SetInferFlag
(
!
infer_shape_interrupt
);
auto
ret
=
primitive
->
InferShape
(
inputs
,
outputs
);
if
(
ret
==
RET_INFER_INVALID
)
{
MS_LOG
(
INFO
)
<<
"InferShape shouldn't be done before runtime, name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
())
<<
"flag set to false."
;
primitive
->
SetInferFlag
(
false
);
auto
ret
=
primitive
->
InferShape
(
inputs
,
outputs
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"InferShape fail! name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
());
return
RET_INFER_ERR
;
}
infer_shape_interrupt
=
true
;
}
else
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"InferShape failed, name: "
<<
cNode
->
name
()
->
str
()
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
cNode
->
primitive
()
->
value_type
());
return
RET_INFER_ERR
;
}
}
return
RET_OK
;
...
...
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc
浏览文件 @
2aaba750
...
...
@@ -76,6 +76,40 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1_keep_dim) {
CompareOutputData
(
out
,
except_out
.
data
(),
except_out
.
size
(),
0.000001
);
}
TEST_F
(
TestArgMinMaxTestFp32
,
ArgMaxTest_axis2_keep_dim
)
{
std
::
vector
<
float
>
in
=
{
10
,
20
,
30
,
11
,
15
,
10
,
5
,
10
,
12
,
10
,
20
,
30
,
11
,
15
,
10
,
5
,
10
,
12
,
10
,
20
,
30
,
11
,
15
,
10
,
5
,
10
,
12
};
std
::
vector
<
float
>
except_out
=
{
1
,
0
,
0
,
1
,
0
,
0
,
1
,
0
,
0
};
std
::
vector
<
int
>
shape
=
{
1
,
3
,
3
,
3
};
float
out
[
9
];
ArgMinMaxParameter
param
;
param
.
topk_
=
1
;
param
.
out_value_
=
false
;
param
.
axis_
=
2
;
param
.
data_type_
=
43
;
param
.
dims_size_
=
4
;
param
.
get_max_
=
true
;
param
.
keep_dims_
=
true
;
param
.
arg_elements_
=
reinterpret_cast
<
ArgElement
*>
(
malloc
(
shape
[
param
.
axis_
]
*
sizeof
(
ArgElement
)));
std
::
vector
<
int
>
out_shape
=
{
1
,
3
,
1
,
3
};
ComputeStrides
(
shape
.
data
(),
param
.
in_strides_
,
shape
.
size
());
ComputeStrides
(
out_shape
.
data
(),
param
.
out_strides_
,
out_shape
.
size
());
ArgMinMax
(
in
.
data
(),
out
,
shape
.
data
(),
&
param
);
for
(
size_t
i
=
0
;
i
<
except_out
.
size
();
++
i
)
{
std
::
cout
<<
out
[
i
]
<<
" "
;
}
std
::
cout
<<
"
\n
"
;
CompareOutputData
(
out
,
except_out
.
data
(),
except_out
.
size
(),
0.000001
);
}
TEST_F
(
TestArgMinMaxTestFp32
,
ArgMaxTest2
)
{
std
::
vector
<
float
>
in
=
{
10
,
20
,
30
,
40
,
90
,
20
,
11
,
15
,
1
,
50
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录