Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f8e4ab86
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f8e4ab86
编写于
8月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4113 Add fused_activation function for Sub, Add, Mul and Div op
Merge pull request !4113 from wangminggui/master
上级
a7185d7e
4327fd7e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
359 addition
and
68 deletion
+359
-68
mindspore/lite/schema/ops.fbs
mindspore/lite/schema/ops.fbs
+4
-4
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+17
-0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
+18
-21
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
+45
-9
mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h
...ore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h
+1
-1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h
...spore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h
+1
-1
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc
...pore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc
+265
-32
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h
...spore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h
+8
-0
未找到文件。
mindspore/lite/schema/ops.fbs
浏览文件 @
f8e4ab86
...
...
@@ -384,19 +384,19 @@ table Eltwise {
}
table Add {
activationType
: ActivationType
;
activationType
: ActivationType = 0
;
}
table Sub {
activationType
: ActivationType
;
activationType
: ActivationType = 0
;
}
table Mul {
activationType
: ActivationType
;
activationType
: ActivationType = 0
;
}
table Div {
activationType
: ActivationType
;
activationType
: ActivationType = 0
;
}
table AddGrad {
...
...
mindspore/lite/src/populate_parameter.cc
浏览文件 @
f8e4ab86
...
...
@@ -510,6 +510,23 @@ OpParameter *PopulateArithmetic(const lite::Primitive *primitive) {
arithmetic_param
->
op_parameter_
.
type_
=
primitive
->
Type
();
arithmetic_param
->
broadcasting_
=
((
lite
::
Arithmetic
*
)
primitive
)
->
Broadcasting
();
arithmetic_param
->
ndim_
=
((
lite
::
Arithmetic
*
)
primitive
)
->
NDims
();
switch
(
primitive
->
Type
())
{
case
schema
::
PrimitiveType_Add
:
arithmetic_param
->
activation_type_
=
primitive
->
Value
()
->
value_as_Add
()
->
activationType
();
break
;
case
schema
::
PrimitiveType_Sub
:
arithmetic_param
->
activation_type_
=
primitive
->
Value
()
->
value_as_Sub
()
->
activationType
();
break
;
case
schema
::
PrimitiveType_Mul
:
arithmetic_param
->
activation_type_
=
primitive
->
Value
()
->
value_as_Mul
()
->
activationType
();
break
;
case
schema
::
PrimitiveType_Div
:
arithmetic_param
->
activation_type_
=
primitive
->
Value
()
->
value_as_Div
()
->
activationType
();
break
;
default:
arithmetic_param
->
activation_type_
=
0
;
break
;
}
auto
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
InShape0
();
(
void
)
memcpy
(
arithmetic_param
->
in_shape0_
,
static_cast
<
void
*>
(
tmp_shape
.
data
()),
tmp_shape
.
size
()
*
sizeof
(
int
));
tmp_shape
=
((
lite
::
Arithmetic
*
)
primitive
)
->
InShape1
();
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.cc
浏览文件 @
f8e4ab86
...
...
@@ -56,29 +56,26 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto
input1_data1
=
reinterpret_cast
<
float
*>
(
inputs_
[
1
]
->
Data
());
auto
output_data
=
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
());
auto
element_num
=
outputs_
[
0
]
->
ElementsNum
();
MS_ASSERT
(
thread_count_
!=
0
);
int
stride
=
UP_DIV
(
element_num
,
thread_count_
);
int
count
=
MSMIN
(
stride
,
element_num
-
stride
*
task_id
);
if
(
arithmetic_run_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"arithmetic_run function is nullptr!"
;
return
RET_ERROR
;
}
int
error_code
=
RET_OK
;
if
(
arithmeticParameter_
->
broadcasting_
)
{
if
(
arithmetic_broadcast_run_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"broadcasting_run function is nullptr!"
;
return
RET_ERROR
;
}
MS_ASSERT
(
thread_count_
!=
0
);
int
stride
=
UP_DIV
(
element_num
,
thread_count_
);
int
count
=
MSMIN
(
stride
,
element_num
-
stride
*
task_id
);
int
error_code
=
arithmetic_run_
(
tile_data0_
+
stride
*
task_id
,
tile_data1_
+
stride
*
task_id
,
output_data
+
stride
*
task_id
,
count
);
if
(
error_code
!=
RET_OK
)
{
return
RET_ERROR
;
}
}
else
if
(
arithmetic_run_
!=
nullptr
)
{
int
error_code
=
arithmetic_run_
(
input0_data
,
input1_data1
,
output_data
,
element_num
);
if
(
error_code
!=
RET_OK
)
{
return
RET_ERROR
;
}
error_code
=
arithmetic_run_
(
tile_data0_
+
stride
*
task_id
,
tile_data1_
+
stride
*
task_id
,
output_data
+
stride
*
task_id
,
count
);
}
else
{
MS_LOG
(
ERROR
)
<<
"arithmetic_run function is nullptr!"
;
error_code
=
arithmetic_run_
(
input0_data
+
stride
*
task_id
,
input1_data1
+
stride
*
task_id
,
output_data
+
stride
*
task_id
,
count
);
}
if
(
error_code
!=
RET_OK
)
{
return
RET_ERROR
;
}
return
RET_OK
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h
浏览文件 @
f8e4ab86
...
...
@@ -50,22 +50,59 @@ class ArithmeticCPUKernel : public LiteKernel {
ArithmeticCPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
const
lite
::
Context
*
ctx
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
),
thread_count_
(
ctx
->
thread_num_
)
{
arithmeticParameter_
=
reinterpret_cast
<
ArithmeticParameter
*>
(
parameter
);
switch
(
parameter
->
type_
)
{
case
PrimitiveType_Mul
:
arithmetic_run_
=
ElementMul
;
arithmetic_broadcast_run_
=
BroadcastMul
;
switch
(
arithmeticParameter_
->
activation_type_
)
{
case
schema
::
ActivationType_RELU
:
arithmetic_run_
=
ElementMulRelu
;
break
;
case
schema
::
ActivationType_RELU6
:
arithmetic_run_
=
ElementMulRelu6
;
break
;
default:
arithmetic_run_
=
ElementMul
;
break
;
}
break
;
case
PrimitiveType_Add
:
arithmetic_run_
=
ElementAdd
;
arithmetic_broadcast_run_
=
BroadcastAdd
;
switch
(
arithmeticParameter_
->
activation_type_
)
{
case
schema
::
ActivationType_RELU
:
arithmetic_run_
=
ElementAddRelu
;
break
;
case
schema
::
ActivationType_RELU6
:
arithmetic_run_
=
ElementAddRelu6
;
break
;
default:
arithmetic_run_
=
ElementAdd
;
break
;
}
break
;
case
PrimitiveType_Sub
:
arithmetic_run_
=
ElementSub
;
arithmetic_broadcast_run_
=
BroadcastSub
;
switch
(
arithmeticParameter_
->
activation_type_
)
{
case
schema
::
ActivationType_RELU
:
arithmetic_run_
=
ElementSubRelu
;
break
;
case
schema
::
ActivationType_RELU6
:
arithmetic_run_
=
ElementSubRelu6
;
break
;
default:
arithmetic_run_
=
ElementSub
;
break
;
}
break
;
case
PrimitiveType_Div
:
arithmetic_run_
=
ElementDiv
;
arithmetic_broadcast_run_
=
BroadcastDiv
;
switch
(
arithmeticParameter_
->
activation_type_
)
{
case
schema
::
ActivationType_RELU
:
arithmetic_run_
=
ElementDivRelu
;
break
;
case
schema
::
ActivationType_RELU6
:
arithmetic_run_
=
ElementDivRelu6
;
break
;
default:
arithmetic_run_
=
ElementDiv
;
break
;
}
break
;
case
PrimitiveType_LogicalAnd
:
arithmetic_run_
=
ElementLogicalAnd
;
...
...
@@ -125,7 +162,6 @@ class ArithmeticCPUKernel : public LiteKernel {
arithmetic_broadcast_run_
=
nullptr
;
break
;
}
arithmeticParameter_
=
reinterpret_cast
<
ArithmeticParameter
*>
(
parameter
);
}
~
ArithmeticCPUKernel
()
override
;
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/arithmetic_common.h
浏览文件 @
f8e4ab86
...
...
@@ -27,6 +27,7 @@ struct ArithmeticParameter {
OpParameter
op_parameter_
;
bool
broadcasting_
;
size_t
ndim_
;
int
activation_type_
;
int
in_shape0_
[
5
];
int
in_shape1_
[
5
];
int
out_shape_
[
5
];
...
...
@@ -49,4 +50,3 @@ void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t
ArithmeticParameter
*
param
);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/activation.h
浏览文件 @
f8e4ab86
...
...
@@ -47,7 +47,7 @@ inline int Relu6(const float *src, int length, float *dst) {
inline
int
LRelu
(
const
float
*
src
,
int
length
,
float
*
dst
,
float
alpha
)
{
for
(
int
i
=
0
;
i
<
length
;
++
i
)
{
dst
[
i
]
=
src
[
i
]
>
(
src
[
i
]
*
alpha
)
?
src
[
i
]
:
(
src
[
i
]
*
alpha
);
dst
[
i
]
=
src
[
i
]
>
0
?
src
[
i
]
:
(
src
[
i
]
*
alpha
);
}
return
NNACL_OK
;
}
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc
浏览文件 @
f8e4ab86
...
...
@@ -21,7 +21,7 @@ int ElementMul(float *input0, float *input1, float *output, int element_size) {
int
block_c4
=
element_size
-
block_mod
;
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vmulq_f32
(
vin0
,
vin1
);
...
...
@@ -43,6 +43,73 @@ int ElementMul(float *input0, float *input1, float *output, int element_size) {
return
NNACL_OK
;
}
int
ElementMulRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef ENABLE_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vmulq_f32
(
vin0
,
vin1
);
vout
=
vbslq_f32
(
vcgtq_f32
(
vout
,
zeros
),
vout
,
zeros
);
vst1q_f32
(
output
,
vout
);
#else
float
res
=
input0
[
0
]
*
input1
[
0
];
output
[
0
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
1
]
*
input1
[
1
];
output
[
1
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
2
]
*
input1
[
2
];
output
[
2
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
3
]
*
input1
[
3
];
output
[
3
]
=
res
>
0
?
res
:
0
;
#endif
input0
+=
C4NUM
;
input1
+=
C4NUM
;
output
+=
C4NUM
;
}
for
(
int
index
=
0
;
index
<
block_mod
;
++
index
)
{
float
res
=
input0
[
index
]
*
input1
[
index
];
output
[
index
]
=
res
>
0
?
res
:
0
;
}
return
NNACL_OK
;
}
int
ElementMulRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
float32x4_t
bounds
=
{
6
,
6
,
6
,
6
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef ENABLE_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vminq_f32
(
vmaxq_f32
(
vmulq_f32
(
vin0
,
vin1
),
zeros
),
bounds
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
MSMIN
(
MSMAX
(
input0
[
0
]
*
input1
[
0
],
0
),
6
);
output
[
1
]
=
MSMIN
(
MSMAX
(
input0
[
1
]
*
input1
[
1
],
0
),
6
);
output
[
2
]
=
MSMIN
(
MSMAX
(
input0
[
2
]
*
input1
[
2
],
0
),
6
);
output
[
3
]
=
MSMIN
(
MSMAX
(
input0
[
3
]
*
input1
[
3
],
0
),
6
);
#endif
input0
+=
C4NUM
;
input1
+=
C4NUM
;
output
+=
C4NUM
;
}
for
(
int
index
=
0
;
index
<
block_mod
;
++
index
)
{
output
[
index
]
=
MSMIN
(
MSMAX
(
input0
[
index
]
*
input1
[
index
],
0
),
6
);
}
return
NNACL_OK
;
}
int
BroadcastMul
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
)
{
TileDimensions
(
input0
,
input1
,
tile_input0
,
tile_input1
,
param
);
...
...
@@ -54,7 +121,7 @@ int ElementAdd(float *input0, float *input1, float *output, int element_size) {
int
block_c4
=
element_size
-
block_mod
;
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vaddq_f32
(
vin0
,
vin1
);
...
...
@@ -75,6 +142,72 @@ int ElementAdd(float *input0, float *input1, float *output, int element_size) {
return
NNACL_OK
;
}
int
ElementAddRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef ENABLE_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vaddq_f32
(
vin0
,
vin1
);
vout
=
vbslq_f32
(
vcgtq_f32
(
vout
,
zeros
),
vout
,
zeros
);
vst1q_f32
(
output
,
vout
);
#else
float
res
=
input0
[
0
]
+
input1
[
0
];
output
[
0
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
1
]
+
input1
[
1
];
output
[
1
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
2
]
+
input1
[
2
];
output
[
2
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
3
]
+
input1
[
3
];
output
[
3
]
=
res
>
0
?
res
:
0
;
#endif
input0
+=
C4NUM
;
input1
+=
C4NUM
;
output
+=
C4NUM
;
}
for
(
int
index
=
0
;
index
<
block_mod
;
++
index
)
{
float
res
=
input0
[
index
]
+
input1
[
index
];
output
[
index
]
=
res
>
0
?
res
:
0
;
}
return
NNACL_OK
;
}
int
ElementAddRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
float32x4_t
bounds
=
{
6
,
6
,
6
,
6
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef ENABLE_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vminq_f32
(
vmaxq_f32
(
vaddq_f32
(
vin0
,
vin1
),
zeros
),
bounds
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
MSMIN
(
MSMAX
(
input0
[
0
]
+
input1
[
0
],
0
),
6
);
output
[
1
]
=
MSMIN
(
MSMAX
(
input0
[
1
]
+
input1
[
1
],
0
),
6
);
output
[
2
]
=
MSMIN
(
MSMAX
(
input0
[
2
]
+
input1
[
2
],
0
),
6
);
output
[
3
]
=
MSMIN
(
MSMAX
(
input0
[
3
]
+
input1
[
3
],
0
),
6
);
#endif
input0
+=
C4NUM
;
input1
+=
C4NUM
;
output
+=
C4NUM
;
}
for
(
int
index
=
0
;
index
<
block_mod
;
++
index
)
{
output
[
index
]
=
MSMIN
(
MSMAX
(
input0
[
index
]
+
input1
[
index
],
0
),
6
);
}
return
NNACL_OK
;
}
int
ElementAddInt8
(
int8_t
*
input0
,
int8_t
*
input1
,
int8_t
*
output
,
int
element_size
)
{
for
(
int
i
=
0
;
i
<
element_size
;
i
++
)
{
output
[
i
]
=
input0
[
i
]
+
input1
[
i
];
...
...
@@ -99,7 +232,7 @@ int ElementSub(float *input0, float *input1, float *output, int element_size) {
int
block_c4
=
element_size
-
block_mod
;
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vsubq_f32
(
vin0
,
vin1
);
...
...
@@ -120,6 +253,72 @@ int ElementSub(float *input0, float *input1, float *output, int element_size) {
return
NNACL_OK
;
}
int
ElementSubRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef ENABLE_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vsubq_f32
(
vin0
,
vin1
);
vout
=
vbslq_f32
(
vcgtq_f32
(
vout
,
zeros
),
vout
,
zeros
);
vst1q_f32
(
output
,
vout
);
#else
float
res
=
input0
[
0
]
-
input1
[
0
];
output
[
0
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
1
]
-
input1
[
1
];
output
[
1
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
2
]
-
input1
[
2
];
output
[
2
]
=
res
>
0
?
res
:
0
;
res
=
input0
[
3
]
-
input1
[
3
];
output
[
3
]
=
res
>
0
?
res
:
0
;
#endif
input0
+=
C4NUM
;
input1
+=
C4NUM
;
output
+=
C4NUM
;
}
for
(
int
index
=
0
;
index
<
block_mod
;
++
index
)
{
float
res
=
input0
[
index
]
-
input1
[
index
];
output
[
index
]
=
res
>
0
?
res
:
0
;
}
return
NNACL_OK
;
}
int
ElementSubRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
float32x4_t
bounds
=
{
6
,
6
,
6
,
6
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef ENABLE_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vminq_f32
(
vmaxq_f32
(
vsubq_f32
(
vin0
,
vin1
),
zeros
),
bounds
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
MSMIN
(
MSMAX
(
input0
[
0
]
-
input1
[
0
],
0
),
6
);
output
[
1
]
=
MSMIN
(
MSMAX
(
input0
[
1
]
-
input1
[
1
],
0
),
6
);
output
[
2
]
=
MSMIN
(
MSMAX
(
input0
[
2
]
-
input1
[
2
],
0
),
6
);
output
[
3
]
=
MSMIN
(
MSMAX
(
input0
[
3
]
-
input1
[
3
],
0
),
6
);
#endif
input0
+=
C4NUM
;
input1
+=
C4NUM
;
output
+=
C4NUM
;
}
for
(
int
index
=
0
;
index
<
block_mod
;
++
index
)
{
output
[
index
]
=
MSMIN
(
MSMAX
(
input0
[
index
]
-
input1
[
index
],
0
),
6
);
}
return
NNACL_OK
;
}
int
BroadcastSub
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
)
{
TileDimensions
(
input0
,
input1
,
tile_input0
,
tile_input1
,
param
);
...
...
@@ -137,6 +336,27 @@ int ElementDiv(float *input0, float *input1, float *output, int element_size) {
return
NNACL_OK
;
}
int
ElementDivRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
for
(
int
i
=
0
;
i
<
element_size
;
i
++
)
{
if
(
input1
[
i
]
==
0
)
{
return
NNACL_ERRCODE_DIVISOR_ZERO
;
}
float
res
=
input0
[
i
]
/
input1
[
i
];
output
[
i
]
=
res
>
0
?
res
:
0
;
}
return
NNACL_OK
;
}
int
ElementDivRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
for
(
int
i
=
0
;
i
<
element_size
;
i
++
)
{
if
(
input1
[
i
]
==
0
)
{
return
NNACL_ERRCODE_DIVISOR_ZERO
;
}
output
[
i
]
=
MSMIN
(
MSMAX
(
input0
[
i
]
/
input1
[
i
],
0
),
6
);
}
return
NNACL_OK
;
}
int
BroadcastDiv
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
)
{
TileDimensions
(
input0
,
input1
,
tile_input0
,
tile_input1
,
param
);
...
...
@@ -179,11 +399,18 @@ int ElementLogicalAnd(float *input0, float *input1, float *output, int element_s
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
uint32x4_t
mask
=
vmovq_n_u32
((
uint32_t
(
1u
<<
31
)
-
1
));
uint32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
v
andq_f32
(
vin0
,
vin1
);
#ifdef
ENABL
E_NEON
uint32x4_t
vin0
=
vandq_u32
(
vreinterpretq_s32_f32
(
vld1q_f32
(
input0
)),
mask
);
uint32x4_t
vin1
=
vandq_u32
(
vreinterpretq_s32_f32
(
vld1q_f32
(
input1
)),
mask
);
float32x4_t
vout
=
v
bslq_f32
(
vceqq_u32
(
vandq_u32
(
vin0
,
vin1
),
zeros
),
vfalse
,
vtrue
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)((
bool
)(
input0
[
0
])
&
(
bool
)(
input1
[
0
]));
...
...
@@ -222,11 +449,18 @@ int ElementLogicalOr(float *input0, float *input1, float *output, int element_si
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef ENABLE_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
uint32x4_t
mask
=
vmovq_n_u32
((
uint32_t
(
1u
<<
31
)
-
1
));
uint32x4_t
zeros
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
v
orrq_f32
(
vin0
,
vin1
);
#ifdef
ENABL
E_NEON
uint32x4_t
vin0
=
vandq_u32
(
vreinterpretq_s32_f32
(
vld1q_f32
(
input0
)),
mask
);
uint32x4_t
vin1
=
vandq_u32
(
vreinterpretq_s32_f32
(
vld1q_f32
(
input1
)),
mask
);
float32x4_t
vout
=
v
bslq_f32
(
vceqq_u32
(
vorrq_u32
(
vin0
,
vin1
),
zeros
),
vfalse
,
vtrue
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)((
bool
)(
input0
[
0
])
|
(
bool
)(
input1
[
0
]));
...
...
@@ -255,7 +489,7 @@ int ElementMaximum(float *input0, float *input1, float *output, int element_size
int
block_c4
=
element_size
-
block_mod
;
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vmaxq_f32
(
vin0
,
vin1
);
...
...
@@ -287,7 +521,7 @@ int ElementMinimum(float *input0, float *input1, float *output, int element_size
int
block_c4
=
element_size
-
block_mod
;
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vminq_f32
(
vin0
,
vin1
);
...
...
@@ -317,15 +551,15 @@ int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *ti
int
ElementNotEqual
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vbslq_f32
(
vceqq_f
p
32
(
vin0
,
vin1
),
vfalse
,
vtrue
);
float32x4_t
vout
=
vbslq_f32
(
vceqq_f32
(
vin0
,
vin1
),
vfalse
,
vtrue
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)(
input0
[
0
]
!=
input1
[
0
]);
...
...
@@ -352,15 +586,15 @@ int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *t
int
ElementEqual
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vbslq_f32
(
vceqq_f
p
32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
float32x4_t
vout
=
vbslq_f32
(
vceqq_f32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)(
input0
[
0
]
==
input1
[
0
]);
...
...
@@ -387,15 +621,15 @@ int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile
int
ElementLess
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vbslq_f32
(
vcltq_f
p
32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
float32x4_t
vout
=
vbslq_f32
(
vcltq_f32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)(
input0
[
0
]
<
input1
[
0
]);
...
...
@@ -422,15 +656,15 @@ int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_
int
ElementLessEqual
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vbslq_f32
(
vcleq_f
p
32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
float32x4_t
vout
=
vbslq_f32
(
vcleq_f32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)(
input0
[
0
]
<=
input1
[
0
]);
...
...
@@ -457,15 +691,15 @@ int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *
int
ElementGreater
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vbslq_f32
(
vcgtq_f
p
32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
float32x4_t
vout
=
vbslq_f32
(
vcgtq_f32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)(
input0
[
0
]
>
input1
[
0
]);
...
...
@@ -492,15 +726,15 @@ int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *ti
int
ElementGreaterEqual
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
)
{
int
block_mod
=
element_size
%
C4NUM
;
int
block_c4
=
element_size
-
block_mod
;
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vtrue
=
{
1
,
1
,
1
,
1
};
float32x4_t
vfalse
=
{
0
,
0
,
0
,
0
};
#endif
for
(
int
index
=
0
;
index
<
block_c4
;
index
+=
C4NUM
)
{
#ifdef
US
E_NEON
#ifdef
ENABL
E_NEON
float32x4_t
vin0
=
vld1q_f32
(
input0
);
float32x4_t
vin1
=
vld1q_f32
(
input1
);
float32x4_t
vout
=
vbslq_f32
(
vcgeq_f
p
32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
float32x4_t
vout
=
vbslq_f32
(
vcgeq_f32
(
vin0
,
vin1
),
vtrue
,
vfalse
);
vst1q_f32
(
output
,
vout
);
#else
output
[
0
]
=
(
float
)(
input0
[
0
]
>=
input1
[
0
]);
...
...
@@ -523,4 +757,3 @@ int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, floa
TileDimensions
(
input0
,
input1
,
tile_input0
,
tile_input1
,
param
);
return
ElementGreaterEqual
(
tile_input0
,
tile_input1
,
output
,
element_size
);
}
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.h
浏览文件 @
f8e4ab86
...
...
@@ -24,20 +24,28 @@
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
int
ElementMul
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementMulRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementMulRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
BroadcastMul
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
);
int
ElementAdd
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementAddRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementAddRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
BroadcastAdd
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
);
int
BroadcastAddInt8
(
int8_t
*
input0
,
int8_t
*
input1
,
int8_t
*
tile_input0
,
int8_t
*
tile_input1
,
int8_t
*
output
,
int
element_size
,
ArithmeticParameter
*
param
);
int
ElementSub
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementSubRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementSubRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
BroadcastSub
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
);
int
ElementDiv
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementDivRelu
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
ElementDivRelu6
(
float
*
input0
,
float
*
input1
,
float
*
output
,
int
element_size
);
int
BroadcastDiv
(
float
*
input0
,
float
*
input1
,
float
*
tile_input0
,
float
*
tile_input1
,
float
*
output
,
int
element_size
,
ArithmeticParameter
*
param
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录