Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ecb87385
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ecb87385
编写于
8月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3847 modify populate op parameter
Merge pull request !3847 from yangruoqi713/lite
上级
49ba473b
c65d63a4
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
654 addition
and
631 deletion
+654
-631
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+638
-628
mindspore/lite/src/populate_parameter.h
mindspore/lite/src/populate_parameter.h
+14
-1
mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h
...re/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h
+1
-1
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h
+1
-1
未找到文件。
mindspore/lite/src/populate_parameter.cc
浏览文件 @
ecb87385
...
@@ -69,320 +69,339 @@
...
@@ -69,320 +69,339 @@
#include "src/runtime/kernel/arm/opclib/fp32/quantize.h"
#include "src/runtime/kernel/arm/opclib/fp32/quantize.h"
namespace
mindspore
::
kernel
{
namespace
mindspore
::
kernel
{
FillParameter
*
PopulateFillParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateFillParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
param
=
primitive
->
Value
()
->
value_as_Fill
();
auto
param
=
primitive
->
Value
()
->
value_as_Fill
();
FillParameter
*
parameter
=
new
(
std
::
nothrow
)
FillParameter
();
FillParameter
*
fill_param
=
new
(
std
::
nothrow
)
FillParameter
();
if
(
parameter
==
nullptr
)
{
if
(
fill_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new FillParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new FillParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
fill_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
flatDims
=
param
->
dims
();
auto
flatDims
=
param
->
dims
();
parameter
->
num_dims_
=
flatDims
->
size
();
fill_param
->
num_dims_
=
flatDims
->
size
();
int
i
=
0
;
int
i
=
0
;
for
(
auto
iter
=
flatDims
->
begin
();
iter
!=
flatDims
->
end
();
iter
++
)
{
for
(
auto
iter
=
flatDims
->
begin
();
iter
!=
flatDims
->
end
();
iter
++
)
{
parameter
->
dims_
[
i
++
]
=
*
iter
;
fill_param
->
dims_
[
i
++
]
=
*
iter
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
fill_param
)
;
}
}
ExpandDimsParameter
*
PopulateExpandDimsParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateExpandDimsParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
param
=
primitive
->
Value
()
->
value_as_ExpandDims
();
auto
param
=
primitive
->
Value
()
->
value_as_ExpandDims
();
ExpandDimsParameter
*
parameter
=
new
(
std
::
nothrow
)
ExpandDimsParameter
();
ExpandDimsParameter
*
expand_dims_param
=
new
(
std
::
nothrow
)
ExpandDimsParameter
();
if
(
parameter
==
nullptr
)
{
if
(
expand_dims_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ExpandDimsParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ExpandDimsParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
dim_
=
param
->
dim
();
expand_dims_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
parameter
;
expand_dims_param
->
dim_
=
param
->
dim
();
return
reinterpret_cast
<
OpParameter
*>
(
expand_dims_param
);
}
}
PoolingParameter
*
PopulatePoolingParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulatePoolingParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
pooling_primitive
=
primitive
->
Value
()
->
value_as_Pooling
();
auto
pooling_primitive
=
primitive
->
Value
()
->
value_as_Pooling
();
// todo use malloc instead
// todo use malloc instead
PoolingParameter
*
p
arameter
=
new
(
std
::
nothrow
)
PoolingParameter
();
PoolingParameter
*
p
ooling_param
=
new
(
std
::
nothrow
)
PoolingParameter
();
if
(
p
arameter
==
nullptr
)
{
if
(
p
ooling_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PoolingParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new PoolingParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
global_
=
pooling_primitive
->
global
();
pooling_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
window_w_
=
pooling_primitive
->
windowW
();
pooling_param
->
global_
=
pooling_primitive
->
global
();
parameter
->
window_h_
=
pooling_primitive
->
windowH
();
pooling_param
->
window_w_
=
pooling_primitive
->
windowW
();
pooling_param
->
window_h_
=
pooling_primitive
->
windowH
();
// todo format
// todo format
auto
pooling_lite_primitive
=
(
lite
::
Pooling
*
)
primitive
;
auto
pooling_lite_primitive
=
(
lite
::
Pooling
*
)
primitive
;
MS_ASSERT
(
nullptr
!=
pooling_lite_primitive
);
MS_ASSERT
(
nullptr
!=
pooling_lite_primitive
);
p
arameter
->
pad_u_
=
pooling_lite_primitive
->
PadUp
();
p
ooling_param
->
pad_u_
=
pooling_lite_primitive
->
PadUp
();
p
arameter
->
pad_d_
=
pooling_lite_primitive
->
PadDown
();
p
ooling_param
->
pad_d_
=
pooling_lite_primitive
->
PadDown
();
p
arameter
->
pad_l_
=
pooling_lite_primitive
->
PadLeft
();
p
ooling_param
->
pad_l_
=
pooling_lite_primitive
->
PadLeft
();
p
arameter
->
pad_r_
=
pooling_lite_primitive
->
PadRight
();
p
ooling_param
->
pad_r_
=
pooling_lite_primitive
->
PadRight
();
p
arameter
->
stride_w_
=
pooling_primitive
->
strideW
();
p
ooling_param
->
stride_w_
=
pooling_primitive
->
strideW
();
p
arameter
->
stride_h_
=
pooling_primitive
->
strideH
();
p
ooling_param
->
stride_h_
=
pooling_primitive
->
strideH
();
auto
pool_mode
=
pooling_primitive
->
poolingMode
();
auto
pool_mode
=
pooling_primitive
->
poolingMode
();
switch
(
pool_mode
)
{
switch
(
pool_mode
)
{
case
schema
::
PoolMode_MAX_POOLING
:
case
schema
::
PoolMode_MAX_POOLING
:
p
arameter
->
max_pooling_
=
true
;
p
ooling_param
->
max_pooling_
=
true
;
p
arameter
->
avg_pooling_
=
false
;
p
ooling_param
->
avg_pooling_
=
false
;
break
;
break
;
case
schema
::
PoolMode_MEAN_POOLING
:
case
schema
::
PoolMode_MEAN_POOLING
:
p
arameter
->
max_pooling_
=
false
;
p
ooling_param
->
max_pooling_
=
false
;
p
arameter
->
avg_pooling_
=
true
;
p
ooling_param
->
avg_pooling_
=
true
;
break
;
break
;
default:
default:
p
arameter
->
max_pooling_
=
false
;
p
ooling_param
->
max_pooling_
=
false
;
p
arameter
->
avg_pooling_
=
false
;
p
ooling_param
->
avg_pooling_
=
false
;
break
;
break
;
}
}
auto
round_mode
=
pooling_primitive
->
roundMode
();
auto
round_mode
=
pooling_primitive
->
roundMode
();
switch
(
round_mode
)
{
switch
(
round_mode
)
{
case
schema
::
RoundMode_FLOOR
:
case
schema
::
RoundMode_FLOOR
:
p
arameter
->
round_floor_
=
true
;
p
ooling_param
->
round_floor_
=
true
;
p
arameter
->
round_ceil_
=
false
;
p
ooling_param
->
round_ceil_
=
false
;
break
;
break
;
case
schema
::
RoundMode_CEIL
:
case
schema
::
RoundMode_CEIL
:
p
arameter
->
round_floor_
=
false
;
p
ooling_param
->
round_floor_
=
false
;
p
arameter
->
round_ceil_
=
true
;
p
ooling_param
->
round_ceil_
=
true
;
break
;
break
;
default:
default:
p
arameter
->
round_floor_
=
false
;
p
ooling_param
->
round_floor_
=
false
;
p
arameter
->
round_ceil_
=
false
;
p
ooling_param
->
round_ceil_
=
false
;
break
;
break
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
pooling_param
)
;
}
}
MatMul
Parameter
*
PopulateFullconnectionParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateFullconnectionParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
param
=
primitive
->
Value
()
->
value_as_FullConnection
();
auto
param
=
primitive
->
Value
()
->
value_as_FullConnection
();
MatMulParameter
*
parameter
=
new
(
std
::
nothrow
)
MatMulParameter
();
MatMulParameter
*
matmul_param
=
new
(
std
::
nothrow
)
MatMulParameter
();
if
(
parameter
==
nullptr
)
{
if
(
matmul_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new FullconnectionParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new FullconnectionParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
b_transpose_
=
true
;
matmul_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
a_transpose_
=
false
;
matmul_param
->
b_transpose_
=
true
;
parameter
->
has_bias_
=
param
->
hasBias
();
matmul_param
->
a_transpose_
=
false
;
parameter
->
minf_
=
-
FLT_MAX
;
matmul_param
->
has_bias_
=
param
->
hasBias
();
parameter
->
maxf_
=
FLT_MAX
;
matmul_param
->
minf_
=
-
FLT_MAX
;
return
parameter
;
matmul_param
->
maxf_
=
FLT_MAX
;
return
reinterpret_cast
<
OpParameter
*>
(
matmul_param
);
}
}
MatMul
Parameter
*
PopulateMatMulParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateMatMulParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
param
=
primitive
->
Value
()
->
value_as_MatMul
();
auto
param
=
primitive
->
Value
()
->
value_as_MatMul
();
MatMulParameter
*
parameter
=
new
(
std
::
nothrow
)
MatMulParameter
();
MatMulParameter
*
matmul_param
=
new
(
std
::
nothrow
)
MatMulParameter
();
if
(
parameter
==
nullptr
)
{
if
(
matmul_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new FullconnectionParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new FullconnectionParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
b_transpose_
=
param
->
transposeB
();
matmul_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
a_transpose_
=
param
->
transposeA
();
matmul_param
->
b_transpose_
=
param
->
transposeB
();
parameter
->
has_bias_
=
false
;
matmul_param
->
a_transpose_
=
param
->
transposeA
();
parameter
->
minf_
=
-
FLT_MAX
;
matmul_param
->
has_bias_
=
false
;
parameter
->
maxf_
=
FLT_MAX
;
matmul_param
->
minf_
=
-
FLT_MAX
;
return
parameter
;
matmul_param
->
maxf_
=
FLT_MAX
;
return
reinterpret_cast
<
OpParameter
*>
(
matmul_param
);
}
}
Conv
Parameter
*
PopulateConvParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateConvParameter
(
const
lite
::
Primitive
*
primitive
)
{
ConvParameter
*
parameter
=
new
(
std
::
nothrow
)
ConvParameter
();
ConvParameter
*
conv_param
=
new
(
std
::
nothrow
)
ConvParameter
();
if
(
parameter
==
nullptr
)
{
if
(
conv_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConvParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ConvParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
conv_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_Conv2D
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_Conv2D
();
parameter
->
kernel_h_
=
conv_primitive
->
kernelH
();
conv_param
->
kernel_h_
=
conv_primitive
->
kernelH
();
parameter
->
kernel_w_
=
conv_primitive
->
kernelW
();
conv_param
->
kernel_w_
=
conv_primitive
->
kernelW
();
// todo format
// todo format
parameter
->
group_
=
conv_primitive
->
group
();
conv_param
->
group_
=
conv_primitive
->
group
();
parameter
->
stride_h_
=
conv_primitive
->
strideH
();
conv_param
->
stride_h_
=
conv_primitive
->
strideH
();
parameter
->
stride_w_
=
conv_primitive
->
strideW
();
conv_param
->
stride_w_
=
conv_primitive
->
strideW
();
auto
conv2d_lite_primitive
=
(
lite
::
Conv2D
*
)
primitive
;
auto
conv2d_lite_primitive
=
(
lite
::
Conv2D
*
)
primitive
;
MS_ASSERT
(
nullptr
!=
conv2d_lite_primitive
);
MS_ASSERT
(
nullptr
!=
conv2d_lite_primitive
);
parameter
->
pad_u_
=
conv2d_lite_primitive
->
PadUp
();
conv_param
->
pad_u_
=
conv2d_lite_primitive
->
PadUp
();
parameter
->
pad_d_
=
conv2d_lite_primitive
->
PadDown
();
conv_param
->
pad_d_
=
conv2d_lite_primitive
->
PadDown
();
parameter
->
pad_l_
=
conv2d_lite_primitive
->
PadLeft
();
conv_param
->
pad_l_
=
conv2d_lite_primitive
->
PadLeft
();
parameter
->
pad_r_
=
conv2d_lite_primitive
->
PadRight
();
conv_param
->
pad_r_
=
conv2d_lite_primitive
->
PadRight
();
parameter
->
pad_h_
=
conv2d_lite_primitive
->
PadUp
();
conv_param
->
pad_h_
=
conv2d_lite_primitive
->
PadUp
();
parameter
->
pad_w_
=
conv2d_lite_primitive
->
PadLeft
();
conv_param
->
pad_w_
=
conv2d_lite_primitive
->
PadLeft
();
parameter
->
dilation_h_
=
conv_primitive
->
dilateH
();
conv_param
->
dilation_h_
=
conv_primitive
->
dilateH
();
parameter
->
dilation_w_
=
conv_primitive
->
dilateW
();
conv_param
->
dilation_w_
=
conv_primitive
->
dilateW
();
parameter
->
input_channel_
=
conv_primitive
->
channelIn
();
conv_param
->
input_channel_
=
conv_primitive
->
channelIn
();
parameter
->
output_channel_
=
conv_primitive
->
channelOut
();
conv_param
->
output_channel_
=
conv_primitive
->
channelOut
();
parameter
->
group_
=
conv_primitive
->
group
();
conv_param
->
group_
=
conv_primitive
->
group
();
auto
act_type
=
conv_primitive
->
activationType
();
auto
act_type
=
conv_primitive
->
activationType
();
switch
(
act_type
)
{
switch
(
act_type
)
{
case
schema
::
ActivationType_RELU
:
case
schema
::
ActivationType_RELU
:
parameter
->
is_relu_
=
true
;
conv_param
->
is_relu_
=
true
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
case
schema
::
ActivationType_RELU6
:
case
schema
::
ActivationType_RELU6
:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
true
;
conv_param
->
is_relu6_
=
true
;
break
;
break
;
default:
default:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
conv_param
)
;
}
}
Conv
Parameter
*
PopulateConvDwParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateConvDwParameter
(
const
lite
::
Primitive
*
primitive
)
{
ConvParameter
*
parameter
=
new
(
std
::
nothrow
)
ConvParameter
();
ConvParameter
*
conv_param
=
new
(
std
::
nothrow
)
ConvParameter
();
if
(
parameter
==
nullptr
)
{
if
(
conv_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConvParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ConvParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
conv_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_DepthwiseConv2D
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_DepthwiseConv2D
();
parameter
->
kernel_h_
=
conv_primitive
->
kernelH
();
conv_param
->
kernel_h_
=
conv_primitive
->
kernelH
();
parameter
->
kernel_w_
=
conv_primitive
->
kernelW
();
conv_param
->
kernel_w_
=
conv_primitive
->
kernelW
();
// todo format, group
// todo format, group
parameter
->
stride_h_
=
conv_primitive
->
strideH
();
conv_param
->
stride_h_
=
conv_primitive
->
strideH
();
parameter
->
stride_w_
=
conv_primitive
->
strideW
();
conv_param
->
stride_w_
=
conv_primitive
->
strideW
();
auto
pad_mode
=
conv_primitive
->
padMode
();
auto
pad_mode
=
conv_primitive
->
padMode
();
auto
convdw_lite_primitive
=
(
lite
::
DepthwiseConv2D
*
)
primitive
;
auto
convdw_lite_primitive
=
(
lite
::
DepthwiseConv2D
*
)
primitive
;
MS_ASSERT
(
nullptr
!=
convdw_lite_primitive
);
MS_ASSERT
(
nullptr
!=
convdw_lite_primitive
);
parameter
->
pad_u_
=
convdw_lite_primitive
->
PadUp
();
conv_param
->
pad_u_
=
convdw_lite_primitive
->
PadUp
();
parameter
->
pad_d_
=
convdw_lite_primitive
->
PadDown
();
conv_param
->
pad_d_
=
convdw_lite_primitive
->
PadDown
();
parameter
->
pad_l_
=
convdw_lite_primitive
->
PadLeft
();
conv_param
->
pad_l_
=
convdw_lite_primitive
->
PadLeft
();
parameter
->
pad_r_
=
convdw_lite_primitive
->
PadRight
();
conv_param
->
pad_r_
=
convdw_lite_primitive
->
PadRight
();
parameter
->
pad_h_
=
convdw_lite_primitive
->
PadUp
();
conv_param
->
pad_h_
=
convdw_lite_primitive
->
PadUp
();
parameter
->
pad_w_
=
convdw_lite_primitive
->
PadLeft
();
conv_param
->
pad_w_
=
convdw_lite_primitive
->
PadLeft
();
parameter
->
dilation_h_
=
conv_primitive
->
dilateH
();
conv_param
->
dilation_h_
=
conv_primitive
->
dilateH
();
parameter
->
dilation_w_
=
conv_primitive
->
dilateW
();
conv_param
->
dilation_w_
=
conv_primitive
->
dilateW
();
auto
act_type
=
conv_primitive
->
activationType
();
auto
act_type
=
conv_primitive
->
activationType
();
switch
(
act_type
)
{
switch
(
act_type
)
{
case
schema
::
ActivationType_RELU
:
case
schema
::
ActivationType_RELU
:
parameter
->
is_relu_
=
true
;
conv_param
->
is_relu_
=
true
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
case
schema
::
ActivationType_RELU6
:
case
schema
::
ActivationType_RELU6
:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
true
;
conv_param
->
is_relu6_
=
true
;
break
;
break
;
default:
default:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
conv_param
)
;
}
}
ConvParameter
*
PopulateDeconvDwParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateDeconvDwParameter
(
const
lite
::
Primitive
*
primitive
)
{
ConvParameter
*
parameter
=
new
ConvParameter
();
ConvParameter
*
conv_param
=
new
ConvParameter
();
if
(
conv_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConvParameter failed."
;
return
nullptr
;
}
conv_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_DeDepthwiseConv2D
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_DeDepthwiseConv2D
();
parameter
->
kernel_h_
=
conv_primitive
->
kernelH
();
conv_param
->
kernel_h_
=
conv_primitive
->
kernelH
();
parameter
->
kernel_w_
=
conv_primitive
->
kernelW
();
conv_param
->
kernel_w_
=
conv_primitive
->
kernelW
();
// todo format, group
// todo format, group
parameter
->
stride_h_
=
conv_primitive
->
strideH
();
conv_param
->
stride_h_
=
conv_primitive
->
strideH
();
parameter
->
stride_w_
=
conv_primitive
->
strideW
();
conv_param
->
stride_w_
=
conv_primitive
->
strideW
();
auto
deconvdw_lite_primitive
=
(
lite
::
DeconvDepthwiseConv2D
*
)
primitive
;
auto
deconvdw_lite_primitive
=
(
lite
::
DeconvDepthwiseConv2D
*
)
primitive
;
MS_ASSERT
(
nullptr
!=
deconvdw_lite_primitive
);
MS_ASSERT
(
nullptr
!=
deconvdw_lite_primitive
);
parameter
->
pad_u_
=
deconvdw_lite_primitive
->
PadUp
();
conv_param
->
pad_u_
=
deconvdw_lite_primitive
->
PadUp
();
parameter
->
pad_d_
=
deconvdw_lite_primitive
->
PadDown
();
conv_param
->
pad_d_
=
deconvdw_lite_primitive
->
PadDown
();
parameter
->
pad_l_
=
deconvdw_lite_primitive
->
PadLeft
();
conv_param
->
pad_l_
=
deconvdw_lite_primitive
->
PadLeft
();
parameter
->
pad_r_
=
deconvdw_lite_primitive
->
PadRight
();
conv_param
->
pad_r_
=
deconvdw_lite_primitive
->
PadRight
();
parameter
->
pad_h_
=
deconvdw_lite_primitive
->
PadUp
();
conv_param
->
pad_h_
=
deconvdw_lite_primitive
->
PadUp
();
parameter
->
pad_w_
=
deconvdw_lite_primitive
->
PadLeft
();
conv_param
->
pad_w_
=
deconvdw_lite_primitive
->
PadLeft
();
parameter
->
dilation_h_
=
conv_primitive
->
dilateH
();
conv_param
->
dilation_h_
=
conv_primitive
->
dilateH
();
parameter
->
dilation_w_
=
conv_primitive
->
dilateW
();
conv_param
->
dilation_w_
=
conv_primitive
->
dilateW
();
auto
act_type
=
conv_primitive
->
activationType
();
auto
act_type
=
conv_primitive
->
activationType
();
switch
(
act_type
)
{
switch
(
act_type
)
{
case
schema
::
ActivationType_RELU
:
case
schema
::
ActivationType_RELU
:
parameter
->
is_relu_
=
true
;
conv_param
->
is_relu_
=
true
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
case
schema
::
ActivationType_RELU6
:
case
schema
::
ActivationType_RELU6
:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
true
;
conv_param
->
is_relu6_
=
true
;
break
;
break
;
default:
default:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
conv_param
)
;
}
}
ConvParameter
*
PopulateDeconvParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateDeconvParameter
(
const
lite
::
Primitive
*
primitive
)
{
ConvParameter
*
parameter
=
new
ConvParameter
();
ConvParameter
*
conv_param
=
new
ConvParameter
();
if
(
conv_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConvParameter failed."
;
return
nullptr
;
}
conv_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_DeConv2D
();
auto
conv_primitive
=
primitive
->
Value
()
->
value_as_DeConv2D
();
parameter
->
kernel_h_
=
conv_primitive
->
kernelH
();
conv_param
->
kernel_h_
=
conv_primitive
->
kernelH
();
parameter
->
kernel_w_
=
conv_primitive
->
kernelW
();
conv_param
->
kernel_w_
=
conv_primitive
->
kernelW
();
parameter
->
stride_h_
=
conv_primitive
->
strideH
();
conv_param
->
stride_h_
=
conv_primitive
->
strideH
();
parameter
->
stride_w_
=
conv_primitive
->
strideW
();
conv_param
->
stride_w_
=
conv_primitive
->
strideW
();
auto
deconv_lite_primitive
=
(
lite
::
DeConv2D
*
)
primitive
;
auto
deconv_lite_primitive
=
(
lite
::
DeConv2D
*
)
primitive
;
MS_ASSERT
(
nullptr
!=
deconvdw_lite_primitive
);
MS_ASSERT
(
nullptr
!=
deconvdw_lite_primitive
);
parameter
->
pad_u_
=
deconv_lite_primitive
->
PadUp
();
conv_param
->
pad_u_
=
deconv_lite_primitive
->
PadUp
();
parameter
->
pad_d_
=
deconv_lite_primitive
->
PadDown
();
conv_param
->
pad_d_
=
deconv_lite_primitive
->
PadDown
();
parameter
->
pad_l_
=
deconv_lite_primitive
->
PadLeft
();
conv_param
->
pad_l_
=
deconv_lite_primitive
->
PadLeft
();
parameter
->
pad_r_
=
deconv_lite_primitive
->
PadRight
();
conv_param
->
pad_r_
=
deconv_lite_primitive
->
PadRight
();
parameter
->
pad_h_
=
deconv_lite_primitive
->
PadUp
();
conv_param
->
pad_h_
=
deconv_lite_primitive
->
PadUp
();
parameter
->
pad_w_
=
deconv_lite_primitive
->
PadLeft
();
conv_param
->
pad_w_
=
deconv_lite_primitive
->
PadLeft
();
parameter
->
dilation_h_
=
conv_primitive
->
dilateH
();
conv_param
->
dilation_h_
=
conv_primitive
->
dilateH
();
parameter
->
dilation_w_
=
conv_primitive
->
dilateW
();
conv_param
->
dilation_w_
=
conv_primitive
->
dilateW
();
auto
act_type
=
conv_primitive
->
activationType
();
auto
act_type
=
conv_primitive
->
activationType
();
switch
(
act_type
)
{
switch
(
act_type
)
{
case
schema
::
ActivationType_RELU
:
case
schema
::
ActivationType_RELU
:
parameter
->
is_relu_
=
true
;
conv_param
->
is_relu_
=
true
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
case
schema
::
ActivationType_RELU6
:
case
schema
::
ActivationType_RELU6
:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
true
;
conv_param
->
is_relu6_
=
true
;
break
;
break
;
default:
default:
parameter
->
is_relu_
=
false
;
conv_param
->
is_relu_
=
false
;
parameter
->
is_relu6_
=
false
;
conv_param
->
is_relu6_
=
false
;
break
;
break
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
conv_param
)
;
}
}
Softmax
Parameter
*
PopulateSoftmaxParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateSoftmaxParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
softmax_primitive
=
primitive
->
Value
()
->
value_as_SoftMax
();
auto
softmax_primitive
=
primitive
->
Value
()
->
value_as_SoftMax
();
SoftmaxParameter
*
parameter
=
new
(
std
::
nothrow
)
SoftmaxParameter
();
SoftmaxParameter
*
softmax_param
=
new
(
std
::
nothrow
)
SoftmaxParameter
();
if
(
parameter
==
nullptr
)
{
if
(
softmax_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new SoftmaxParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new SoftmaxParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
axis_
=
softmax_primitive
->
axis
();
softmax_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
parameter
;
softmax_param
->
axis_
=
softmax_primitive
->
axis
();
return
reinterpret_cast
<
OpParameter
*>
(
softmax_param
);
}
}
Reduce
Parameter
*
PopulateReduceParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateReduceParameter
(
const
lite
::
Primitive
*
primitive
)
{
ReduceParameter
*
parameter
=
new
(
std
::
nothrow
)
ReduceParameter
();
ReduceParameter
*
reduce_param
=
new
(
std
::
nothrow
)
ReduceParameter
();
if
(
parameter
==
nullptr
)
{
if
(
reduce_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ReduceParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ReduceParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
reduce_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
reduce
=
primitive
->
Value
()
->
value_as_Reduce
();
auto
reduce
=
primitive
->
Value
()
->
value_as_Reduce
();
parameter
->
keep_dims_
=
reduce
->
keepDims
();
reduce_param
->
keep_dims_
=
reduce
->
keepDims
();
auto
axisVector
=
reduce
->
axes
();
auto
axisVector
=
reduce
->
axes
();
if
(
axisVector
->
size
()
>
REDUCE_MAX_AXES_NUM
)
{
if
(
axisVector
->
size
()
>
REDUCE_MAX_AXES_NUM
)
{
MS_LOG
(
ERROR
)
<<
"Reduce axes size "
<<
axisVector
->
size
()
<<
" exceed limit "
<<
REDUCE_MAX_AXES_NUM
;
MS_LOG
(
ERROR
)
<<
"Reduce axes size "
<<
axisVector
->
size
()
<<
" exceed limit "
<<
REDUCE_MAX_AXES_NUM
;
delete
(
parameter
);
delete
(
reduce_param
);
return
nullptr
;
return
nullptr
;
}
}
parameter
->
num_axes_
=
static_cast
<
int
>
(
axisVector
->
size
());
reduce_param
->
num_axes_
=
static_cast
<
int
>
(
axisVector
->
size
());
int
i
=
0
;
int
i
=
0
;
for
(
auto
iter
=
axisVector
->
begin
();
iter
!=
axisVector
->
end
();
iter
++
)
{
for
(
auto
iter
=
axisVector
->
begin
();
iter
!=
axisVector
->
end
();
iter
++
)
{
parameter
->
axes_
[
i
++
]
=
*
iter
;
reduce_param
->
axes_
[
i
++
]
=
*
iter
;
}
}
parameter
->
mode_
=
static_cast
<
int
>
(
reduce
->
mode
());
reduce_param
->
mode_
=
static_cast
<
int
>
(
reduce
->
mode
());
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
reduce_param
)
;
}
}
Pad
Parameter
*
PopulatePadParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulatePadParameter
(
const
lite
::
Primitive
*
primitive
)
{
PadParameter
*
pad_param
=
new
(
std
::
nothrow
)
PadParameter
();
PadParameter
*
pad_param
=
new
(
std
::
nothrow
)
PadParameter
();
if
(
pad_param
==
nullptr
)
{
if
(
pad_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PadParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new PadParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
pad_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
pad_node
=
primitive
->
Value
()
->
value_as_Pad
();
auto
pad_node
=
primitive
->
Value
()
->
value_as_Pad
();
pad_param
->
pad_mode_
=
pad_node
->
paddingMode
();
pad_param
->
pad_mode_
=
pad_node
->
paddingMode
();
if
(
pad_param
->
pad_mode_
==
schema
::
PaddingMode_CONSTANT
)
{
if
(
pad_param
->
pad_mode_
==
schema
::
PaddingMode_CONSTANT
)
{
pad_param
->
constant_value_
=
pad_node
->
constantValue
();
pad_param
->
constant_value_
=
pad_node
->
constantValue
();
...
@@ -402,218 +421,212 @@ PadParameter *PopulatePadParameter(const lite::Primitive *primitive) {
...
@@ -402,218 +421,212 @@ PadParameter *PopulatePadParameter(const lite::Primitive *primitive) {
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
pad_param
->
paddings_
[
MAX_PAD_SIZE
-
size
+
i
]
=
(
*
(
pad_node
->
paddings
()))[
i
];
pad_param
->
paddings_
[
MAX_PAD_SIZE
-
size
+
i
]
=
(
*
(
pad_node
->
paddings
()))[
i
];
}
}
return
pad_param
;
return
reinterpret_cast
<
OpParameter
*>
(
pad_param
)
;
}
}
Activation
Parameter
*
PopulateActivationParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateActivationParameter
(
const
lite
::
Primitive
*
primitive
)
{
ActivationParameter
*
parameter
=
new
(
std
::
nothrow
)
ActivationParameter
();
ActivationParameter
*
act_param
=
new
(
std
::
nothrow
)
ActivationParameter
();
if
(
parameter
==
nullptr
)
{
if
(
act_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ActivationParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ActivationParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
activation
=
primitive
->
Value
()
->
value_as_Activation
();
auto
activation
=
primitive
->
Value
()
->
value_as_Activation
();
parameter
->
type_
=
static_cast
<
int
>
(
activation
->
type
());
act_param
->
type_
=
static_cast
<
int
>
(
activation
->
type
());
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
act_param
)
;
}
}
FusedBatchNorm
Parameter
*
PopulateFusedBatchNorm
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateFusedBatchNorm
(
const
lite
::
Primitive
*
primitive
)
{
FusedBatchNormParameter
*
parameter
=
new
(
std
::
nothrow
)
FusedBatchNormParameter
();
FusedBatchNormParameter
*
fuse_batch_norm_param
=
new
(
std
::
nothrow
)
FusedBatchNormParameter
();
if
(
parameter
==
nullptr
)
{
if
(
fuse_batch_norm_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new FusedBatchNormParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new FusedBatchNormParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
fuse_batch_norm_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_FusedBatchNorm
();
auto
param
=
primitive
->
Value
()
->
value_as_FusedBatchNorm
();
parameter
->
epsilon_
=
param
->
epsilon
();
fuse_batch_norm_param
->
epsilon_
=
param
->
epsilon
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
fuse_batch_norm_param
)
;
}
}
Arithmetic
Parameter
*
PopulateArithmetic
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateArithmetic
(
const
lite
::
Primitive
*
primitive
)
{
ArithmeticParameter
*
parameter
=
new
(
std
::
nothrow
)
ArithmeticParameter
();
ArithmeticParameter
*
arithmetic_param
=
new
(
std
::
nothrow
)
ArithmeticParameter
();
if
(
parameter
==
nullptr
)
{
if
(
arithmetic_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ArithmeticParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ArithmeticParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter
.
type_
=
primitive
->
Type
();
arithmetic_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
broadcasting_
=
((
lite
::
Arithmetic
*
)
primitive
)
->
Broadcasting
();
arithmetic_param
->
broadcasting_
=
((
lite
::
Arithmetic
*
)
primitive
)
->
Broadcasting
();
parameter
->
ndim_
=
((
lite
::
Arithmetic
*
)
primitive
)
->
NDims
();
arithmetic_param
->
ndim_
=
((
lite
::
Arithmetic
*
)
primitive
)
->
NDims
();
auto
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
InShape0
();
auto
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
InShape0
();
(
void
)
memcpy
(
parameter
->
in_shape0_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
arithmetic_param
->
in_shape0_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
InShape1
();
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
InShape1
();
(
void
)
memcpy
(
parameter
->
in_shape1_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
arithmetic_param
->
in_shape1_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
OutputShape
();
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
OutputShape
();
(
void
)
memcpy
(
parameter
->
out_shape_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
arithmetic_param
->
out_shape_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
arithmetic_param
)
;
}
}
ArithmeticParameter
*
PopulateEltwiseParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateEltwiseParameter
(
const
lite
::
Primitive
*
primitive
)
{
ArithmeticParameter
*
parameter
=
new
(
std
::
nothrow
)
ArithmeticParameter
();
ArithmeticParameter
*
arithmetic_param
=
new
(
std
::
nothrow
)
ArithmeticParameter
();
if
(
parameter
==
nullptr
)
{
if
(
arithmetic_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ArithmeticParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ArithmeticParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
eltwise
=
primitive
->
Value
()
->
value_as_Eltwise
();
auto
eltwise
=
primitive
->
Value
()
->
value_as_Eltwise
();
switch
(
eltwise
->
mode
())
{
switch
(
eltwise
->
mode
())
{
case
schema
::
EltwiseMode_PROD
:
case
schema
::
EltwiseMode_PROD
:
parameter
->
op_parameter
.
type_
=
schema
::
PrimitiveType_Mul
;
arithmetic_param
->
op_parameter_
.
type_
=
schema
::
PrimitiveType_Mul
;
break
;
break
;
case
schema
::
EltwiseMode_SUM
:
case
schema
::
EltwiseMode_SUM
:
parameter
->
op_parameter
.
type_
=
schema
::
PrimitiveType_Add
;
arithmetic_param
->
op_parameter_
.
type_
=
schema
::
PrimitiveType_Add
;
break
;
break
;
case
schema
::
EltwiseMode_MAXIMUM
:
case
schema
::
EltwiseMode_MAXIMUM
:
parameter
->
op_parameter
.
type_
=
schema
::
PrimitiveType_Maximum
;
arithmetic_param
->
op_parameter_
.
type_
=
schema
::
PrimitiveType_Maximum
;
break
;
break
;
default:
default:
delete
parameter
;
delete
arithmetic_param
;
return
nullptr
;
return
nullptr
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
arithmetic_param
)
;
}
}
ArithmeticSelf
Parameter
*
PopulateArithmeticSelf
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateArithmeticSelf
(
const
lite
::
Primitive
*
primitive
)
{
ArithmeticSelfParameter
*
parameter
=
new
(
std
::
nothrow
)
ArithmeticSelfParameter
();
ArithmeticSelfParameter
*
arithmetic_self_param
=
new
(
std
::
nothrow
)
ArithmeticSelfParameter
();
if
(
parameter
==
nullptr
)
{
if
(
arithmetic_self_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ArithmeticParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ArithmeticParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
arithmetic_self_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
arithmetic_self_param
)
;
}
}
Power
Parameter
*
PopulatePowerParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulatePowerParameter
(
const
lite
::
Primitive
*
primitive
)
{
PowerParameter
*
p
arameter
=
new
(
std
::
nothrow
)
PowerParameter
();
PowerParameter
*
p
ower_param
=
new
(
std
::
nothrow
)
PowerParameter
();
if
(
p
arameter
==
nullptr
)
{
if
(
p
ower_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PowerParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new PowerParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
power_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
power
=
primitive
->
Value
()
->
value_as_Power
();
auto
power
=
primitive
->
Value
()
->
value_as_Power
();
p
arameter
->
power_
=
power
->
power
();
p
ower_param
->
power_
=
power
->
power
();
p
arameter
->
scale_
=
power
->
scale
();
p
ower_param
->
scale_
=
power
->
scale
();
p
arameter
->
shift_
=
power
->
shift
();
p
ower_param
->
shift_
=
power
->
shift
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
power_param
)
;
}
}
ArgMinMaxParameter
*
PopulateArgMaxParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateArgMaxParameter
(
const
lite
::
Primitive
*
primitive
)
{
ArgMinMaxParameter
*
parameter
=
new
(
std
::
nothrow
)
ArgMinMaxParameter
();
ArgMinMaxParameter
*
arg_param
=
new
(
std
::
nothrow
)
ArgMinMaxParameter
();
if
(
parameter
==
nullptr
)
{
if
(
arg_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ArgMinMaxParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ArgMinMaxParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
arg_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_ArgMax
();
auto
param
=
primitive
->
Value
()
->
value_as_ArgMax
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
arg_param
->
axis_
=
param
->
axis
();
parameter
->
axis_
=
param
->
axis
();
arg_param
->
topk_
=
param
->
topK
();
parameter
->
topk_
=
param
->
topK
();
arg_param
->
axis_type_
=
param
->
axisType
();
parameter
->
axis_type_
=
param
->
axisType
();
arg_param
->
out_value_
=
param
->
outMaxValue
();
parameter
->
out_value_
=
param
->
outMaxValue
();
arg_param
->
keep_dims_
=
param
->
keepDims
();
parameter
->
keep_dims_
=
param
->
keepDims
();
return
reinterpret_cast
<
OpParameter
*>
(
arg_param
);
return
parameter
;
}
}
ArgMinMaxParameter
*
PopulateArgMinParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateArgMinParameter
(
const
lite
::
Primitive
*
primitive
)
{
ArgMinMaxParameter
*
parameter
=
new
(
std
::
nothrow
)
ArgMinMaxParameter
();
ArgMinMaxParameter
*
arg_param
=
new
(
std
::
nothrow
)
ArgMinMaxParameter
();
if
(
parameter
==
nullptr
)
{
if
(
arg_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ArgMinMaxParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ArgMinMaxParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
arg_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_ArgMin
();
auto
param
=
primitive
->
Value
()
->
value_as_ArgMin
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
arg_param
->
axis_
=
param
->
axis
();
parameter
->
axis_
=
param
->
axis
();
arg_param
->
topk_
=
param
->
topK
();
parameter
->
topk_
=
param
->
topK
();
arg_param
->
axis_type_
=
param
->
axisType
();
parameter
->
axis_type_
=
param
->
axisType
();
arg_param
->
out_value_
=
param
->
outMaxValue
();
parameter
->
out_value_
=
param
->
outMaxValue
();
arg_param
->
keep_dims_
=
param
->
keepDims
();
parameter
->
keep_dims_
=
param
->
keepDims
();
return
reinterpret_cast
<
OpParameter
*>
(
arg_param
);
return
parameter
;
}
}
CastParameter
*
PopulateCastParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateCastParameter
(
const
lite
::
Primitive
*
primitive
)
{
CastParameter
*
parameter
=
new
(
std
::
nothrow
)
CastParameter
();
CastParameter
*
cast_param
=
new
(
std
::
nothrow
)
CastParameter
();
if
(
parameter
==
nullptr
)
{
if
(
cast_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new CastParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new CastParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
cast_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_Cast
();
auto
param
=
primitive
->
Value
()
->
value_as_Cast
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
cast_param
->
src_type_
=
param
->
srcT
();
parameter
->
src_type_
=
param
->
srcT
();
cast_param
->
dst_type_
=
param
->
dstT
();
parameter
->
dst_type_
=
param
->
dstT
();
return
reinterpret_cast
<
OpParameter
*>
(
cast_param
);
return
parameter
;
}
}
LocalResponseNorm
Parameter
*
PopulateLocalResponseNormParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateLocalResponseNormParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
local_response_norm_attr
=
primitive
->
Value
()
->
value_as_LocalResponseNormalization
();
auto
local_response_norm_attr
=
primitive
->
Value
()
->
value_as_LocalResponseNormalization
();
LocalResponseNormParameter
*
parameter
=
new
(
std
::
nothrow
)
LocalResponseNormParameter
();
LocalResponseNormParameter
*
lrn_param
=
new
(
std
::
nothrow
)
LocalResponseNormParameter
();
if
(
parameter
==
nullptr
)
{
if
(
lrn_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new LocalResponseNormParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new LocalResponseNormParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
depth_radius_
=
local_response_norm_attr
->
depth_radius
();
lrn_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
bias_
=
local_response_norm_attr
->
bias
();
lrn_param
->
depth_radius_
=
local_response_norm_attr
->
depth_radius
();
parameter
->
alpha_
=
local_response_norm_attr
->
alpha
();
lrn_param
->
bias_
=
local_response_norm_attr
->
bias
();
parameter
->
beta_
=
local_response_norm_attr
->
beta
();
lrn_param
->
alpha_
=
local_response_norm_attr
->
alpha
();
return
parameter
;
lrn_param
->
beta_
=
local_response_norm_attr
->
beta
();
return
reinterpret_cast
<
OpParameter
*>
(
lrn_param
);
}
}
Range
Parameter
*
PopulateRangeParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateRangeParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
range_attr
=
primitive
->
Value
()
->
value_as_Range
();
auto
range_attr
=
primitive
->
Value
()
->
value_as_Range
();
RangeParameter
*
parameter
=
new
(
std
::
nothrow
)
RangeParameter
();
RangeParameter
*
range_param
=
new
(
std
::
nothrow
)
RangeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
range_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new RangeParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new RangeParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
start_
=
range_attr
->
start
();
range_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
limit_
=
range_attr
->
limit
();
range_param
->
start_
=
range_attr
->
start
();
parameter
->
delta_
=
range_attr
->
delta
();
range_param
->
limit_
=
range_attr
->
limit
();
parameter
->
dType_
=
range_attr
->
dType
();
range_param
->
delta_
=
range_attr
->
delta
();
return
parameter
;
range_param
->
dType_
=
range_attr
->
dType
();
}
return
reinterpret_cast
<
OpParameter
*>
(
range_param
);
OpParameter
*
PopulateCeilParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
parameter
=
new
(
std
::
nothrow
)
OpParameter
();
if
(
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new OpParameter failed."
;
return
nullptr
;
}
parameter
->
type_
=
primitive
->
Type
();
return
parameter
;
}
}
Concat
Parameter
*
PopulateConcatParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateConcatParameter
(
const
lite
::
Primitive
*
primitive
)
{
ConcatParameter
*
parameter
=
new
(
std
::
nothrow
)
ConcatParameter
();
ConcatParameter
*
concat_param
=
new
(
std
::
nothrow
)
ConcatParameter
();
if
(
parameter
==
nullptr
)
{
if
(
concat_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ConcatParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ConcatParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
concat_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_Concat
();
auto
param
=
primitive
->
Value
()
->
value_as_Concat
();
parameter
->
axis_
=
param
->
axis
();
concat_param
->
axis_
=
param
->
axis
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
concat_param
)
;
}
}
Tile
Parameter
*
PopulateTileParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateTileParameter
(
const
lite
::
Primitive
*
primitive
)
{
TileParameter
*
parameter
=
new
(
std
::
nothrow
)
TileParameter
();
TileParameter
*
tile_param
=
new
(
std
::
nothrow
)
TileParameter
();
if
(
parameter
==
nullptr
)
{
if
(
tile_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new TileParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new TileParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
tile_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_Tile
();
auto
param
=
primitive
->
Value
()
->
value_as_Tile
();
auto
multiples
=
param
->
multiples
();
auto
multiples
=
param
->
multiples
();
parameter
->
in_dim_
=
multiples
->
size
();
tile_param
->
in_dim_
=
multiples
->
size
();
for
(
size_t
i
=
0
;
i
<
parameter
->
in_dim_
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tile_param
->
in_dim_
;
++
i
)
{
parameter
->
multiples_
[
i
]
=
multiples
->
Get
(
i
);
tile_param
->
multiples_
[
i
]
=
multiples
->
Get
(
i
);
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
tile_param
)
;
}
}
Topk
Parameter
*
PopulateTopKParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateTopKParameter
(
const
lite
::
Primitive
*
primitive
)
{
TopkParameter
*
parameter
=
new
(
std
::
nothrow
)
TopkParameter
();
TopkParameter
*
topk_param
=
new
(
std
::
nothrow
)
TopkParameter
();
if
(
parameter
==
nullptr
)
{
if
(
topk_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new TopkParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new TopkParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
topk_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_TopK
();
auto
param
=
primitive
->
Value
()
->
value_as_TopK
();
parameter
->
k_
=
param
->
k
();
topk_param
->
k_
=
param
->
k
();
parameter
->
sorted_
=
param
->
sorted
();
topk_param
->
sorted_
=
param
->
sorted
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
topk_param
)
;
}
}
OpParameter
*
PopulateNhwc2NchwParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateNhwc2NchwParameter
(
const
lite
::
Primitive
*
primitive
)
{
...
@@ -636,64 +649,64 @@ OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) {
...
@@ -636,64 +649,64 @@ OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) {
return
parameter
;
return
parameter
;
}
}
Transpose
Parameter
*
PopulateTransposeParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateTransposeParameter
(
const
lite
::
Primitive
*
primitive
)
{
TransposeParameter
*
parameter
=
new
(
std
::
nothrow
)
TransposeParameter
();
TransposeParameter
*
transpose_param
=
new
(
std
::
nothrow
)
TransposeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
transpose_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new TransposeParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new TransposeParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_Transpose
();
auto
param
=
primitive
->
Value
()
->
value_as_Transpose
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
transpose_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
perm_vector_
=
param
->
perm
();
auto
perm_vector_
=
param
->
perm
();
int
i
=
0
;
int
i
=
0
;
for
(
auto
iter
=
perm_vector_
->
begin
();
iter
!=
perm_vector_
->
end
();
iter
++
)
{
for
(
auto
iter
=
perm_vector_
->
begin
();
iter
!=
perm_vector_
->
end
();
iter
++
)
{
parameter
->
perm_
[
i
++
]
=
*
iter
;
transpose_param
->
perm_
[
i
++
]
=
*
iter
;
}
}
parameter
->
num_axes_
=
i
;
transpose_param
->
num_axes_
=
i
;
parameter
->
conjugate_
=
param
->
conjugate
();
transpose_param
->
conjugate_
=
param
->
conjugate
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
transpose_param
)
;
}
}
Split
Parameter
*
PopulateSplitParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateSplitParameter
(
const
lite
::
Primitive
*
primitive
)
{
SplitParameter
*
parameter
=
new
(
std
::
nothrow
)
SplitParameter
();
SplitParameter
*
split_param
=
new
(
std
::
nothrow
)
SplitParameter
();
if
(
parameter
==
nullptr
)
{
if
(
split_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new SplitParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new SplitParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_Split
();
auto
param
=
primitive
->
Value
()
->
value_as_Split
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
split_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
num_split_
=
param
->
numberSplit
();
split_param
->
num_split_
=
param
->
numberSplit
();
auto
split_sizes_vector_
=
param
->
sizeSplits
();
auto
split_sizes_vector_
=
param
->
sizeSplits
();
int
i
=
0
;
int
i
=
0
;
for
(
auto
iter
=
split_sizes_vector_
->
begin
();
iter
!=
split_sizes_vector_
->
end
();
iter
++
)
{
for
(
auto
iter
=
split_sizes_vector_
->
begin
();
iter
!=
split_sizes_vector_
->
end
();
iter
++
)
{
parameter
->
split_sizes_
[
i
++
]
=
*
iter
;
split_param
->
split_sizes_
[
i
++
]
=
*
iter
;
}
}
parameter
->
split_dim_
=
param
->
splitDim
();
split_param
->
split_dim_
=
param
->
splitDim
();
parameter
->
num_split_
=
param
->
numberSplit
();
split_param
->
num_split_
=
param
->
numberSplit
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
split_param
)
;
}
}
Squeeze
Parameter
*
PopulateSqueezeParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateSqueezeParameter
(
const
lite
::
Primitive
*
primitive
)
{
SqueezeParameter
*
parameter
=
new
(
std
::
nothrow
)
SqueezeParameter
();
SqueezeParameter
*
squeeze_param
=
new
(
std
::
nothrow
)
SqueezeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
squeeze_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new SqueezeParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new SqueezeParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
squeeze_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
squeeze_param
)
;
}
}
Scale
Parameter
*
PopulateScaleParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateScaleParameter
(
const
lite
::
Primitive
*
primitive
)
{
if
(
primitive
==
nullptr
)
{
if
(
primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"input primitive is nullptr"
;
MS_LOG
(
ERROR
)
<<
"input primitive is nullptr"
;
return
nullptr
;
return
nullptr
;
}
}
ScaleParameter
*
parameter
=
new
(
std
::
nothrow
)
ScaleParameter
();
ScaleParameter
*
scale_param
=
new
(
std
::
nothrow
)
ScaleParameter
();
if
(
parameter
==
nullptr
)
{
if
(
scale_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ScaleParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ScaleParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
scale_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_Scale
();
auto
param
=
primitive
->
Value
()
->
value_as_Scale
();
if
(
param
==
nullptr
)
{
if
(
param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"value_as_Scale return nullptr"
;
MS_LOG
(
ERROR
)
<<
"value_as_Scale return nullptr"
;
...
@@ -701,219 +714,253 @@ ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) {
...
@@ -701,219 +714,253 @@ ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) {
}
}
// NCHW todo use enum
// NCHW todo use enum
if
(
param
->
format
()
==
schema
::
Format_NCHW
)
{
if
(
param
->
format
()
==
schema
::
Format_NCHW
)
{
parameter
->
axis_
=
1
;
scale_param
->
axis_
=
1
;
parameter
->
num_axis_
=
1
;
scale_param
->
num_axis_
=
1
;
}
else
if
(
param
->
format
()
==
schema
::
Format_NHWC
)
{
}
else
if
(
param
->
format
()
==
schema
::
Format_NHWC
)
{
parameter
->
axis_
=
3
;
scale_param
->
axis_
=
3
;
parameter
->
num_axis_
=
1
;
scale_param
->
num_axis_
=
1
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
scale_param
)
;
}
}
Gather
Parameter
*
PopulateGatherParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateGatherParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
gather_attr
=
primitive
->
Value
()
->
value_as_Gather
();
auto
gather_attr
=
primitive
->
Value
()
->
value_as_Gather
();
GatherParameter
*
parameter
=
new
(
std
::
nothrow
)
GatherParameter
();
GatherParameter
*
gather_param
=
new
(
std
::
nothrow
)
GatherParameter
();
if
(
parameter
==
nullptr
)
{
if
(
gather_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new GatherParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new GatherParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
axis_
=
gather_attr
->
axis
();
gather_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
batchDims_
=
gather_attr
->
batchDims
();
gather_param
->
axis_
=
gather_attr
->
axis
();
return
parameter
;
gather_param
->
batchDims_
=
gather_attr
->
batchDims
();
return
reinterpret_cast
<
OpParameter
*>
(
gather_param
);
}
}
GatherNdParameter
*
PopulateGatherNdParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateGatherNdParameter
(
const
lite
::
Primitive
*
primitive
)
{
GatherNdParameter
*
parameter
=
new
(
std
::
nothrow
)
GatherNdParameter
();
GatherNdParameter
*
gather_nd_param
=
new
(
std
::
nothrow
)
GatherNdParameter
();
MS_ASSERT
(
paramter
!=
nullptr
);
if
(
gather_nd_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new GatherNDParameter failed."
;
return
nullptr
;
}
gather_nd_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
gatherNd_attr
=
primitive
->
Value
()
->
value_as_GatherNd
();
auto
gatherNd_attr
=
primitive
->
Value
()
->
value_as_GatherNd
();
parameter
->
batchDims_
=
gatherNd_attr
->
batchDims
();
gather_nd_param
->
batchDims_
=
gatherNd_attr
->
batchDims
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
gather_nd_param
)
;
}
}
ScatterNDParameter
*
PopulateScatterNDParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateScatterNDParameter
(
const
lite
::
Primitive
*
primitive
)
{
ScatterNDParameter
*
parameter
=
new
(
std
::
nothrow
)
ScatterNDParameter
();
ScatterNDParameter
*
scatter_nd_param
=
new
(
std
::
nothrow
)
ScatterNDParameter
();
if
(
scatter_nd_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ScatterNDParameter failed."
;
return
nullptr
;
}
scatter_nd_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
MS_ASSERT
(
paramter
!=
nullptr
);
MS_ASSERT
(
paramter
!=
nullptr
);
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
scatter_nd_param
)
;
}
}
SliceParameter
*
PopulateSliceParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateSliceParameter
(
const
lite
::
Primitive
*
primitive
)
{
SliceParameter
*
parameter
=
new
(
std
::
nothrow
)
SliceParameter
();
SliceParameter
*
slice_param
=
new
(
std
::
nothrow
)
SliceParameter
();
if
(
parameter
==
nullptr
)
{
if
(
slice_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new SliceParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new SliceParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_Slice
();
auto
param
=
primitive
->
Value
()
->
value_as_Slice
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
slice_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param_begin
=
param
->
begin
();
auto
param_begin
=
param
->
begin
();
auto
param_size
=
param
->
size
();
auto
param_size
=
param
->
size
();
if
(
param_begin
->
size
()
!=
param_size
->
size
())
{
if
(
param_begin
->
size
()
!=
param_size
->
size
())
{
delete
parameter
;
delete
slice_param
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
param_length_
=
static_cast
<
int32_t
>
(
param_begin
->
size
());
slice_param
->
param_length_
=
static_cast
<
int32_t
>
(
param_begin
->
size
());
for
(
int32_t
i
=
0
;
i
<
parameter
->
param_length_
;
++
i
)
{
for
(
int32_t
i
=
0
;
i
<
slice_param
->
param_length_
;
++
i
)
{
parameter
->
begin_
[
i
]
=
param_begin
->
Get
(
i
);
slice_param
->
begin_
[
i
]
=
param_begin
->
Get
(
i
);
parameter
->
size_
[
i
]
=
param_size
->
Get
(
i
);
slice_param
->
size_
[
i
]
=
param_size
->
Get
(
i
);
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
slice_param
)
;
}
}
BroadcastToParameter
*
PopulateBroadcastToParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateBroadcastToParameter
(
const
lite
::
Primitive
*
primitive
)
{
BroadcastToParameter
*
parameter
=
new
(
std
::
nothrow
)
BroadcastToParameter
();
BroadcastToParameter
*
broadcast_param
=
new
(
std
::
nothrow
)
BroadcastToParameter
();
if
(
parameter
==
nullptr
)
{
if
(
broadcast_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new BroadcastToParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new BroadcastToParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_BroadcastTo
();
auto
param
=
primitive
->
Value
()
->
value_as_BroadcastTo
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
broadcast_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
dst_shape
=
param
->
dst_shape
();
auto
dst_shape
=
param
->
dst_shape
();
parameter
->
shape_size_
=
dst_shape
->
size
();
broadcast_param
->
shape_size_
=
dst_shape
->
size
();
for
(
size_t
i
=
0
;
i
<
parameter
->
shape_size_
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
broadcast_param
->
shape_size_
;
++
i
)
{
parameter
->
shape_
[
i
]
=
dst_shape
->
Get
(
i
);
broadcast_param
->
shape_
[
i
]
=
dst_shape
->
Get
(
i
);
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
broadcast_param
)
;
}
}
ReshapeParameter
*
PopulateReshapeParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateReshapeParameter
(
const
lite
::
Primitive
*
primitive
)
{
ReshapeParameter
*
parameter
=
new
(
std
::
nothrow
)
ReshapeParameter
();
ReshapeParameter
*
reshape_param
=
new
(
std
::
nothrow
)
ReshapeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
reshape_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ReshapeParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ReshapeParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
reshape_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
reshape_param
)
;
}
}
Reverse
Parameter
*
PopulateReverseParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateReverseParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
reverse_attr
=
primitive
->
Value
()
->
value_as_Reverse
();
auto
reverse_attr
=
primitive
->
Value
()
->
value_as_Reverse
();
ReverseParameter
*
parameter
=
new
(
std
::
nothrow
)
ReverseParameter
();
ReverseParameter
*
reverse_param
=
new
(
std
::
nothrow
)
ReverseParameter
();
if
(
parameter
==
nullptr
)
{
if
(
reverse_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ReverseParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ReverseParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
reverse_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
flatAxis
=
reverse_attr
->
axis
();
auto
flatAxis
=
reverse_attr
->
axis
();
parameter
->
num_axis_
=
flatAxis
->
size
();
reverse_param
->
num_axis_
=
flatAxis
->
size
();
int
i
=
0
;
int
i
=
0
;
for
(
auto
iter
=
flatAxis
->
begin
();
iter
!=
flatAxis
->
end
();
iter
++
)
{
for
(
auto
iter
=
flatAxis
->
begin
();
iter
!=
flatAxis
->
end
();
iter
++
)
{
parameter
->
axis_
[
i
++
]
=
*
iter
;
reverse_param
->
axis_
[
i
++
]
=
*
iter
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
reverse_param
)
;
}
}
Unsqueeze
Parameter
*
PopulateUnsqueezeParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateUnsqueezeParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
unsqueeze_attr
=
primitive
->
Value
()
->
value_as_Unsqueeze
();
auto
unsqueeze_attr
=
primitive
->
Value
()
->
value_as_Unsqueeze
();
UnsqueezeParameter
*
parameter
=
new
(
std
::
nothrow
)
UnsqueezeParameter
();
UnsqueezeParameter
*
unsqueeze_param
=
new
(
std
::
nothrow
)
UnsqueezeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
unsqueeze_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ReverseParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ReverseParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
unsqueeze_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
flatAxis
=
unsqueeze_attr
->
axis
();
auto
flatAxis
=
unsqueeze_attr
->
axis
();
parameter
->
num_dim_
=
flatAxis
->
size
();
unsqueeze_param
->
num_dim_
=
flatAxis
->
size
();
int
i
=
0
;
int
i
=
0
;
for
(
auto
iter
=
flatAxis
->
begin
();
iter
!=
flatAxis
->
end
();
iter
++
)
{
for
(
auto
iter
=
flatAxis
->
begin
();
iter
!=
flatAxis
->
end
();
iter
++
)
{
parameter
->
dims_
[
i
++
]
=
*
iter
;
unsqueeze_param
->
dims_
[
i
++
]
=
*
iter
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
unsqueeze_param
)
;
}
}
StackParameter
*
PopulateStackParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateStackParameter
(
const
lite
::
Primitive
*
primitive
)
{
StackParameter
*
parameter
=
new
(
std
::
nothrow
)
StackParameter
();
StackParameter
*
stack_param
=
new
(
std
::
nothrow
)
StackParameter
();
if
(
parameter
==
nullptr
)
{
if
(
stack_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new StackParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new StackParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_Stack
();
auto
param
=
primitive
->
Value
()
->
value_as_Stack
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
stack_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
axis_
=
param
->
axis
();
stack_param
->
axis_
=
param
->
axis
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
stack_param
)
;
}
}
UnstackParameter
*
PopulateUnstackParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateUnstackParameter
(
const
lite
::
Primitive
*
primitive
)
{
UnstackParameter
*
parameter
=
new
(
std
::
nothrow
)
UnstackParameter
();
UnstackParameter
*
unstack_param
=
new
(
std
::
nothrow
)
UnstackParameter
();
if
(
parameter
==
nullptr
)
{
if
(
unstack_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new UnstackParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new UnstackParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_Unstack
();
auto
param
=
primitive
->
Value
()
->
value_as_Unstack
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
unstack_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
num_
=
param
->
num
();
unstack_param
->
num_
=
param
->
num
();
parameter
->
axis_
=
param
->
axis
();
unstack_param
->
axis_
=
param
->
axis
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
unstack_param
)
;
}
}
ReverseSequenceParameter
*
PopulateReverseSequenceParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateReverseSequenceParameter
(
const
lite
::
Primitive
*
primitive
)
{
ReverseSequenceParameter
*
parameter
=
new
(
std
::
nothrow
)
ReverseSequenceParameter
();
ReverseSequenceParameter
*
reverse_sequence_param
=
new
(
std
::
nothrow
)
ReverseSequenceParameter
();
if
(
parameter
==
nullptr
)
{
if
(
reverse_sequence_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ReverseSequenceParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ReverseSequenceParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_ReverseSequence
();
auto
param
=
primitive
->
Value
()
->
value_as_ReverseSequence
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
reverse_sequence_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
seq_axis_
=
param
->
seqAxis
();
reverse_sequence_param
->
seq_axis_
=
param
->
seqAxis
();
parameter
->
batch_axis_
=
param
->
batchAxis
();
reverse_sequence_param
->
batch_axis_
=
param
->
batchAxis
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
reverse_sequence_param
)
;
}
}
UniqueParameter
*
PopulateUniqueParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateUniqueParameter
(
const
lite
::
Primitive
*
primitive
)
{
UniqueParameter
*
parameter
=
new
(
std
::
nothrow
)
UniqueParameter
();
UniqueParameter
*
unique_param
=
new
(
std
::
nothrow
)
UniqueParameter
();
if
(
parameter
==
nullptr
)
{
if
(
unique_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PopulateUniqueParam failed."
;
MS_LOG
(
ERROR
)
<<
"new PopulateUniqueParam failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
unique_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
unique_param
)
;
}
}
DepthToSpaceParameter
*
PopulateDepthToSpaceParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateDepthToSpaceParameter
(
const
lite
::
Primitive
*
primitive
)
{
DepthToSpaceParameter
*
parameter
=
new
(
std
::
nothrow
)
DepthToSpaceParameter
();
DepthToSpaceParameter
*
depth_space_param
=
new
(
std
::
nothrow
)
DepthToSpaceParameter
();
if
(
parameter
==
nullptr
)
{
if
(
depth_space_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new DepthToSpaceParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new DepthToSpaceParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
auto
param
=
primitive
->
Value
()
->
value_as_DepthToSpace
();
auto
param
=
primitive
->
Value
()
->
value_as_DepthToSpace
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
depth_space_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
block_size_
=
param
->
blockSize
();
depth_space_param
->
block_size_
=
param
->
blockSize
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
depth_space_param
)
;
}
}
SpaceToDepthParameter
*
PopulateSpaceToDepthParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateSpaceToDepthParameter
(
const
lite
::
Primitive
*
primitive
)
{
SpaceToDepthParameter
*
parameter
=
new
(
std
::
nothrow
)
SpaceToDepthParameter
();
SpaceToDepthParameter
*
space_depth_param
=
new
(
std
::
nothrow
)
SpaceToDepthParameter
();
if
(
parameter
==
nullptr
)
{
if
(
space_depth_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new SpaceToDepth
Parameter
failed."
;
MS_LOG
(
ERROR
)
<<
"new SpaceToDepth
space_depth_param
failed."
;
return
nullptr
;
return
nullptr
;
}
}
space_depth_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_DepthToSpace
();
auto
param
=
primitive
->
Value
()
->
value_as_DepthToSpace
();
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
space_depth_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
block_size_
=
param
->
blockSize
();
space_depth_param
->
block_size_
=
param
->
blockSize
();
if
(
param
->
format
()
!=
schema
::
Format_NHWC
)
{
if
(
param
->
format
()
!=
schema
::
Format_NHWC
)
{
MS_LOG
(
ERROR
)
<<
"Currently only NHWC format is supported."
;
MS_LOG
(
ERROR
)
<<
"Currently only NHWC format is supported."
;
return
nullptr
;
return
nullptr
;
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
space_depth_param
)
;
}
}
ResizeParameter
*
PopulateResizeParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateSpaceToBatchParameter
(
const
lite
::
Primitive
*
primitive
)
{
ResizeParameter
*
parameter
=
new
(
std
::
nothrow
)
ResizeParameter
();
SpaceToBatchParameter
*
space_batch_param
=
new
(
std
::
nothrow
)
SpaceToBatchParameter
();
if
(
parameter
==
nullptr
)
{
if
(
space_batch_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new SpaceToBatchParameter failed."
;
return
nullptr
;
}
space_batch_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
space_batch_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
block_sizes
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
BlockSizes
();
(
void
)
memcpy
(
space_batch_param
->
block_sizes_
,
(
block_sizes
.
data
()),
block_sizes
.
size
()
*
sizeof
(
int
));
auto
paddings
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
Paddings
();
(
void
)
memcpy
(
space_batch_param
->
paddings_
,
(
paddings
.
data
()),
paddings
.
size
()
*
sizeof
(
int
));
auto
in_shape
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
InShape
();
(
void
)
memcpy
(
space_batch_param
->
in_shape_
,
(
in_shape
.
data
()),
in_shape
.
size
()
*
sizeof
(
int
));
auto
padded_in_shape
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
PaddedInShape
();
(
void
)
memcpy
(
space_batch_param
->
padded_in_shape_
,
(
padded_in_shape
.
data
()),
padded_in_shape
.
size
()
*
sizeof
(
int
));
return
reinterpret_cast
<
OpParameter
*>
(
space_batch_param
);
}
OpParameter
*
PopulateResizeParameter
(
const
lite
::
Primitive
*
primitive
)
{
ResizeParameter
*
resize_param
=
new
(
std
::
nothrow
)
ResizeParameter
();
if
(
resize_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new ResizeParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new ResizeParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
resize_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_Resize
();
auto
param
=
primitive
->
Value
()
->
value_as_Resize
();
parameter
->
method_
=
param
->
method
();
resize_param
->
method_
=
param
->
method
();
parameter
->
new_height_
=
param
->
newHeight
();
resize_param
->
new_height_
=
param
->
newHeight
();
parameter
->
new_width_
=
param
->
newWidth
();
resize_param
->
new_width_
=
param
->
newWidth
();
parameter
->
align_corners_
=
param
->
alignCorners
();
resize_param
->
align_corners_
=
param
->
alignCorners
();
parameter
->
preserve_aspect_ratio_
=
param
->
preserveAspectRatio
();
resize_param
->
preserve_aspect_ratio_
=
param
->
preserveAspectRatio
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
resize_param
)
;
}
}
BatchToSpace
Parameter
*
PopulateBatchToSpaceParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateBatchToSpaceParameter
(
const
lite
::
Primitive
*
primitive
)
{
BatchToSpaceParameter
*
parameter
=
new
(
std
::
nothrow
)
BatchToSpaceParameter
();
BatchToSpaceParameter
*
batch_space_param
=
new
(
std
::
nothrow
)
BatchToSpaceParameter
();
if
(
parameter
==
nullptr
)
{
if
(
batch_space_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"New BatchToSpaceParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"New BatchToSpaceParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
batch_space_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_BatchToSpace
();
auto
param
=
primitive
->
Value
()
->
value_as_BatchToSpace
();
auto
block_shape
=
param
->
blockShape
();
auto
block_shape
=
param
->
blockShape
();
if
(
block_shape
->
size
()
!=
BATCH_TO_SPACE_BLOCK_SHAPE_SIZE
)
{
if
(
block_shape
->
size
()
!=
BATCH_TO_SPACE_BLOCK_SHAPE_SIZE
)
{
...
@@ -928,308 +975,271 @@ BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *prim
...
@@ -928,308 +975,271 @@ BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *prim
}
}
for
(
int
i
=
0
;
i
<
BATCH_TO_SPACE_BLOCK_SHAPE_SIZE
;
++
i
)
{
for
(
int
i
=
0
;
i
<
BATCH_TO_SPACE_BLOCK_SHAPE_SIZE
;
++
i
)
{
parameter
->
block_shape_
[
i
]
=
block_shape
->
Get
(
i
);
batch_space_param
->
block_shape_
[
i
]
=
block_shape
->
Get
(
i
);
}
}
for
(
int
i
=
0
;
i
<
BATCH_TO_SPACE_CROPS_SIZE
;
++
i
)
{
for
(
int
i
=
0
;
i
<
BATCH_TO_SPACE_CROPS_SIZE
;
++
i
)
{
parameter
->
crops_
[
i
]
=
crops
->
Get
(
i
);
batch_space_param
->
crops_
[
i
]
=
crops
->
Get
(
i
);
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
batch_space_param
)
;
}
}
Cro
pParameter
*
PopulateCropParameter
(
const
lite
::
Primitive
*
primitive
)
{
O
pParameter
*
PopulateCropParameter
(
const
lite
::
Primitive
*
primitive
)
{
auto
param
=
primitive
->
Value
()
->
value_as_Crop
();
auto
param
=
primitive
->
Value
()
->
value_as_Crop
();
auto
param_offset
=
param
->
offsets
();
auto
param_offset
=
param
->
offsets
();
if
(
param_offset
->
size
()
>
CROP_OFFSET_MAX_SIZE
)
{
if
(
param_offset
->
size
()
>
CROP_OFFSET_MAX_SIZE
)
{
MS_LOG
(
ERROR
)
<<
"
parameter
offset size("
<<
param_offset
->
size
()
<<
") should <= "
<<
CROP_OFFSET_MAX_SIZE
;
MS_LOG
(
ERROR
)
<<
"
crop_param
offset size("
<<
param_offset
->
size
()
<<
") should <= "
<<
CROP_OFFSET_MAX_SIZE
;
return
nullptr
;
return
nullptr
;
}
}
CropParameter
*
parameter
=
new
(
std
::
nothrow
)
CropParameter
();
CropParameter
*
crop_param
=
new
(
std
::
nothrow
)
CropParameter
();
if
(
parameter
==
nullptr
)
{
if
(
crop_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new CropParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"new CropParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
axis_
=
param
->
axis
();
crop_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
parameter
->
offset_size_
=
param_offset
->
size
();
crop_param
->
axis_
=
param
->
axis
();
crop_param
->
offset_size_
=
param_offset
->
size
();
for
(
int
i
=
0
;
i
<
param_offset
->
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
param_offset
->
size
();
++
i
)
{
parameter
->
offset_
[
i
]
=
param_offset
->
Get
(
i
);
crop_param
->
offset_
[
i
]
=
param_offset
->
Get
(
i
);
}
}
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
crop_param
)
;
}
}
O
neHot
Parameter
*
PopulateOneHotParameter
(
const
lite
::
Primitive
*
primitive
)
{
O
p
Parameter
*
PopulateOneHotParameter
(
const
lite
::
Primitive
*
primitive
)
{
OneHotParameter
*
parameter
=
new
(
std
::
nothrow
)
OneHotParameter
();
OneHotParameter
*
one_hot_param
=
new
(
std
::
nothrow
)
OneHotParameter
();
if
(
parameter
==
nullptr
)
{
if
(
one_hot_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new OneHotParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"new OneHotParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
one_hot_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
param
=
primitive
->
Value
()
->
value_as_OneHot
();
auto
param
=
primitive
->
Value
()
->
value_as_OneHot
();
if
(
param
==
nullptr
)
{
if
(
param
==
nullptr
)
{
delete
(
parameter
);
delete
(
one_hot_param
);
MS_LOG
(
ERROR
)
<<
"get OneHot param nullptr."
;
MS_LOG
(
ERROR
)
<<
"get OneHot param nullptr."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
axis_
=
param
->
axis
();
one_hot_param
->
axis_
=
param
->
axis
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
one_hot_param
)
;
}
}
Flatten
Parameter
*
PopulateFlattenParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateFlattenParameter
(
const
lite
::
Primitive
*
primitive
)
{
FlattenParameter
*
parameter
=
new
(
std
::
nothrow
)
FlattenParameter
();
FlattenParameter
*
flatten_param
=
new
(
std
::
nothrow
)
FlattenParameter
();
if
(
parameter
==
nullptr
)
{
if
(
flatten_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new FlattenParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"new FlattenParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
return
parameter
;
flatten_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
reinterpret_cast
<
OpParameter
*>
(
flatten_param
);
}
}
Dequantize
Parameter
*
PopulateDequantizeParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateDequantizeParameter
(
const
lite
::
Primitive
*
primitive
)
{
DequantizeParameter
*
parameter
=
new
(
std
::
nothrow
)
DequantizeParameter
();
DequantizeParameter
*
dequantize_
parameter
=
new
(
std
::
nothrow
)
DequantizeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
dequantize_
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new DequantizeParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"new DequantizeParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
return
parameter
;
dequantize_parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
reinterpret_cast
<
OpParameter
*>
(
dequantize_parameter
);
}
}
Quantize
Parameter
*
PopulateQuantizeParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulateQuantizeParameter
(
const
lite
::
Primitive
*
primitive
)
{
QuantizeParameter
*
parameter
=
new
(
std
::
nothrow
)
QuantizeParameter
();
QuantizeParameter
*
quantize_
parameter
=
new
(
std
::
nothrow
)
QuantizeParameter
();
if
(
parameter
==
nullptr
)
{
if
(
quantize_
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new QuantizeParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"new QuantizeParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
return
parameter
;
quantize_parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
return
reinterpret_cast
<
OpParameter
*>
(
quantize_parameter
);
}
}
StridedSliceParameter
*
PopulateStridedSliceParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateStridedSliceParameter
(
const
lite
::
Primitive
*
primitive
)
{
StridedSliceParameter
*
parameter
=
new
(
std
::
nothrow
)
StridedSliceParameter
();
StridedSliceParameter
*
strided_slice_param
=
new
(
std
::
nothrow
)
StridedSliceParameter
();
if
(
parameter
==
nullptr
)
{
if
(
strided_slice_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new StridedSliceParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new StridedSliceParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
strided_slice_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
n_dims
=
((
lite
::
StridedSlice
*
)
primitive
)
->
NDims
();
auto
n_dims
=
((
lite
::
StridedSlice
*
)
primitive
)
->
NDims
();
parameter
->
num_axes_
=
n_dims
;
strided_slice_param
->
num_axes_
=
n_dims
;
auto
begin
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedBegins
();
auto
begin
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedBegins
();
(
void
)
memcpy
(
parameter
->
begins_
,
(
begin
.
data
()),
begin
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
strided_slice_param
->
begins_
,
(
begin
.
data
()),
begin
.
size
()
*
sizeof
(
int
));
auto
end
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedEnds
();
auto
end
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedEnds
();
(
void
)
memcpy
(
parameter
->
ends_
,
(
end
.
data
()),
end
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
strided_slice_param
->
ends_
,
(
end
.
data
()),
end
.
size
()
*
sizeof
(
int
));
auto
stride
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedStrides
();
auto
stride
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedStrides
();
(
void
)
memcpy
(
parameter
->
strides_
,
(
stride
.
data
()),
stride
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
strided_slice_param
->
strides_
,
(
stride
.
data
()),
stride
.
size
()
*
sizeof
(
int
));
auto
in_shape
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedInShape
();
auto
in_shape
=
((
lite
::
StridedSlice
*
)
primitive
)
->
UpdatedInShape
();
(
void
)
memcpy
(
parameter
->
in_shape_
,
(
in_shape
.
data
()),
in_shape
.
size
()
*
sizeof
(
int
));
(
void
)
memcpy
(
strided_slice_param
->
in_shape_
,
(
in_shape
.
data
()),
in_shape
.
size
()
*
sizeof
(
int
));
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
strided_slice_param
)
;
}
}
OpParameter
*
PopulateAddNParam
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateAddNParam
eter
(
const
lite
::
Primitive
*
primitive
)
{
auto
parameter
=
new
(
std
::
nothrow
)
OpParameter
();
auto
addn_param
=
new
(
std
::
nothrow
)
OpParameter
();
if
(
parameter
==
nullptr
)
{
if
(
addn_param
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new OpParameter fail!"
;
MS_LOG
(
ERROR
)
<<
"new OpParameter fail!"
;
return
nullptr
;
return
nullptr
;
}
}
parameter
->
type_
=
primitive
->
Type
();
addn_param
->
type_
=
primitive
->
Type
();
return
parameter
;
return
reinterpret_cast
<
OpParameter
*>
(
addn_param
)
;
}
}
PriorBox
Parameter
*
PopulatePriorBoxParameter
(
const
lite
::
Primitive
*
primitive
)
{
Op
Parameter
*
PopulatePriorBoxParameter
(
const
lite
::
Primitive
*
primitive
)
{
PriorBoxParameter
*
param
=
new
(
std
::
nothrow
)
PriorBoxParameter
();
PriorBoxParameter
*
p
rior_box_p
aram
=
new
(
std
::
nothrow
)
PriorBoxParameter
();
if
(
param
==
nullptr
)
{
if
(
p
rior_box_p
aram
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PriorBoxParameter failed."
;
MS_LOG
(
ERROR
)
<<
"new PriorBoxParameter failed."
;
return
nullptr
;
return
nullptr
;
}
}
param
->
op_parameter_
.
type_
=
primitive
->
Type
();
p
rior_box_p
aram
->
op_parameter_
.
type_
=
primitive
->
Type
();
auto
prior_box_
param
=
primitive
->
Value
()
->
value_as_PriorBox
();
auto
prior_box_
attr
=
primitive
->
Value
()
->
value_as_PriorBox
();
if
(
prior_box_
param
->
min_sizes
()
->
size
()
>
PRIOR_BOX_MAX_NUM
)
{
if
(
prior_box_
attr
->
min_sizes
()
->
size
()
>
PRIOR_BOX_MAX_NUM
)
{
MS_LOG
(
ERROR
)
<<
"PriorBox min_sizes size exceeds max num "
<<
PRIOR_BOX_MAX_NUM
<<
", got "
MS_LOG
(
ERROR
)
<<
"PriorBox min_sizes size exceeds max num "
<<
PRIOR_BOX_MAX_NUM
<<
", got "
<<
prior_box_
param
->
min_sizes
();
<<
prior_box_
attr
->
min_sizes
();
delete
(
param
);
delete
(
p
rior_box_p
aram
);
return
nullptr
;
return
nullptr
;
}
}
p
aram
->
min_sizes_size
=
prior_box_param
->
min_sizes
()
->
size
();
p
rior_box_param
->
min_sizes_size
=
prior_box_attr
->
min_sizes
()
->
size
();
if
(
prior_box_
param
->
max_sizes
()
->
size
()
>
PRIOR_BOX_MAX_NUM
)
{
if
(
prior_box_
attr
->
max_sizes
()
->
size
()
>
PRIOR_BOX_MAX_NUM
)
{
MS_LOG
(
ERROR
)
<<
"PriorBox max_sizes size exceeds max num "
<<
PRIOR_BOX_MAX_NUM
<<
", got "
MS_LOG
(
ERROR
)
<<
"PriorBox max_sizes size exceeds max num "
<<
PRIOR_BOX_MAX_NUM
<<
", got "
<<
prior_box_
param
->
max_sizes
();
<<
prior_box_
attr
->
max_sizes
();
delete
(
param
);
delete
(
p
rior_box_p
aram
);
return
nullptr
;
return
nullptr
;
}
}
p
aram
->
max_sizes_size
=
prior_box_param
->
max_sizes
()
->
size
();
p
rior_box_param
->
max_sizes_size
=
prior_box_attr
->
max_sizes
()
->
size
();
(
void
)
memcpy
(
p
aram
->
max_sizes
,
prior_box_param
->
max_sizes
()
->
data
(),
(
void
)
memcpy
(
p
rior_box_param
->
max_sizes
,
prior_box_attr
->
max_sizes
()
->
data
(),
prior_box_
param
->
max_sizes
()
->
size
()
*
sizeof
(
int32_t
));
prior_box_
attr
->
max_sizes
()
->
size
()
*
sizeof
(
int32_t
));
(
void
)
memcpy
(
p
aram
->
min_sizes
,
prior_box_param
->
min_sizes
()
->
data
(),
(
void
)
memcpy
(
p
rior_box_param
->
min_sizes
,
prior_box_attr
->
min_sizes
()
->
data
(),
prior_box_
param
->
min_sizes
()
->
size
()
*
sizeof
(
int32_t
));
prior_box_
attr
->
min_sizes
()
->
size
()
*
sizeof
(
int32_t
));
if
(
prior_box_
param
->
aspect_ratios
()
->
size
()
>
PRIOR_BOX_MAX_NUM
)
{
if
(
prior_box_
attr
->
aspect_ratios
()
->
size
()
>
PRIOR_BOX_MAX_NUM
)
{
MS_LOG
(
ERROR
)
<<
"PriorBox aspect_ratios size exceeds max num "
<<
PRIOR_BOX_MAX_NUM
<<
", got "
MS_LOG
(
ERROR
)
<<
"PriorBox aspect_ratios size exceeds max num "
<<
PRIOR_BOX_MAX_NUM
<<
", got "
<<
prior_box_
param
->
aspect_ratios
();
<<
prior_box_
attr
->
aspect_ratios
();
delete
(
param
);
delete
(
p
rior_box_p
aram
);
return
nullptr
;
return
nullptr
;
}
}
p
aram
->
aspect_ratios_size
=
prior_box_param
->
aspect_ratios
()
->
size
();
p
rior_box_param
->
aspect_ratios_size
=
prior_box_attr
->
aspect_ratios
()
->
size
();
(
void
)
memcpy
(
p
aram
->
aspect_ratios
,
prior_box_param
->
aspect_ratios
()
->
data
(),
(
void
)
memcpy
(
p
rior_box_param
->
aspect_ratios
,
prior_box_attr
->
aspect_ratios
()
->
data
(),
prior_box_
param
->
aspect_ratios
()
->
size
()
*
sizeof
(
float
));
prior_box_
attr
->
aspect_ratios
()
->
size
()
*
sizeof
(
float
));
if
(
prior_box_
param
->
variances
()
->
size
()
!=
PRIOR_BOX_VAR_NUM
)
{
if
(
prior_box_
attr
->
variances
()
->
size
()
!=
PRIOR_BOX_VAR_NUM
)
{
MS_LOG
(
ERROR
)
<<
"PriorBox variances size should be "
<<
PRIOR_BOX_VAR_NUM
<<
", got "
MS_LOG
(
ERROR
)
<<
"PriorBox variances size should be "
<<
PRIOR_BOX_VAR_NUM
<<
", got "
<<
prior_box_
param
->
variances
()
->
size
();
<<
prior_box_
attr
->
variances
()
->
size
();
delete
(
param
);
delete
(
p
rior_box_p
aram
);
return
nullptr
;
return
nullptr
;
}
}
(
void
)
memcpy
(
p
aram
->
variances
,
prior_box_param
->
variances
()
->
data
(),
PRIOR_BOX_VAR_NUM
*
sizeof
(
float
));
(
void
)
memcpy
(
p
rior_box_param
->
variances
,
prior_box_attr
->
variances
()
->
data
(),
PRIOR_BOX_VAR_NUM
*
sizeof
(
float
));
p
aram
->
flip
=
prior_box_param
->
flip
();
p
rior_box_param
->
flip
=
prior_box_attr
->
flip
();
p
aram
->
clip
=
prior_box_param
->
clip
();
p
rior_box_param
->
clip
=
prior_box_attr
->
clip
();
p
aram
->
offset
=
prior_box_param
->
offset
();
p
rior_box_param
->
offset
=
prior_box_attr
->
offset
();
p
aram
->
image_size_h
=
prior_box_param
->
image_size_h
();
p
rior_box_param
->
image_size_h
=
prior_box_attr
->
image_size_h
();
p
aram
->
image_size_w
=
prior_box_param
->
image_size_w
();
p
rior_box_param
->
image_size_w
=
prior_box_attr
->
image_size_w
();
p
aram
->
step_h
=
prior_box_param
->
step_h
();
p
rior_box_param
->
step_h
=
prior_box_attr
->
step_h
();
p
aram
->
step_w
=
prior_box_param
->
step_w
();
p
rior_box_param
->
step_w
=
prior_box_attr
->
step_w
();
return
param
;
return
reinterpret_cast
<
OpParameter
*>
(
prior_box_param
)
;
}
}
SpaceToBatchParameter
*
PopulateSpaceToBatchParam
(
const
lite
::
Primitive
*
primitive
)
{
PopulateParameterRegistry
::
PopulateParameterRegistry
()
{
SpaceToBatchParameter
*
parameter
=
new
(
std
::
nothrow
)
SpaceToBatchParameter
();
populate_parameter_funcs_
[
schema
::
PrimitiveType_SoftMax
]
=
PopulateSoftmaxParameter
;
if
(
parameter
==
nullptr
)
{
populate_parameter_funcs_
[
schema
::
PrimitiveType_Activation
]
=
PopulateActivationParameter
;
MS_LOG
(
ERROR
)
<<
"new SpaceToBatchParameter failed."
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Conv2D
]
=
PopulateConvParameter
;
return
nullptr
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Reduce
]
=
PopulateReduceParameter
;
}
populate_parameter_funcs_
[
schema
::
PrimitiveType_Pooling
]
=
PopulatePoolingParameter
;
parameter
->
op_parameter_
.
type_
=
primitive
->
Type
();
populate_parameter_funcs_
[
schema
::
PrimitiveType_DepthwiseConv2D
]
=
PopulateConvDwParameter
;
auto
block_sizes
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
BlockSizes
();
populate_parameter_funcs_
[
schema
::
PrimitiveType_DeDepthwiseConv2D
]
=
PopulateDeconvDwParameter
;
(
void
)
memcpy
(
parameter
->
block_sizes_
,
(
block_sizes
.
data
()),
block_sizes
.
size
()
*
sizeof
(
int
));
populate_parameter_funcs_
[
schema
::
PrimitiveType_DeConv2D
]
=
PopulateDeconvParameter
;
auto
paddings
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
Paddings
();
populate_parameter_funcs_
[
schema
::
PrimitiveType_FusedBatchNorm
]
=
PopulateFusedBatchNorm
;
(
void
)
memcpy
(
parameter
->
paddings_
,
(
paddings
.
data
()),
paddings
.
size
()
*
sizeof
(
int
));
populate_parameter_funcs_
[
schema
::
PrimitiveType_FullConnection
]
=
PopulateFullconnectionParameter
;
auto
in_shape
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
InShape
();
populate_parameter_funcs_
[
schema
::
PrimitiveType_Power
]
=
PopulatePowerParameter
;
(
void
)
memcpy
(
parameter
->
in_shape_
,
(
in_shape
.
data
()),
in_shape
.
size
()
*
sizeof
(
int
));
populate_parameter_funcs_
[
schema
::
PrimitiveType_LocalResponseNormalization
]
=
PopulateLocalResponseNormParameter
;
auto
padded_in_shape
=
((
lite
::
SpaceToBatch
*
)
primitive
)
->
PaddedInShape
();
populate_parameter_funcs_
[
schema
::
PrimitiveType_Range
]
=
PopulateRangeParameter
;
(
void
)
memcpy
(
parameter
->
padded_in_shape_
,
(
padded_in_shape
.
data
()),
padded_in_shape
.
size
()
*
sizeof
(
int
));
populate_parameter_funcs_
[
schema
::
PrimitiveType_Transpose
]
=
PopulateTransposeParameter
;
return
parameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Mul
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Add
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Sub
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Div
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_FloorDiv
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_FloorMod
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_SquaredDifference
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_BiasAdd
]
=
PopulateArithmetic
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Eltwise
]
=
PopulateEltwiseParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ExpandDims
]
=
PopulateExpandDimsParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Abs
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Cos
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Sin
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Exp
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Log
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Square
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Sqrt
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Rsqrt
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_LogicalNot
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Floor
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Ceil
]
=
PopulateArithmeticSelf
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ArgMax
]
=
PopulateArgMaxParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ArgMin
]
=
PopulateArgMinParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Cast
]
=
PopulateCastParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Scale
]
=
PopulateScaleParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Reshape
]
=
PopulateReshapeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Concat
]
=
PopulateConcatParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Tile
]
=
PopulateTileParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_TopK
]
=
PopulateTopKParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Fill
]
=
PopulateFillParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Gather
]
=
PopulateGatherParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_GatherNd
]
=
PopulateGatherNdParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Slice
]
=
PopulateSliceParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_BroadcastTo
]
=
PopulateBroadcastToParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Reverse
]
=
PopulateReverseParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Stack
]
=
PopulateStackParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Unstack
]
=
PopulateUnstackParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ReverseSequence
]
=
PopulateReverseSequenceParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Unique
]
=
PopulateUniqueParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_DepthToSpace
]
=
PopulateDepthToSpaceParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Nchw2Nhwc
]
=
PopulateNchw2NhwcParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Nhwc2Nchw
]
=
PopulateNhwc2NchwParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Pad
]
=
PopulatePadParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Resize
]
=
PopulateResizeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_BatchToSpace
]
=
PopulateBatchToSpaceParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_SpaceToDepth
]
=
PopulateSpaceToDepthParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_SpaceToBatch
]
=
PopulateSpaceToBatchParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Crop
]
=
PopulateCropParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Unsqueeze
]
=
PopulateUnsqueezeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Flatten
]
=
PopulateFlattenParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_MatMul
]
=
PopulateMatMulParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_OneHot
]
=
PopulateOneHotParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_AddN
]
=
PopulateAddNParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_StridedSlice
]
=
PopulateStridedSliceParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_ScatterND
]
=
PopulateScatterNDParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Square
]
=
PopulateSqueezeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_Split
]
=
PopulateSplitParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_PriorBox
]
=
PopulatePriorBoxParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_OnnxInt8Dequantize
]
=
PopulateDequantizeParameter
;
populate_parameter_funcs_
[
schema
::
PrimitiveType_OnnxInt8Quantize
]
=
PopulateQuantizeParameter
;
}
PopulateParameterRegistry
*
PopulateParameterRegistry
::
GetInstance
()
{
static
PopulateParameterRegistry
populate_parameter_instance
;
return
&
populate_parameter_instance
;
}
PopulateParameterFunc
PopulateParameterRegistry
::
GetParameterFunc
(
const
schema
::
PrimitiveType
&
type
)
{
return
populate_parameter_funcs_
[
type
];
}
}
OpParameter
*
PopulateParameter
(
const
lite
::
Primitive
*
primitive
)
{
OpParameter
*
PopulateParameter
(
const
lite
::
Primitive
*
primitive
)
{
MS_EXCEPTION_IF_NULL
(
primitive
);
if
(
primitive
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Primitive is nullptr when populating parameter for op."
;
return
nullptr
;
}
auto
op_type
=
primitive
->
Type
();
auto
op_type
=
primitive
->
Type
();
switch
(
op_type
)
{
auto
func
=
PopulateParameterRegistry
::
GetInstance
()
->
GetParameterFunc
(
op_type
);
case
schema
::
PrimitiveType_SoftMax
:
if
(
func
==
nullptr
)
{
return
reinterpret_cast
<
OpParameter
*>
(
PopulateSoftmaxParameter
(
primitive
));
MS_LOG
(
ERROR
)
<<
"Get nullptr for Op Parameter Func."
;
case
schema
::
PrimitiveType_Activation
:
return
nullptr
;
return
reinterpret_cast
<
OpParameter
*>
(
PopulateActivationParameter
(
primitive
));
case
schema
::
PrimitiveType_Conv2D
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateConvParameter
(
primitive
));
case
schema
::
PrimitiveType_Reduce
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateReduceParameter
(
primitive
));
case
schema
::
PrimitiveType_Pooling
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulatePoolingParam
(
primitive
));
case
schema
::
PrimitiveType_DepthwiseConv2D
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateConvDwParameter
(
primitive
));
case
schema
::
PrimitiveType_DeDepthwiseConv2D
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateDeconvDwParameter
(
primitive
));
case
schema
::
PrimitiveType_DeConv2D
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateDeconvParameter
(
primitive
));
case
schema
::
PrimitiveType_FusedBatchNorm
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateFusedBatchNorm
(
primitive
));
case
schema
::
PrimitiveType_FullConnection
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateFullconnectionParameter
(
primitive
));
case
schema
::
PrimitiveType_Power
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulatePowerParameter
(
primitive
));
case
schema
::
PrimitiveType_LocalResponseNormalization
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateLocalResponseNormParameter
(
primitive
));
case
schema
::
PrimitiveType_Range
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateRangeParameter
(
primitive
));
case
schema
::
PrimitiveType_Transpose
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateTransposeParameter
(
primitive
));
case
schema
::
PrimitiveType_Mul
:
case
schema
::
PrimitiveType_Add
:
case
schema
::
PrimitiveType_Sub
:
case
schema
::
PrimitiveType_Div
:
case
schema
::
PrimitiveType_FloorDiv
:
case
schema
::
PrimitiveType_FloorMod
:
case
schema
::
PrimitiveType_SquaredDifference
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateArithmetic
(
primitive
));
case
schema
::
PrimitiveType_BiasAdd
:
return
reinterpret_cast
<
OpParameter
*>
(
new
ArithmeticParameter
());
case
schema
::
PrimitiveType_Eltwise
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateEltwiseParam
(
primitive
));
case
schema
::
PrimitiveType_ExpandDims
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateExpandDimsParam
(
primitive
));
case
schema
::
PrimitiveType_Abs
:
case
schema
::
PrimitiveType_Cos
:
case
schema
::
PrimitiveType_Sin
:
case
schema
::
PrimitiveType_Exp
:
case
schema
::
PrimitiveType_Log
:
case
schema
::
PrimitiveType_Square
:
case
schema
::
PrimitiveType_Sqrt
:
case
schema
::
PrimitiveType_Rsqrt
:
case
schema
::
PrimitiveType_LogicalNot
:
case
schema
::
PrimitiveType_Floor
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateArithmeticSelf
(
primitive
));
case
schema
::
PrimitiveType_ArgMax
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateArgMaxParam
(
primitive
));
case
schema
::
PrimitiveType_ArgMin
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateArgMinParam
(
primitive
));
case
schema
::
PrimitiveType_Cast
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateCastParam
(
primitive
));
case
schema
::
PrimitiveType_Ceil
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateCeilParameter
(
primitive
));
case
schema
::
PrimitiveType_Scale
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateScaleParameter
(
primitive
));
case
schema
::
PrimitiveType_Reshape
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateReshapeParam
(
primitive
));
case
schema
::
PrimitiveType_Concat
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateConcatParameter
(
primitive
));
case
schema
::
PrimitiveType_Tile
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateTileParameter
(
primitive
));
case
schema
::
PrimitiveType_TopK
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateTopKParameter
(
primitive
));
case
schema
::
PrimitiveType_Fill
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateFillParam
(
primitive
));
case
schema
::
PrimitiveType_Gather
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateGatherParameter
(
primitive
));
case
schema
::
PrimitiveType_GatherNd
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateGatherNdParameter
(
primitive
));
case
schema
::
PrimitiveType_Slice
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateSliceParam
(
primitive
));
case
schema
::
PrimitiveType_BroadcastTo
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateBroadcastToParam
(
primitive
));
case
schema
::
PrimitiveType_Reverse
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateReverseParameter
(
primitive
));
case
schema
::
PrimitiveType_Stack
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateStackParam
(
primitive
));
case
schema
::
PrimitiveType_Unstack
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateUnstackParam
(
primitive
));
case
schema
::
PrimitiveType_ReverseSequence
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateReverseSequenceParam
(
primitive
));
case
schema
::
PrimitiveType_Unique
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateUniqueParam
(
primitive
));
case
schema
::
PrimitiveType_DepthToSpace
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateDepthToSpaceParam
(
primitive
));
case
schema
::
PrimitiveType_Nchw2Nhwc
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateNchw2NhwcParameter
(
primitive
));
case
schema
::
PrimitiveType_Nhwc2Nchw
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateNhwc2NchwParameter
(
primitive
));
case
schema
::
PrimitiveType_Pad
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulatePadParameter
(
primitive
));
case
schema
::
PrimitiveType_Resize
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateResizeParameter
(
primitive
));
case
schema
::
PrimitiveType_BatchToSpace
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateBatchToSpaceParameter
(
primitive
));
case
schema
::
PrimitiveType_Crop
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateCropParameter
(
primitive
));
case
schema
::
PrimitiveType_Unsqueeze
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateUnsqueezeParameter
(
primitive
));
case
schema
::
PrimitiveType_Flatten
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateFlattenParameter
(
primitive
));
case
schema
::
PrimitiveType_MatMul
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateMatMulParameter
(
primitive
));
case
schema
::
PrimitiveType_OneHot
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateOneHotParameter
(
primitive
));
case
schema
::
PrimitiveType_AddN
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateAddNParam
(
primitive
));
case
schema
::
PrimitiveType_PriorBox
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulatePriorBoxParameter
(
primitive
));
case
schema
::
PrimitiveType_OnnxInt8Dequantize
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateDequantizeParameter
(
primitive
));
case
schema
::
PrimitiveType_OnnxInt8Quantize
:
return
reinterpret_cast
<
OpParameter
*>
(
PopulateQuantizeParameter
(
primitive
));
default:
break
;
}
}
auto
*
parameter
=
func
(
primitive
);
if
(
parameter
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Get nullptr for Op Parameter."
;
return
nullptr
;
return
nullptr
;
}
return
parameter
;
}
}
}
// namespace mindspore::kernel
}
// namespace mindspore::kernel
mindspore/lite/src/populate_parameter.h
浏览文件 @
ecb87385
...
@@ -22,7 +22,20 @@
...
@@ -22,7 +22,20 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/op_base.h"
namespace
mindspore
::
kernel
{
namespace
mindspore
::
kernel
{
typedef
OpParameter
*
(
*
PopulateParameterFunc
)(
const
lite
::
Primitive
*
);
class
PopulateParameterRegistry
{
public:
PopulateParameterRegistry
();
~
PopulateParameterRegistry
()
=
default
;
static
PopulateParameterRegistry
*
GetInstance
();
PopulateParameterFunc
GetParameterFunc
(
const
schema
::
PrimitiveType
&
type
);
protected:
PopulateParameterFunc
populate_parameter_funcs_
[
schema
::
PrimitiveType_MAX
+
1
];
};
OpParameter
*
PopulateParameter
(
const
lite
::
Primitive
*
primitive
);
OpParameter
*
PopulateParameter
(
const
lite
::
Primitive
*
primitive
);
}
// namespace mindspore::kernel
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_
#endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_
mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h
浏览文件 @
ecb87385
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
struct
ArithmeticParameter
{
struct
ArithmeticParameter
{
OpParameter
op_parameter
;
OpParameter
op_parameter
_
;
bool
broadcasting_
;
bool
broadcasting_
;
size_t
ndim_
;
size_t
ndim_
;
int
in_shape0_
[
5
];
int
in_shape0_
[
5
];
...
...
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h
浏览文件 @
ecb87385
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/op_base.h"
struct
SoftmaxParameter
{
struct
SoftmaxParameter
{
OpParameter
op_parameter
;
OpParameter
op_parameter
_
;
int32_t
axis_
;
int32_t
axis_
;
int
element_size_
;
int
element_size_
;
int
n_dim_
;
int
n_dim_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录