Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
8137d199
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8137d199
编写于
11月 28, 2018
作者:
qnqinan
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into develop
上级
c6c8d605
6249afe0
变更
42
隐藏空白更改
内联
并排
Showing
42 changed file
with
4171 addition
and
325 deletion
+4171
-325
CMakeLists.txt
CMakeLists.txt
+1
-4
src/common/types.cpp
src/common/types.cpp
+3
-0
src/common/types.h
src/common/types.h
+1
-0
src/framework/executor.cpp
src/framework/executor.cpp
+1
-2
src/framework/load_ops.h
src/framework/load_ops.h
+4
-0
src/framework/operator.h
src/framework/operator.h
+0
-5
src/framework/tensor.h
src/framework/tensor.h
+0
-16
src/operators/dequantize_op.cpp
src/operators/dequantize_op.cpp
+1
-1
src/operators/fusion_dequant_add_bn_relu_op.cpp
src/operators/fusion_dequant_add_bn_relu_op.cpp
+40
-0
src/operators/fusion_dequant_add_bn_relu_op.h
src/operators/fusion_dequant_add_bn_relu_op.h
+76
-0
src/operators/kernel/arm/conv_kernel.cpp
src/operators/kernel/arm/conv_kernel.cpp
+65
-1
src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp
src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp
+116
-0
src/operators/kernel/arm/dequantize_kernel.cpp
src/operators/kernel/arm/dequantize_kernel.cpp
+15
-11
src/operators/kernel/arm/quantize_kernel.cpp
src/operators/kernel/arm/quantize_kernel.cpp
+518
-12
src/operators/kernel/central-arm-func/conv_add_arm_func.h
src/operators/kernel/central-arm-func/conv_add_arm_func.h
+1
-1
src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h
...ators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h
+1
-1
src/operators/kernel/central-arm-func/conv_arm_func.h
src/operators/kernel/central-arm-func/conv_arm_func.h
+83
-24
src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h
...ators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h
+1
-1
src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h
...operators/kernel/central-arm-func/conv_bn_relu_arm_func.h
+3
-1
src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h
...erators/kernel/central-arm-func/depthwise_conv_arm_func.h
+2
-3
src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h
...erators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h
+3
-1
src/operators/kernel/conv_add_kernel.h
src/operators/kernel/conv_add_kernel.h
+1
-1
src/operators/kernel/dequant_add_bn_relu_kernel.h
src/operators/kernel/dequant_add_bn_relu_kernel.h
+37
-0
src/operators/math/depthwise_conv3x3.cpp
src/operators/math/depthwise_conv3x3.cpp
+52
-33
src/operators/math/depthwise_conv3x3.h
src/operators/math/depthwise_conv3x3.h
+87
-0
src/operators/math/depthwise_conv3x3_int8.cpp
src/operators/math/depthwise_conv3x3_int8.cpp
+1207
-0
src/operators/math/gemm.cpp
src/operators/math/gemm.cpp
+0
-73
src/operators/math/im2col.cpp
src/operators/math/im2col.cpp
+96
-25
src/operators/math/pad.cpp
src/operators/math/pad.cpp
+7
-5
src/operators/math/pad.h
src/operators/math/pad.h
+3
-2
src/operators/math/winograd/winograd_transform.h
src/operators/math/winograd/winograd_transform.h
+42
-0
src/operators/math/winograd/winograd_transform_f6k3.cpp
src/operators/math/winograd/winograd_transform_f6k3.cpp
+1366
-0
src/operators/op_param.h
src/operators/op_param.h
+73
-12
src/operators/quantize_op.cpp
src/operators/quantize_op.cpp
+5
-2
test/CMakeLists.txt
test/CMakeLists.txt
+5
-5
test/framework/test_load_memory.cpp
test/framework/test_load_memory.cpp
+2
-1
test/net/test_benchmark.cpp
test/net/test_benchmark.cpp
+64
-0
test/net/test_googlenet.cpp
test/net/test_googlenet.cpp
+3
-5
test/operators/test_conv_op.cpp
test/operators/test_conv_op.cpp
+69
-41
test/operators/test_quantize_op.cpp
test/operators/test_quantize_op.cpp
+111
-34
tools/build.sh
tools/build.sh
+1
-1
tools/op.cmake
tools/op.cmake
+5
-1
未找到文件。
CMakeLists.txt
浏览文件 @
8137d199
...
@@ -5,6 +5,7 @@ option(DEBUGING "enable debug mode" ON)
...
@@ -5,6 +5,7 @@ option(DEBUGING "enable debug mode" ON)
option
(
USE_EXCEPTION
"use std exception"
ON
)
option
(
USE_EXCEPTION
"use std exception"
ON
)
option
(
SYMBOL_HIDDEN
"symbol hidden"
OFF
)
# on when use jni or ios io
option
(
SYMBOL_HIDDEN
"symbol hidden"
OFF
)
# on when use jni or ios io
option
(
LOG_PROFILE
"log profile"
OFF
)
option
(
LOG_PROFILE
"log profile"
OFF
)
# select the platform to build
# select the platform to build
option
(
CPU
"armv7 with neon"
ON
)
option
(
CPU
"armv7 with neon"
ON
)
option
(
GPU_MALI
"mali gpu"
OFF
)
option
(
GPU_MALI
"mali gpu"
OFF
)
...
@@ -15,7 +16,6 @@ if(FPGA)
...
@@ -15,7 +16,6 @@ if(FPGA)
option
(
FPGAV2
"fpga v2"
OFF
)
option
(
FPGAV2
"fpga v2"
OFF
)
endif
()
endif
()
project
(
paddle-mobile
)
project
(
paddle-mobile
)
file
(
GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm
)
file
(
GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm
)
...
@@ -247,6 +247,3 @@ elseif(FPGA)
...
@@ -247,6 +247,3 @@ elseif(FPGA)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
endif
()
endif
()
src/common/types.cpp
浏览文件 @
8137d199
...
@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
...
@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const
char
*
G_OP_TYPE_QUANTIZE
=
"quantize"
;
const
char
*
G_OP_TYPE_QUANTIZE
=
"quantize"
;
const
char
*
G_OP_TYPE_DEQUANTIZE
=
"dequantize"
;
const
char
*
G_OP_TYPE_DEQUANTIZE
=
"dequantize"
;
const
char
*
G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU
=
"fusion_dequant_add_bn_relu"
;
const
char
*
G_OP_TYPE_TANH
=
"tanh"
;
const
char
*
G_OP_TYPE_TANH
=
"tanh"
;
const
char
*
G_OP_TYPE_FUSION_DECONV_RELU
=
"fusion_deconv_relu"
;
const
char
*
G_OP_TYPE_FUSION_DECONV_RELU
=
"fusion_deconv_relu"
;
const
char
*
G_OP_TYPE_FUSION_DECONV_ADD
=
"fusion_deconv_add"
;
const
char
*
G_OP_TYPE_FUSION_DECONV_ADD
=
"fusion_deconv_add"
;
...
@@ -134,6 +136,7 @@ std::unordered_map<
...
@@ -134,6 +136,7 @@ std::unordered_map<
{
G_OP_TYPE_ELEMENTWISE_MUL
,
{{
"X"
,
"Y"
},
{
"Out"
}}},
{
G_OP_TYPE_ELEMENTWISE_MUL
,
{{
"X"
,
"Y"
},
{
"Out"
}}},
{
G_OP_TYPE_QUANTIZE
,
{{
"X"
},
{
"Out"
,
"OutScale"
}}},
{
G_OP_TYPE_QUANTIZE
,
{{
"X"
},
{
"Out"
,
"OutScale"
}}},
{
G_OP_TYPE_DEQUANTIZE
,
{{
"X"
,
"Scale"
},
{
"Out"
}}},
{
G_OP_TYPE_DEQUANTIZE
,
{{
"X"
,
"Scale"
},
{
"Out"
}}},
{
G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU
,
{{
"X"
,
"Scale"
},
{
"Out"
}}},
{
G_OP_TYPE_TANH
,
{{
"X"
},
{
"Out"
}}},
{
G_OP_TYPE_TANH
,
{{
"X"
},
{
"Out"
}}},
{
G_OP_TYPE_FUSION_DECONV_RELU
,
{{
"Input"
},
{
"Out"
}}},
{
G_OP_TYPE_FUSION_DECONV_RELU
,
{{
"Input"
},
{
"Out"
}}},
{
G_OP_TYPE_FUSION_DECONV_ADD
,
{{
"Input"
},
{
"Out"
}}},
{
G_OP_TYPE_FUSION_DECONV_ADD
,
{{
"Input"
},
{
"Out"
}}},
...
...
src/common/types.h
浏览文件 @
8137d199
...
@@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
...
@@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern
const
char
*
G_OP_TYPE_QUANTIZE
;
extern
const
char
*
G_OP_TYPE_QUANTIZE
;
extern
const
char
*
G_OP_TYPE_DEQUANTIZE
;
extern
const
char
*
G_OP_TYPE_DEQUANTIZE
;
extern
const
char
*
G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU
;
extern
const
char
*
G_OP_TYPE_TANH
;
extern
const
char
*
G_OP_TYPE_TANH
;
extern
const
char
*
G_OP_TYPE_FUSION_DECONV_RELU
;
extern
const
char
*
G_OP_TYPE_FUSION_DECONV_RELU
;
...
...
src/framework/executor.cpp
浏览文件 @
8137d199
...
@@ -30,7 +30,6 @@ limitations under the License. */
...
@@ -30,7 +30,6 @@ limitations under the License. */
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <queue>
#include <utility>
#include "common/threadpool.h"
#include "common/threadpool.h"
#endif
#endif
...
@@ -73,7 +72,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
...
@@ -73,7 +72,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
op
->
Type
(),
op
->
GetInputs
(),
op
->
GetOutputs
(),
op
->
GetAttrMap
(),
op
->
Type
(),
op
->
GetInputs
(),
op
->
GetOutputs
(),
op
->
GetAttrMap
(),
program_
.
scope
);
program_
.
scope
);
// infer shape to reshape tensor before predict,
// infer shape to reshape tensor before predict,
// but for lod tensor, it will need to reshape in runtime
// but for lod tensor, it will
still
need to reshape in runtime
if
(
!
loddable_
)
{
if
(
!
loddable_
)
{
op_base
->
InferShape
();
op_base
->
InferShape
();
}
}
...
...
src/framework/load_ops.h
浏览文件 @
8137d199
...
@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU);
...
@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
#ifdef DEQUANT_OP
LOAD_OP1
(
dequantize
,
CPU
);
LOAD_OP1
(
dequantize
,
CPU
);
#endif
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1
(
fusion_dequant_add_bn_relu
,
CPU
);
LOAD_FUSION_MATCHER
(
fusion_dequant_add_bn_relu
);
#endif
src/framework/operator.h
浏览文件 @
8137d199
...
@@ -127,11 +127,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
...
@@ -127,11 +127,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
virtual
void
InferShape
()
const
=
0
;
virtual
void
InferShape
()
const
=
0
;
void
Init
()
{
void
Init
()
{
// for (auto i : this->inputs_) {
// DLOG << i.first;
// DLOG << i.second;
// }
PADDLE_MOBILE_ENFORCE
(
kernel_
.
Init
(
&
param_
),
" %s kernel init failed"
,
PADDLE_MOBILE_ENFORCE
(
kernel_
.
Init
(
&
param_
),
" %s kernel init failed"
,
this
->
type_
.
c_str
());
this
->
type_
.
c_str
());
}
}
...
...
src/framework/tensor.h
浏览文件 @
8137d199
...
@@ -54,22 +54,6 @@ class Tensor : public TensorBase {
...
@@ -54,22 +54,6 @@ class Tensor : public TensorBase {
this
->
offset_
=
inTensor
.
offset_
;
this
->
offset_
=
inTensor
.
offset_
;
}
}
#ifdef PADDLE_MOBILE_DEBUG
template
<
typename
T
>
inline
void
dump
(
std
::
string
filename
)
const
{
const
T
*
dataptr
=
data
<
T
>
();
std
::
ofstream
out
(
filename
.
c_str
());
for
(
int
i
=
0
;
i
<
numel
();
++
i
)
{
out
<<
dataptr
[
i
]
<<
" "
;
}
out
<<
"形状:"
;
for
(
int
j
=
0
;
j
<
dims_
.
size
();
++
j
)
{
out
<<
dims_
[
j
]
<<
" "
;
}
out
.
close
();
}
#endif
/*! Resize the dimensions of the memory block. */
/*! Resize the dimensions of the memory block. */
inline
Tensor
&
Resize
(
const
DDim
&
dims
)
{
inline
Tensor
&
Resize
(
const
DDim
&
dims
)
{
dims_
=
dims
;
dims_
=
dims
;
...
...
src/operators/dequantize_op.cpp
浏览文件 @
8137d199
...
@@ -22,7 +22,7 @@ namespace operators {
...
@@ -22,7 +22,7 @@ namespace operators {
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
void
DequantizeOp
<
DeviceType
,
T
>::
InferShape
()
const
{
void
DequantizeOp
<
DeviceType
,
T
>::
InferShape
()
const
{
const
auto
&
input_dims
=
this
->
param_
.
input_
->
dims
();
const
auto
&
input_dims
=
this
->
param_
.
input_
->
dims
();
this
->
param_
.
out_
->
Resize
(
input_dims
);
this
->
param_
.
out
put
_
->
Resize
(
input_dims
);
}
}
}
// namespace operators
}
// namespace operators
...
...
test/operators/test_cov
_op.cpp
→
src/operators/fusion_dequant_add_bn_relu
_op.cpp
浏览文件 @
8137d199
...
@@ -12,33 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,33 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "../test_include.h"
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/conv_op.h"
#include "operators/fusion_dequant_add_bn_relu_op.h"
int
main
()
{
paddle_mobile
::
framework
::
Loader
<
paddle_mobile
::
GPU_MALI
>
loader
;
namespace
paddle_mobile
{
// ../models/image_classification_resnet.inference.model
namespace
operators
{
auto
program
=
loader
.
Load
(
g_googlenet
);
template
<
typename
Dtype
,
typename
T
>
PADDLE_MOBILE_ENFORCE
(
program
.
originProgram
!=
nullptr
,
void
FusionDequantAddBNReluOp
<
Dtype
,
T
>::
InferShape
()
const
{
"program file read fail"
);
const
auto
&
input_dims
=
this
->
param_
.
input_
->
dims
();
this
->
param_
.
output_
->
Resize
(
input_dims
);
Executor4Test
<
paddle_mobile
::
GPU_MALI
,
paddle_mobile
::
operators
::
ConvOp
<
paddle_mobile
::
GPU_MALI
,
float
>>
executor
(
program
,
"conv2d"
);
paddle_mobile
::
framework
::
Tensor
input
;
GetInput
<
float
>
(
g_test_image_1x3x224x224
,
&
input
,
{
1
,
3
,
224
,
224
});
// // use SetupTensor if not has local input image .
// SetupTensor<float>(&input, {1, 3, 224, 224}, static_cast<float>(0),
// static_cast<float>(1));
auto
out_ddim
=
paddle_mobile
::
framework
::
make_ddim
({
1
,
64
,
112
,
112
});
auto
output
=
executor
.
Predict
(
input
,
"data"
,
"conv2d_0.tmp_0"
,
out_ddim
);
auto
output_ptr
=
output
->
data
<
float
>
();
for
(
int
j
=
0
;
j
<
20
;
++
j
)
{
DLOG
<<
" value of output: "
<<
output_ptr
[
j
];
}
return
0
;
}
}
}
// namespace operators
}
// namespace paddle_mobile
namespace
ops
=
paddle_mobile
::
operators
;
REGISTER_FUSION_MATCHER
(
fusion_dequant_add_bn_relu
,
ops
::
FusionDequantAddBNReluMatcher
);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU
(
fusion_dequant_add_bn_relu
,
ops
::
FusionDequantAddBNReluOp
);
#endif
#endif
src/operators/fusion_dequant_add_bn_relu_op.h
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace
paddle_mobile
{
namespace
operators
{
class
FusionDequantAddBNReluMatcher
:
public
framework
::
FusionOpMatcher
{
public:
FusionDequantAddBNReluMatcher
()
{
node_
=
framework
::
Node
(
G_OP_TYPE_DEQUANTIZE
);
node_
>
std
::
make_shared
<
framework
::
Node
>
(
G_OP_TYPE_ELEMENTWISE_ADD
)
>
std
::
make_shared
<
framework
::
Node
>
(
G_OP_TYPE_BATCHNORM
)
>
std
::
make_shared
<
framework
::
Node
>
(
G_OP_TYPE_RELU
);
}
void
FolderNodes
(
framework
::
Node
*
node
,
std
::
vector
<
std
::
shared_ptr
<
framework
::
Node
>>
*
removed_nodes
)
{
node
->
Folder
(
node_
.
Depth
(),
Type
(),
{{
G_OP_TYPE_ELEMENTWISE_ADD
,
{{
"Y"
,
"Y"
}}},
{
G_OP_TYPE_BATCHNORM
,
{{
"Scale"
,
"BNScale"
},
{
"Mean"
,
"BNMean"
},
{
"Bias"
,
"BNBias"
},
{
"Variance"
,
"BNVariance"
}}}},
removed_nodes
);
}
std
::
string
Type
()
{
return
G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU
;
}
};
template
<
typename
DeviceType
,
typename
T
>
class
FusionDequantAddBNReluOp
:
public
framework
::
OperatorWithKernel
<
DeviceType
,
FusionDequantAddBNReluParam
<
DeviceType
>
,
operators
::
FusionDequantAddBNReluKernel
<
DeviceType
,
T
>>
{
public:
FusionDequantAddBNReluOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
,
std
::
shared_ptr
<
framework
::
Scope
>
scope
)
:
framework
::
OperatorWithKernel
<
DeviceType
,
FusionDequantAddBNReluParam
<
DeviceType
>
,
operators
::
FusionDequantAddBNReluKernel
<
DeviceType
,
T
>>
(
type
,
inputs
,
outputs
,
attrs
,
scope
)
{}
// inference output shape
void
InferShape
()
const
override
;
};
}
// namespace operators
}
// namespace paddle_mobile
#endif
src/operators/kernel/arm/conv_kernel.cpp
浏览文件 @
8137d199
...
@@ -22,12 +22,76 @@ namespace operators {
...
@@ -22,12 +22,76 @@ namespace operators {
template
<
>
template
<
>
bool
ConvKernel
<
CPU
,
float
>::
Init
(
ConvParam
<
CPU
>
*
param
)
{
bool
ConvKernel
<
CPU
,
float
>::
Init
(
ConvParam
<
CPU
>
*
param
)
{
if
(
param
->
Filter
()
->
type
()
==
typeid
(
int8_t
))
{
if
(
param
->
Groups
()
==
param
->
Input
()
->
dims
()[
1
]
&&
param
->
Input
()
->
dims
()[
1
]
==
param
->
Output
()
->
dims
()[
1
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
3
&&
param
->
Strides
()[
0
]
<
3
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
])
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_INT8
;
}
else
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_GEMM_INT8
;
}
}
else
{
if
(
param
->
Groups
()
==
param
->
Input
()
->
dims
()[
1
]
&&
param
->
Input
()
->
dims
()[
1
]
==
param
->
Output
()
->
dims
()[
1
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
3
&&
param
->
Strides
()[
0
]
==
1
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
;
}
else
if
(
param
->
Groups
()
==
param
->
Input
()
->
dims
()[
1
]
&&
param
->
Input
()
->
dims
()[
1
]
==
param
->
Output
()
->
dims
()[
1
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
3
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
;
#ifndef __aarch64__
}
else
if
(
param
->
Filter
()
->
dims
()[
2
]
==
param
->
Filter
()
->
dims
()[
3
]
&&
param
->
Strides
()[
0
]
==
param
->
Strides
()[
1
]
&&
param
->
Dilations
()[
0
]
==
param
->
Dilations
()[
1
]
&&
param
->
Filter
()
->
dims
()[
2
]
==
3
&&
param
->
Strides
()[
0
]
==
1
&&
param
->
Dilations
()[
0
]
==
1
&&
param
->
Output
()
->
dims
()[
1
]
>=
16
&&
param
->
Input
()
->
dims
()[
1
]
>=
16
&&
param
->
Input
()
->
dims
()[
2
]
<=
140
/* refered from ncnn */
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
;
// transform weight
framework
::
Tensor
*
transformed_weight
=
new
framework
::
Tensor
;
operators
::
math
::
winograd_transform_weight
<
8
,
3
>
(
*
param
->
Filter
(),
transformed_weight
);
param
->
Filter
()
=
transformed_weight
;
#endif
}
else
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
;
}
}
return
true
;
return
true
;
}
}
template
<
>
template
<
>
void
ConvKernel
<
CPU
,
float
>::
Compute
(
const
ConvParam
<
CPU
>
&
param
)
{
void
ConvKernel
<
CPU
,
float
>::
Compute
(
const
ConvParam
<
CPU
>
&
param
)
{
ConvCompute
<
float
>
(
param
);
switch
(
param
.
ExecMode
())
{
case
ConvParam
<
CPU
>::
EXEC_GEMM_INT8
:
GemmConv
<
int8_t
,
int32_t
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_INT8
:
DepthwiseConv3x3
<
int8_t
,
int32_t
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3S1P1_FLOAT
:
math
::
DepthwiseConv3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
);
break
;
case
ConvParam
<
CPU
>::
EXEC_DEPTHWISE3x3_FLOAT
:
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
break
;
case
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
:
WinogradConv3x3
<
8
,
3
>
(
param
);
break
;
case
ConvParam
<
CPU
>::
EXEC_GEMM_FLOAT
:
GemmConv
<
float
,
float
>
(
param
);
break
;
default:
PADDLE_MOBILE_THROW_EXCEPTION
(
"Invalid convolution execute mode %d"
,
param
.
ExecMode
());
}
}
}
template
class
ConvKernel
<
CPU
,
float
>;
template
class
ConvKernel
<
CPU
,
float
>;
...
...
src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace
paddle_mobile
{
namespace
operators
{
template
<
>
bool
FusionDequantAddBNReluKernel
<
CPU
,
float
>::
Init
(
FusionDequantAddBNReluParam
<
CPU
>
*
param
)
{
// elementwise add params
const
Tensor
*
bias
=
param
->
bias_
;
// batch norm params
const
Tensor
*
bn_mean
=
param
->
bn_mean_
;
const
Tensor
*
bn_variance
=
param
->
bn_variance_
;
Tensor
*
bn_scale
=
param
->
bn_scale_
;
Tensor
*
bn_bias
=
param
->
bn_bias_
;
const
float
epsilon
=
param
->
epsilon_
;
const
float
*
bias_ptr
=
bias
->
data
<
float
>
();
const
float
*
mean_ptr
=
bn_mean
->
data
<
float
>
();
const
float
*
var_ptr
=
bn_variance
->
data
<
float
>
();
float
*
bn_scale_ptr
=
bn_scale
->
mutable_data
<
float
>
();
float
*
bn_bias_ptr
=
bn_bias
->
mutable_data
<
float
>
();
for
(
int
c
=
0
;
c
<
bn_scale
->
numel
();
++
c
)
{
float
inv_scale
=
bn_scale_ptr
[
c
]
/
(
std
::
sqrt
(
var_ptr
[
c
]
+
epsilon
));
bn_scale_ptr
[
c
]
=
inv_scale
;
bn_bias_ptr
[
c
]
=
inv_scale
*
(
bias_ptr
[
c
]
-
mean_ptr
[
c
])
+
bn_bias_ptr
[
c
];
}
return
true
;
}
template
<
>
void
FusionDequantAddBNReluKernel
<
CPU
,
float
>::
Compute
(
const
FusionDequantAddBNReluParam
<
CPU
>
&
param
)
{
const
int32_t
*
input
=
param
.
input_
->
data
<
int32_t
>
();
const
float
*
bn_scale
=
param
.
bn_scale_
->
data
<
float
>
();
const
float
*
bn_bias
=
param
.
bn_bias_
->
data
<
float
>
();
// dequantize params
const
float
activation_scale
=
param
.
activation_scale_
->
data
<
float
>
()[
0
];
const
float
weight_scale
=
param
.
weight_scale_
;
const
float
dequant_scale
=
activation_scale
/
weight_scale
;
float
*
output
=
param
.
output_
->
mutable_data
<
float
>
();
int
batch_size
=
param
.
input_
->
dims
()[
0
];
int
channels
=
param
.
input_
->
dims
()[
1
];
size_t
spatial_size
=
param
.
input_
->
dims
()[
2
]
*
param
.
input_
->
dims
()[
3
];
#pragma omp parallel for collapse(2)
for
(
int
batch
=
0
;
batch
<
batch_size
;
++
batch
)
{
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
float
scale
=
bn_scale
[
c
]
*
dequant_scale
;
float
bias
=
bn_bias
[
c
];
size_t
offset
=
(
batch
*
channels
+
c
)
*
spatial_size
;
const
int32_t
*
x
=
input
+
offset
;
float
*
y
=
output
+
offset
;
size_t
remain
=
spatial_size
;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int
loop
=
spatial_size
>>
4
;
remain
=
spatial_size
&
0xF
;
float32x4_t
__scale
=
vdupq_n_f32
(
scale
);
float32x4_t
__bias
=
vdupq_n_f32
(
bias
);
float32x4_t
__zero
=
vdupq_n_f32
(
0.
f
);
for
(
int
k
=
0
;
k
<
loop
;
++
k
,
x
+=
16
,
y
+=
16
)
{
int32x4_t
r0
=
vld1q_s32
(
x
);
int32x4_t
r1
=
vld1q_s32
(
x
+
4
);
int32x4_t
r2
=
vld1q_s32
(
x
+
8
);
int32x4_t
r3
=
vld1q_s32
(
x
+
12
);
float32x4_t
f0
=
vcvtq_f32_s32
(
r0
);
float32x4_t
f1
=
vcvtq_f32_s32
(
r1
);
float32x4_t
f2
=
vcvtq_f32_s32
(
r2
);
float32x4_t
f3
=
vcvtq_f32_s32
(
r3
);
f0
=
vmlaq_f32
(
__bias
,
__scale
,
f0
);
f1
=
vmlaq_f32
(
__bias
,
__scale
,
f1
);
f2
=
vmlaq_f32
(
__bias
,
__scale
,
f2
);
f3
=
vmlaq_f32
(
__bias
,
__scale
,
f3
);
f0
=
vmaxq_f32
(
__zero
,
f0
);
f1
=
vmaxq_f32
(
__zero
,
f1
);
f2
=
vmaxq_f32
(
__zero
,
f2
);
f3
=
vmaxq_f32
(
__zero
,
f3
);
vst1q_f32
(
y
,
f0
);
vst1q_f32
(
y
+
4
,
f1
);
vst1q_f32
(
y
+
8
,
f2
);
vst1q_f32
(
y
+
12
,
f3
);
}
#endif // __ARM_NEON__
for
(
int
k
=
0
;
k
<
remain
;
++
k
)
{
y
[
k
]
=
std
::
max
(
scale
*
x
[
k
]
+
bias
,
0.
f
);
}
}
}
}
}
// namespace operators
}
// namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
src/operators/kernel/arm/dequantize_kernel.cpp
浏览文件 @
8137d199
...
@@ -31,7 +31,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
...
@@ -31,7 +31,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template
<
>
template
<
>
void
DequantizeKernel
<
CPU
,
float
>::
Compute
(
const
DequantizeParam
<
CPU
>
&
param
)
{
void
DequantizeKernel
<
CPU
,
float
>::
Compute
(
const
DequantizeParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
input_
;
const
Tensor
*
input
=
param
.
input_
;
Tensor
*
output
=
param
.
out_
;
Tensor
*
output
=
param
.
out
put
_
;
float
activation_scale
=
param
.
activation_scale_
->
data
<
float
>
()[
0
];
float
activation_scale
=
param
.
activation_scale_
->
data
<
float
>
()[
0
];
float
weight_scale
=
param
.
weight_scale_
;
float
weight_scale
=
param
.
weight_scale_
;
const
int32_t
*
x
=
input
->
data
<
const
int32_t
>
();
const
int32_t
*
x
=
input
->
data
<
const
int32_t
>
();
...
@@ -43,11 +43,15 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> ¶m) {
...
@@ -43,11 +43,15 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> ¶m) {
size_t
loop
=
size
>>
4
;
size_t
loop
=
size
>>
4
;
size_t
remain
=
size
&
0xF
;
size_t
remain
=
size
&
0xF
;
float32x4_t
s
=
vdupq_n_f32
(
scale
);
float32x4_t
s
=
vdupq_n_f32
(
scale
);
#pragma omp parallel for
for
(
size_t
i
=
0
;
i
<
loop
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
loop
;
++
i
)
{
int32x4_t
r0
=
vld1q_s32
(
x
);
const
int32_t
*
local_x
=
x
+
(
i
<<
4
);
int32x4_t
r1
=
vld1q_s32
(
x
+
4
);
float
*
local_y
=
y
+
(
i
<<
4
);
int32x4_t
r2
=
vld1q_s32
(
x
+
8
);
int32x4_t
r0
=
vld1q_s32
(
local_x
);
int32x4_t
r3
=
vld1q_s32
(
x
+
12
);
int32x4_t
r1
=
vld1q_s32
(
local_x
+
4
);
int32x4_t
r2
=
vld1q_s32
(
local_x
+
8
);
int32x4_t
r3
=
vld1q_s32
(
local_x
+
12
);
float32x4_t
f0
=
vcvtq_f32_s32
(
r0
);
float32x4_t
f0
=
vcvtq_f32_s32
(
r0
);
float32x4_t
f1
=
vcvtq_f32_s32
(
r1
);
float32x4_t
f1
=
vcvtq_f32_s32
(
r1
);
float32x4_t
f2
=
vcvtq_f32_s32
(
r2
);
float32x4_t
f2
=
vcvtq_f32_s32
(
r2
);
...
@@ -56,14 +60,14 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> ¶m) {
...
@@ -56,14 +60,14 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> ¶m) {
f1
=
vmulq_f32
(
f1
,
s
);
f1
=
vmulq_f32
(
f1
,
s
);
f2
=
vmulq_f32
(
f2
,
s
);
f2
=
vmulq_f32
(
f2
,
s
);
f3
=
vmulq_f32
(
f3
,
s
);
f3
=
vmulq_f32
(
f3
,
s
);
vst1q_f32
(
y
,
f0
);
vst1q_f32
(
local_y
,
f0
);
vst1q_f32
(
y
+
4
,
f1
);
vst1q_f32
(
local_y
+
4
,
f1
);
vst1q_f32
(
y
+
8
,
f2
);
vst1q_f32
(
local_y
+
8
,
f2
);
vst1q_f32
(
y
+
12
,
f3
);
vst1q_f32
(
local_y
+
12
,
f3
);
x
+=
16
;
y
+=
16
;
}
}
size
=
remain
;
size
=
remain
;
x
+=
(
loop
<<
4
);
y
+=
(
loop
<<
4
);
#endif
#endif
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
y
[
i
]
=
x
[
i
]
*
scale
;
y
[
i
]
=
x
[
i
]
*
scale
;
...
...
src/operators/kernel/arm/quantize_kernel.cpp
浏览文件 @
8137d199
...
@@ -21,15 +21,15 @@ limitations under the License. */
...
@@ -21,15 +21,15 @@ limitations under the License. */
#include <arm_neon.h>
#include <arm_neon.h>
#ifndef __aarch64__
#ifndef __aarch64__
float32_t
vmaxvq_f32
(
float32x4_t
r
)
{
inline
float32_t
vmaxvq_f32
(
float32x4_t
r
)
{
float32x2_t
v
=
vmax_f32
(
vget_high_f32
(
r
),
vget_low_f32
(
r
));
float32x2_t
v
=
vmax_f32
(
vget_high_f32
(
r
),
vget_low_f32
(
r
));
return
vget_lane_f32
(
vpmax_f32
(
v
,
v
),
0
);
return
vget_lane_f32
(
vpmax_f32
(
v
,
v
),
0
);
}
}
#endif
#endif
int32x4_t
vrnd_towards_zero
(
float32x4_t
r
)
{
return
vcvtq_s32_f32
(
r
);
}
in
line
in
t32x4_t
vrnd_towards_zero
(
float32x4_t
r
)
{
return
vcvtq_s32_f32
(
r
);
}
int32x4_t
vrnd_away_zero
(
float32x4_t
r
)
{
in
line
in
t32x4_t
vrnd_away_zero
(
float32x4_t
r
)
{
float32x4_t
plus
=
vdupq_n_f32
(
0.5
);
float32x4_t
plus
=
vdupq_n_f32
(
0.5
);
float32x4_t
minus
=
vdupq_n_f32
(
-
0.5
);
float32x4_t
minus
=
vdupq_n_f32
(
-
0.5
);
float32x4_t
zero
=
vdupq_n_f32
(
0
);
float32x4_t
zero
=
vdupq_n_f32
(
0
);
...
@@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) {
...
@@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) {
return
ret
;
return
ret
;
}
}
int32x4_t
vrnd_to_even
(
float32x4_t
r
)
{
in
line
in
t32x4_t
vrnd_to_even
(
float32x4_t
r
)
{
#if 0
#if 0
int32x4_t ret;
int32x4_t ret;
float value[4];
float value[4];
...
@@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) {
...
@@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) {
return
rnd
;
return
rnd
;
#endif
#endif
}
}
#endif
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
...
@@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) {
...
@@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) {
return
max_abs
;
return
max_abs
;
}
}
#ifdef __aarch64__
static
void
quantize_round_to_even
(
const
Tensor
*
input
,
const
float
scale
,
static
void
quantize_round_to_even
(
const
Tensor
*
input
,
const
float
scale
,
Tensor
*
output
)
{
Tensor
*
output
)
{
const
float
*
x
=
input
->
data
<
const
float
>
();
const
float
*
x
=
input
->
data
<
const
float
>
();
...
@@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
...
@@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const
float
*
x
=
input
->
data
<
const
float
>
();
const
float
*
x
=
input
->
data
<
const
float
>
();
int8_t
*
y
=
output
->
mutable_data
<
int8_t
>
();
int8_t
*
y
=
output
->
mutable_data
<
int8_t
>
();
size_t
size
=
input
->
numel
();
size_t
size
=
input
->
numel
();
#if
def
defined(__ARM_NEON__) || defined(__ARM_NEON)
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t
loop
=
size
>>
4
;
size_t
loop
=
size
>>
4
;
size_t
remain
=
size
&
0xF
;
size_t
remain
=
size
&
0xF
;
...
@@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
...
@@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
y
+=
(
loop
<<
4
);
y
+=
(
loop
<<
4
);
#endif
#endif
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
y
[
i
]
=
trunc
(
x
[
i
]
*
scale
);
y
[
i
]
=
static_cast
<
int8_t
>
(
x
[
i
]
*
scale
);
}
}
}
}
...
@@ -272,6 +272,508 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
...
@@ -272,6 +272,508 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
y
[
i
]
=
round
(
x
[
i
]
*
scale
);
y
[
i
]
=
round
(
x
[
i
]
*
scale
);
}
}
}
}
#else // __aarch64__
static
void
quantize_round_to_even
(
const
Tensor
*
input
,
const
float
scale
,
const
std
::
vector
<
int
>
&
paddings
,
const
int8_t
padding_val
,
Tensor
*
output
)
{}
static
void
quantize_round_to_nearest
(
const
Tensor
*
input
,
const
float
scale
,
const
std
::
vector
<
int
>
&
paddings
,
const
int8_t
padding_val
,
Tensor
*
output
)
{}
static
void
quantize_round_to_zero
(
const
Tensor
*
input
,
const
float
scale
,
const
std
::
vector
<
int
>
&
paddings
,
const
int8_t
padding_val
,
Tensor
*
output
)
{
int
channels
=
input
->
dims
()[
1
];
int
input_h
=
input
->
dims
()[
2
];
int
input_w
=
input
->
dims
()[
3
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
int
input_spatial_size
=
input_h
*
input_w
;
int
output_spatial_size
=
output_h
*
output_w
;
const
float
*
x
=
input
->
data
<
float
>
();
int8_t
*
y
=
output
->
mutable_data
<
int8_t
>
();
// valid area start
int
start
=
paddings
[
0
]
*
output_w
+
paddings
[
1
];
for
(
int
batch
=
0
;
batch
<
input
->
dims
()[
0
];
++
batch
)
{
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
channels
-
3
;
c
+=
4
)
{
const
float
*
input0
=
x
+
(
batch
*
channels
+
c
)
*
input_spatial_size
;
const
float
*
input1
=
input0
+
input_spatial_size
;
const
float
*
input2
=
input1
+
input_spatial_size
;
const
float
*
input3
=
input2
+
input_spatial_size
;
size_t
offset
=
(
batch
*
channels
+
c
)
*
output_spatial_size
;
for
(
int
h
=
0
;
h
<
2
;
++
h
)
{
int8_t
*
y0
=
y
+
offset
+
h
*
((
input_h
+
paddings
[
0
])
*
output_w
-
paddings
[
1
]);
int8_t
*
y1
=
y0
+
output_spatial_size
;
int8_t
*
y2
=
y1
+
output_spatial_size
;
int8_t
*
y3
=
y2
+
output_spatial_size
;
int
loop
=
start
>>
4
;
int
remain
=
start
&
0xF
;
asm
volatile
(
"vdup.s8 q0, %[val]
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
"store_16w_%=:
\n
"
"vst1.32 {q0}, [%[y0]]!
\n
"
"vst1.32 {q0}, [%[y1]]!
\n
"
"vst1.32 {q0}, [%[y2]]!
\n
"
"vst1.32 {q0}, [%[y3]]!
\n
"
"subs %[loop], #1
\n
"
"bne store_16w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #8
\n
"
"blt store_4w_%=
\n
"
"vst1.32 {d0}, [%[y0]]!
\n
"
"vst1.32 {d0}, [%[y1]]!
\n
"
"vst1.32 {d0}, [%[y2]]!
\n
"
"vst1.32 {d0}, [%[y3]]!
\n
"
"sub %[remain], #8
\n
"
"store_4w_%=:
\n
"
"cmp %[remain], #4
\n
"
"blt store_2w_%=
\n
"
"vst1.32 {d0[0]}, [%[y0]]!
\n
"
"vst1.32 {d0[0]}, [%[y1]]!
\n
"
"vst1.32 {d0[0]}, [%[y2]]!
\n
"
"vst1.32 {d0[0]}, [%[y3]]!
\n
"
"sub %[remain], #4
\n
"
"store_2w_%=:
\n
"
"cmp %[remain], #4
\n
"
"blt store_1w_%=
\n
"
"vst1.16 {d0[0]}, [%[y0]]!
\n
"
"vst1.16 {d0[0]}, [%[y1]]!
\n
"
"vst1.16 {d0[0]}, [%[y2]]!
\n
"
"vst1.16 {d0[0]}, [%[y3]]!
\n
"
"sub %[remain], #2
\n
"
"store_1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.8 {d0[0]}, [%[y0]]!
\n
"
"vst1.8 {d0[0]}, [%[y1]]!
\n
"
"vst1.8 {d0[0]}, [%[y2]]!
\n
"
"vst1.8 {d0[0]}, [%[y3]]!
\n
"
"end_%=:
\n
"
:
[
y0
]
"+r"
(
y0
),
[
y1
]
"+r"
(
y1
),
[
y2
]
"+r"
(
y2
),
[
y3
]
"+r"
(
y3
),
[
loop
]
"+r"
(
loop
),
[
remain
]
"+r"
(
remain
)
:
[
val
]
"r"
(
padding_val
)
:
"cc"
,
"memory"
,
"q0"
);
}
// quantize valid area
int8_t
*
y0
=
y
+
offset
+
start
;
int8_t
*
y1
=
y0
+
output_spatial_size
;
int8_t
*
y2
=
y1
+
output_spatial_size
;
int8_t
*
y3
=
y2
+
output_spatial_size
;
for
(
int
h
=
0
;
h
<
input_h
;
++
h
)
{
const
float
*
x0
=
input0
+
h
*
input_w
;
const
float
*
x1
=
input1
+
h
*
input_w
;
const
float
*
x2
=
input2
+
h
*
input_w
;
const
float
*
x3
=
input3
+
h
*
input_w
;
int
loop
=
input_w
>>
4
;
int
remain
=
input_w
&
0xF
;
int
pad_loop
=
paddings
[
1
]
>>
1
;
// (paddings[1] << 1) >> 2
int
pad_remain
=
(
paddings
[
1
]
<<
1
)
&
0x3
;
int
remain_steps
=
remain
;
asm
volatile
(
"vdup.f32 q0, %[scale]
\n
"
"cmp %[loop], #0
\n
"
"ble quantize_remain_%=
\n
"
"loop_quantize_%=:
\n
"
"vld1.32 {q1, q2}, [%[x0]]!
\n
"
"vld1.32 {q3, q4}, [%[x1]]!
\n
"
"vld1.32 {q5, q6}, [%[x2]]!
\n
"
"vld1.32 {q7, q8}, [%[x3]]!
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vmul.f32 q3, q3, q0
\n
"
"vmul.f32 q4, q4, q0
\n
"
"vmul.f32 q5, q5, q0
\n
"
"vmul.f32 q6, q6, q0
\n
"
"vmul.f32 q7, q7, q0
\n
"
"vmul.f32 q8, q8, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vcvt.s32.f32 q3, q3
\n
"
"vcvt.s32.f32 q4, q4
\n
"
"vcvt.s32.f32 q5, q5
\n
"
"vcvt.s32.f32 q6, q6
\n
"
"vcvt.s32.f32 q7, q7
\n
"
"vcvt.s32.f32 q8, q8
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s32 d4, q3
\n
"
"vmovn.s32 d5, q4
\n
"
"vmovn.s32 d6, q5
\n
"
"vmovn.s32 d7, q6
\n
"
"vmovn.s32 d8, q7
\n
"
"vmovn.s32 d9, q8
\n
"
"vmovn.s16 d18, q1
\n
"
"vmovn.s16 d20, q2
\n
"
"vmovn.s16 d22, q3
\n
"
"vmovn.s16 d24, q4
\n
"
"vld1.32 {q1, q2}, [%[x0]]!
\n
"
"vld1.32 {q3, q4}, [%[x1]]!
\n
"
"vld1.32 {q5, q6}, [%[x2]]!
\n
"
"vld1.32 {q7, q8}, [%[x3]]!
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vmul.f32 q3, q3, q0
\n
"
"vmul.f32 q4, q4, q0
\n
"
"vmul.f32 q5, q5, q0
\n
"
"vmul.f32 q6, q6, q0
\n
"
"vmul.f32 q7, q7, q0
\n
"
"vmul.f32 q8, q8, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vcvt.s32.f32 q3, q3
\n
"
"vcvt.s32.f32 q4, q4
\n
"
"vcvt.s32.f32 q5, q5
\n
"
"vcvt.s32.f32 q6, q6
\n
"
"vcvt.s32.f32 q7, q7
\n
"
"vcvt.s32.f32 q8, q8
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s32 d4, q3
\n
"
"vmovn.s32 d5, q4
\n
"
"vmovn.s32 d6, q5
\n
"
"vmovn.s32 d7, q6
\n
"
"vmovn.s32 d8, q7
\n
"
"vmovn.s32 d9, q8
\n
"
"vmovn.s16 d19, q1
\n
"
"vmovn.s16 d21, q2
\n
"
"vmovn.s16 d23, q3
\n
"
"vmovn.s16 d25, q4
\n
"
"vst1.32 {q9}, [%[y0]]!
\n
"
"vst1.32 {q10}, [%[y1]]!
\n
"
"vst1.32 {q11}, [%[y2]]!
\n
"
"vst1.32 {q12}, [%[y3]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_quantize_%=
\n
"
"quantize_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble end_%=
\n
"
"vld1.32 {q1, q2}, [%[x0]]!
\n
"
"vld1.32 {q3, q4}, [%[x1]]!
\n
"
"vld1.32 {q5, q6}, [%[x2]]!
\n
"
"vld1.32 {q7, q8}, [%[x3]]!
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vmul.f32 q3, q3, q0
\n
"
"vmul.f32 q4, q4, q0
\n
"
"vmul.f32 q5, q5, q0
\n
"
"vmul.f32 q6, q6, q0
\n
"
"vmul.f32 q7, q7, q0
\n
"
"vmul.f32 q8, q8, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vcvt.s32.f32 q3, q3
\n
"
"vcvt.s32.f32 q4, q4
\n
"
"vcvt.s32.f32 q5, q5
\n
"
"vcvt.s32.f32 q6, q6
\n
"
"vcvt.s32.f32 q7, q7
\n
"
"vcvt.s32.f32 q8, q8
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s32 d4, q3
\n
"
"vmovn.s32 d5, q4
\n
"
"vmovn.s32 d6, q5
\n
"
"vmovn.s32 d7, q6
\n
"
"vmovn.s32 d8, q7
\n
"
"vmovn.s32 d9, q8
\n
"
"vmovn.s16 d18, q1
\n
"
"vmovn.s16 d20, q2
\n
"
"vmovn.s16 d22, q3
\n
"
"vmovn.s16 d24, q4
\n
"
"vld1.32 {q1, q2}, [%[x0]]
\n
"
"vld1.32 {q3, q4}, [%[x1]]
\n
"
"vld1.32 {q5, q6}, [%[x2]]
\n
"
"vld1.32 {q7, q8}, [%[x3]]
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vmul.f32 q3, q3, q0
\n
"
"vmul.f32 q4, q4, q0
\n
"
"vmul.f32 q5, q5, q0
\n
"
"vmul.f32 q6, q6, q0
\n
"
"vmul.f32 q7, q7, q0
\n
"
"vmul.f32 q8, q8, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vcvt.s32.f32 q3, q3
\n
"
"vcvt.s32.f32 q4, q4
\n
"
"vcvt.s32.f32 q5, q5
\n
"
"vcvt.s32.f32 q6, q6
\n
"
"vcvt.s32.f32 q7, q7
\n
"
"vcvt.s32.f32 q8, q8
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s32 d4, q3
\n
"
"vmovn.s32 d5, q4
\n
"
"vmovn.s32 d6, q5
\n
"
"vmovn.s32 d7, q6
\n
"
"vmovn.s32 d8, q7
\n
"
"vmovn.s32 d9, q8
\n
"
"vmovn.s16 d19, q1
\n
"
"vmovn.s16 d21, q2
\n
"
"vmovn.s16 d23, q3
\n
"
"vmovn.s16 d25, q4
\n
"
"cmp %[remain], #8
\n
"
"blt store_4w_%=
\n
"
"vst1.32 {d18}, [%[y0]]!
\n
"
"vst1.32 {d20}, [%[y1]]!
\n
"
"vst1.32 {d22}, [%[y2]]!
\n
"
"vst1.32 {d24}, [%[y3]]!
\n
"
"vmov.32 d18, d19
\n
"
"vmov.32 d20, d21
\n
"
"vmov.32 d22, d23
\n
"
"vmov.32 d24, d25
\n
"
"sub %[remain], #8
\n
"
"store_4w_%=:
\n
"
"cmp %[remain], #4
\n
"
"blt store_2w_%=
\n
"
"vst1.32 {d18[0]}, [%[y0]]!
\n
"
"vst1.32 {d20[0]}, [%[y1]]!
\n
"
"vst1.32 {d22[0]}, [%[y2]]!
\n
"
"vst1.32 {d24[0]}, [%[y3]]!
\n
"
"vext.32 d18, d18, d18, #1
\n
"
"vext.32 d20, d20, d20, #1
\n
"
"vext.32 d22, d22, d22, #1
\n
"
"vext.32 d24, d24, d24, #1
\n
"
"sub %[remain], #4
\n
"
"store_2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_1w_%=
\n
"
"vst1.16 {d18[0]}, [%[y0]]!
\n
"
"vst1.16 {d20[0]}, [%[y1]]!
\n
"
"vst1.16 {d22[0]}, [%[y2]]!
\n
"
"vst1.16 {d24[0]}, [%[y3]]!
\n
"
"vext.16 d18, d18, d18, #1
\n
"
"vext.16 d20, d20, d20, #1
\n
"
"vext.16 d22, d22, d22, #1
\n
"
"vext.16 d24, d24, d24, #1
\n
"
"sub %[remain], #2
\n
"
"store_1w_%=:"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.8 {d18[0]}, [%[y0]]!
\n
"
"vst1.8 {d20[0]}, [%[y1]]!
\n
"
"vst1.8 {d22[0]}, [%[y2]]!
\n
"
"vst1.8 {d24[0]}, [%[y3]]!
\n
"
"end_%=:
\n
"
:
[
x0
]
"+r"
(
x0
),
[
x1
]
"+r"
(
x1
),
[
x2
]
"+r"
(
x2
),
[
x3
]
"+r"
(
x3
),
[
y0
]
"+r"
(
y0
),
[
y1
]
"+r"
(
y1
),
[
y2
]
"+r"
(
y2
),
[
y3
]
"+r"
(
y3
),
[
loop
]
"+r"
(
loop
),
[
remain
]
"+r"
(
remain
)
:
[
scale
]
"r"
(
scale
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
asm
volatile
(
"vdup.s8 d0, %[val]
\n
"
"cmp %[pad_loop], #0
\n
"
"ble store_pad_2w_%=
\n
"
"loop_pad_4w_%=:
\n
"
"vst1.32 {d0[0]}, [%[y0]]!
\n
"
"vst1.32 {d0[0]}, [%[y1]]!
\n
"
"vst1.32 {d0[0]}, [%[y2]]!
\n
"
"vst1.32 {d0[0]}, [%[y3]]!
\n
"
"subs %[pad_loop], #1
\n
"
"bne loop_pad_4w_%=
\n
"
"store_pad_2w_%=:
\n
"
"cmp %[pad_remain], #2
\n
"
"blt store_pad_1w_%=
\n
"
"vst1.16 {d0[0]}, [%[y0]]!
\n
"
"vst1.16 {d0[0]}, [%[y1]]!
\n
"
"vst1.16 {d0[0]}, [%[y2]]!
\n
"
"vst1.16 {d0[0]}, [%[y3]]!
\n
"
"sub %[pad_remain], #2
\n
"
"store_pad_1w_%=:
\n
"
"cmp %[pad_remain], #1
\n
"
"blt end_%=
\n
"
"vst1.8 {d0[0]}, [%[y0]]!
\n
"
"vst1.8 {d0[0]}, [%[y1]]!
\n
"
"vst1.8 {d0[0]}, [%[y2]]!
\n
"
"vst1.8 {d0[0]}, [%[y3]]!
\n
"
"end_%=:
\n
"
:
[
y0
]
"+r"
(
y0
),
[
y1
]
"+r"
(
y1
),
[
y2
]
"+r"
(
y2
),
[
y3
]
"+r"
(
y3
),
[
pad_loop
]
"+r"
(
pad_loop
),
[
pad_remain
]
"+r"
(
pad_remain
)
:
[
val
]
"r"
(
padding_val
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
);
}
}
for
(
int
c
=
(
channels
&
0xFFFC
);
c
<
channels
;
++
c
)
{
const
float
*
input0
=
x
+
(
batch
*
channels
+
c
)
*
input_spatial_size
;
size_t
offset
=
(
batch
*
channels
+
c
)
*
output_spatial_size
;
for
(
int
h
=
0
;
h
<
2
;
++
h
)
{
int8_t
*
y0
=
y
+
offset
+
h
*
((
input_h
+
paddings
[
0
])
*
output_w
-
paddings
[
1
]);
int
loop
=
start
>>
4
;
int
remain
=
start
&
0xF
;
asm
volatile
(
"vdup.s8 q0, %[val]
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
"store_16w_%=:
\n
"
"vst1.32 {q0}, [%[y0]]!
\n
"
"subs %[loop], #1
\n
"
"bne store_16w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #8
\n
"
"blt store_4w_%=
\n
"
"vst1.32 {d0}, [%[y0]]!
\n
"
"sub %[remain], #8
\n
"
"store_4w_%=:
\n
"
"cmp %[remain], #4
\n
"
"blt store_2w_%=
\n
"
"vst1.32 {d0[0]}, [%[y0]]!
\n
"
"sub %[remain], #4
\n
"
"store_2w_%=:
\n
"
"cmp %[remain], #4
\n
"
"blt store_1w_%=
\n
"
"vst1.16 {d0[0]}, [%[y0]]!
\n
"
"sub %[remain], #2
\n
"
"store_1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.8 {d0[0]}, [%[y0]]!
\n
"
"end_%=:
\n
"
:
[
y0
]
"+r"
(
y0
),
[
loop
]
"+r"
(
loop
),
[
remain
]
"+r"
(
remain
)
:
[
val
]
"r"
(
padding_val
)
:
"cc"
,
"memory"
,
"q0"
);
}
// quantize valid area
int8_t
*
y0
=
y
+
offset
+
start
;
for
(
int
h
=
0
;
h
<
input_h
;
++
h
)
{
const
float
*
x0
=
input0
+
h
*
input_w
;
int
loop
=
input_w
>>
4
;
int
remain
=
input_w
&
0xF
;
int
pad_loop
=
paddings
[
1
]
>>
1
;
// (paddings[1] << 1) >> 2
int
pad_remain
=
(
paddings
[
1
]
<<
1
)
&
0x3
;
asm
volatile
(
"vdup.f32 q0, %[scale]
\n
"
"cmp %[loop], #0
\n
"
"ble quantize_remain_%=
\n
"
"loop_quantize_%=:
\n
"
"vld1.32 {q1, q2}, [%[x0]]!
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s16 d18, q1
\n
"
"vld1.32 {q1, q2}, [%[x0]]!
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s16 d19, q1
\n
"
"vst1.32 {q9}, [%[y0]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_quantize_%=
\n
"
"quantize_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble start_pad_%=
\n
"
"vldm %[x0], {d2-d9}
\n
"
"vmul.f32 q1, q1, q0
\n
"
"vmul.f32 q2, q2, q0
\n
"
"vcvt.s32.f32 q1, q1
\n
"
"vcvt.s32.f32 q2, q2
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s16 d18, q1
\n
"
"vmul.f32 q3, q3, q0
\n
"
"vmul.f32 q4, q4, q0
\n
"
"vcvt.s32.f32 q1, q3
\n
"
"vcvt.s32.f32 q2, q4
\n
"
"vmovn.s32 d2, q1
\n
"
"vmovn.s32 d3, q2
\n
"
"vmovn.s16 d19, q1
\n
"
"cmp %[remain], #8
\n
"
"blt store_4w_%=
\n
"
"vst1.32 {d18}, [%[y0]]!
\n
"
"vmov.32 d18, d19
\n
"
"sub %[remain], #8
\n
"
"store_4w_%=:
\n
"
"cmp %[remain], #4
\n
"
"blt store_2w_%=
\n
"
"vst1.32 {d18[0]}, [%[y0]]!
\n
"
"vext.32 d18, d18, d18, #1
\n
"
"sub %[remain], #4
\n
"
"store_2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_1w_%=
\n
"
"vst1.16 {d18[0]}, [%[y0]]!
\n
"
"vext.16 d18, d18, d18, #1
\n
"
"sub %[remain], #2
\n
"
"store_1w_%=:"
"cmp %[remain], #1
\n
"
"blt start_pad_%=
\n
"
"vst1.8 {d18[0]}, [%[y0]]!
\n
"
"start_pad_%=:
\n
"
"vdup.s8 d0, %[val]
\n
"
"cmp %[pad_loop], #0
\n
"
"ble pad_remain_%=
\n
"
"loop_pad_4w_%=:
\n
"
"vst1.32 {d0[0]}, [%[y0]]!
\n
"
"subs %[pad_loop], #1
\n
"
"bne loop_pad_4w_%=
\n
"
"pad_remain_%=:
\n
"
"cmp %[pad_remain], #2
\n
"
"blt store_pad_1w_%=
\n
"
"vst1.16 {d0[0]}, [%[y0]]!
\n
"
"sub %[pad_remain], #2
\n
"
"store_pad_1w_%=:
\n
"
"cmp %[pad_remain], #1
\n
"
"blt end_%=
\n
"
"vst1.8 {d0[0]}, [%[y0]]!
\n
"
"end_%=:
\n
"
:
[
x0
]
"+r"
(
x0
),
[
y0
]
"+r"
(
y0
),
[
loop
]
"+r"
(
loop
),
[
remain
]
"+r"
(
remain
),
[
pad_loop
]
"+r"
(
pad_loop
),
[
pad_remain
]
"+r"
(
pad_remain
)
:
[
scale
]
"r"
(
scale
),
[
val
]
"r"
(
padding_val
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q9"
);
}
}
}
}
#endif // __aarch64__
#endif // ARM_NEON
template
<
>
template
<
>
bool
QuantizeKernel
<
CPU
,
float
>::
Init
(
QuantizeParam
<
CPU
>
*
param
)
{
bool
QuantizeKernel
<
CPU
,
float
>::
Init
(
QuantizeParam
<
CPU
>
*
param
)
{
...
@@ -280,10 +782,10 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
...
@@ -280,10 +782,10 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
template
<
>
template
<
>
void
QuantizeKernel
<
CPU
,
float
>::
Compute
(
const
QuantizeParam
<
CPU
>
&
param
)
{
void
QuantizeKernel
<
CPU
,
float
>::
Compute
(
const
QuantizeParam
<
CPU
>
&
param
)
{
float
max_abs
=
0.
f
;
const
Tensor
*
input
=
param
.
input_
;
const
Tensor
*
input
=
param
.
input_
;
Tensor
*
output
=
param
.
out_
;
Tensor
*
output
=
param
.
out
put
_
;
Tensor
*
output_scale
=
param
.
online_scale_
;
Tensor
*
output_scale
=
param
.
online_scale_
;
float
max_abs
=
0.
f
;
if
(
param
.
is_static_
)
{
if
(
param
.
is_static_
)
{
max_abs
=
param
.
static_scale_
;
max_abs
=
param
.
static_scale_
;
}
else
{
}
else
{
...
@@ -293,15 +795,19 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> ¶m) {
...
@@ -293,15 +795,19 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> ¶m) {
// only support int8 currently
// only support int8 currently
float
scale
=
127
/
max_abs
;
float
scale
=
127
/
max_abs
;
param
.
online_scale_
->
mutable_data
<
float
>
()[
0
]
=
max_abs
;
param
.
online_scale_
->
mutable_data
<
float
>
()[
0
]
=
max_abs
;
const
auto
&
paddings
=
param
.
paddings_
;
// std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t
padding_val
=
0
;
switch
(
param
.
round_type_
)
{
switch
(
param
.
round_type_
)
{
case
ROUND_NEAREST_TO_EVEN
:
case
ROUND_NEAREST_TO_EVEN
:
quantize_round_to_even
(
input
,
scale
,
output
);
quantize_round_to_even
(
input
,
scale
,
paddings
,
padding_val
,
output
);
break
;
break
;
case
ROUND_NEAREST_TOWARDS_ZERO
:
case
ROUND_NEAREST_TOWARDS_ZERO
:
quantize_round_to_zero
(
input
,
scale
,
output
);
quantize_round_to_zero
(
input
,
scale
,
paddings
,
padding_val
,
output
);
break
;
break
;
case
ROUND_NEAREST_AWAY_ZERO
:
case
ROUND_NEAREST_AWAY_ZERO
:
quantize_round_to_nearest
(
input
,
scale
,
output
);
quantize_round_to_nearest
(
input
,
scale
,
paddings
,
padding_val
,
output
);
break
;
break
;
default:
default:
LOG
(
kLOG_ERROR
)
<<
"round type is not supported."
;
LOG
(
kLOG_ERROR
)
<<
"round type is not supported."
;
...
...
src/operators/kernel/central-arm-func/conv_add_arm_func.h
浏览文件 @
8137d199
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
...
...
src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h
浏览文件 @
8137d199
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
...
...
src/operators/kernel/central-arm-func/conv_arm_func.h
浏览文件 @
8137d199
...
@@ -17,18 +17,19 @@ limitations under the License. */
...
@@ -17,18 +17,19 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
#include "operators/math/pad.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
#include "operators/math/winograd/winograd_transform.h"
#include "operators/op_param.h"
#include "operators/op_param.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
template
<
typename
Itype
,
typename
Otype
>
template
<
typename
Itype
,
typename
Otype
>
inline
void
ConvBasic
(
const
ConvParam
<
CPU
>
&
param
)
{
inline
void
GemmConv
(
const
ConvParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
input
=
param
.
Input
();
Tensor
filter
=
*
param
.
Filter
();
Tensor
filter
=
*
param
.
Filter
();
Tensor
*
output
=
param
.
Output
();
Tensor
*
output
=
param
.
Output
();
...
@@ -38,10 +39,7 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
...
@@ -38,10 +39,7 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
const
std
::
vector
<
int
>
paddings
=
param
.
Paddings
();
const
std
::
vector
<
int
>
paddings
=
param
.
Paddings
();
const
std
::
vector
<
int
>
dilations
=
param
.
Dilations
();
const
std
::
vector
<
int
>
dilations
=
param
.
Dilations
();
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
filter_shape_vec
(
framework
::
vectorize
(
filter
.
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output
->
dims
()));
std
::
vector
<
int64_t
>
output_shape_vec
(
framework
::
vectorize
(
output
->
dims
()));
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
size_t
data_dim
=
filter_shape_vec
.
size
()
-
2
;
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
std
::
vector
<
int64_t
>
col_shape_vec
(
1
+
2
*
data_dim
);
...
@@ -82,6 +80,7 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
...
@@ -82,6 +80,7 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
math
::
Vol2ColFunctor
<
CPU
,
Itype
>
vol2col
;
math
::
Vol2ColFunctor
<
CPU
,
Itype
>
vol2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
CPU
,
Itype
>
im2col
;
math
::
Im2ColFunctor
<
math
::
ColFormat
::
kCFO
,
CPU
,
Itype
>
im2col
;
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
).
Resize
(
input_shape
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
).
Resize
(
output_matrix_shape
);
...
@@ -99,7 +98,6 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
...
@@ -99,7 +98,6 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
std
::
vector
<
int
>
{
paddings
[
0
],
paddings
[
1
],
paddings
[
0
],
paddings
[
1
]},
paddings
[
1
]},
&
col
);
&
col
);
}
else
if
(
data_dim
==
3U
)
{
}
else
if
(
data_dim
==
3U
)
{
// vol2col
// vol2col
vol2col
(
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
vol2col
(
in_slice
,
dilations
,
strides
,
paddings
,
&
col
);
...
@@ -116,25 +114,86 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
...
@@ -116,25 +114,86 @@ inline void ConvBasic(const ConvParam<CPU> ¶m) {
}
}
}
}
template
<
typename
P
>
template
<
int
tile
,
int
kernel
>
void
ConvCompute
(
const
ConvParam
<
CPU
>
&
param
)
{
inline
void
WinogradConv3x3
(
const
ConvParam
<
CPU
>
&
param
)
{
if
(
param
.
Input
()
->
type
()
==
typeid
(
int8_t
))
{
const
Tensor
*
input
=
param
.
Input
();
ConvBasic
<
int8_t
,
int32_t
>
(
param
);
const
Tensor
*
filter
=
param
.
Filter
();
}
else
{
Tensor
*
output
=
param
.
Output
();
if
(
param
.
Groups
()
==
param
.
Input
()
->
dims
()[
1
]
&&
output
->
mutable_data
<
float
>
();
param
.
Input
()
->
dims
()[
1
]
==
param
.
Output
()
->
dims
()[
1
]
&&
int
batch_size
=
input
->
dims
()[
0
];
param
.
Filter
()
->
dims
()[
2
]
==
param
.
Filter
()
->
dims
()[
3
]
&&
int
groups
=
param
.
Groups
();
param
.
Filter
()
->
dims
()[
2
]
==
3
&&
param
.
Strides
()[
0
]
==
1
)
{
const
std
::
vector
<
int
>
&
paddings
=
param
.
Paddings
();
math
::
DepthwiseConv3x3s1p1
(
param
.
Input
(),
param
.
Filter
(),
param
.
Output
(),
nullptr
,
false
);
auto
winograd_pad
=
[
&
](
int
width
,
int
pad
)
{
}
else
if
(
param
.
Groups
()
==
param
.
Input
()
->
dims
()[
1
]
&&
int
output_tile
=
tile
-
kernel
+
1
;
param
.
Input
()
->
dims
()[
1
]
==
param
.
Output
()
->
dims
()[
1
]
&&
// int tiles = (width + pad - kernel) / output_tile + 1;
param
.
Filter
()
->
dims
()[
2
]
==
param
.
Filter
()
->
dims
()[
3
]
&&
// return (tiles - 1) * output_tile + tile - width;
param
.
Filter
()
->
dims
()[
2
]
==
3
)
{
int
pad_width
=
(
width
+
2
*
pad
-
kernel
)
/
output_tile
*
output_tile
;
math
::
DepthwiseConv3x3
(
param
.
Input
(),
param
.
Strides
(),
param
.
Paddings
(),
return
pad_width
+
tile
-
width
;
param
.
Filter
(),
nullptr
,
param
.
Output
(),
false
);
};
math
::
PadFunctor
<
CPU
,
float
>
pad
;
Tensor
input_pad
;
framework
::
Tensor
transformed_input
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
// int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]);
// int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]);
int
pad_bottom
=
paddings
[
0
];
int
pad_right
=
paddings
[
1
];
if
(
paddings
[
0
]
||
paddings
[
1
]
||
pad_bottom
||
pad_right
)
{
framework
::
DDim
pad_shape
=
in_batch
.
dims
();
pad_shape
[
2
]
+=
paddings
[
0
]
+
pad_bottom
;
pad_shape
[
3
]
+=
paddings
[
1
]
+
pad_right
;
input_pad
.
mutable_data
<
float
>
(
pad_shape
);
pad
(
in_batch
,
paddings
[
0
],
pad_bottom
,
paddings
[
1
],
pad_right
,
&
input_pad
);
}
else
{
input_pad
=
in_batch
;
}
// tile input and transform
math
::
winograd_transform_input
<
tile
,
kernel
>
(
input_pad
,
&
transformed_input
);
// caculate output
math
::
winograd_transform_output
<
tile
,
kernel
>
(
transformed_input
,
*
filter
,
output
);
}
}
template
<
typename
Itype
,
typename
Otype
>
inline
void
DepthwiseConv3x3
(
const
ConvParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
filter
=
param
.
Filter
();
Tensor
*
output
=
param
.
Output
();
output
->
mutable_data
<
Otype
>
();
const
std
::
vector
<
int
>
&
paddings
=
param
.
Paddings
();
const
std
::
vector
<
int
>
&
strides
=
param
.
Strides
();
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
Tensor
input_pad
;
math
::
PadFunctor
<
CPU
,
Itype
>
pad
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
Tensor
in_batch
=
input
->
Slice
(
i
,
i
+
1
);
Tensor
out_batch
=
output
->
Slice
(
i
,
i
+
1
);
if
(
paddings
[
0
]
||
paddings
[
1
])
{
framework
::
DDim
pad_shape
=
in_batch
.
dims
();
pad_shape
[
2
]
+=
2
*
paddings
[
0
];
pad_shape
[
3
]
+=
2
*
paddings
[
1
];
input_pad
.
mutable_data
<
float
>
(
pad_shape
);
pad
(
in_batch
,
paddings
[
0
],
paddings
[
0
],
paddings
[
1
],
paddings
[
1
],
&
input_pad
);
}
else
{
input_pad
=
in_batch
;
}
if
(
strides
[
0
]
==
1
)
{
math
::
DepthwiseConv3x3s1
<
Itype
,
Otype
>
(
input_pad
,
*
filter
,
&
out_batch
);
}
else
if
(
strides
[
0
]
==
2
)
{
math
::
DepthwiseConv3x3s2
<
Itype
,
Otype
>
(
input_pad
,
*
filter
,
&
out_batch
);
}
else
{
}
else
{
ConvBasic
<
float
,
float
>
(
param
);
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION
(
"Depthwise conv with generic strides has not been implemented."
);
}
}
}
}
}
}
...
...
src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h
浏览文件 @
8137d199
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
...
...
src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h
浏览文件 @
8137d199
...
@@ -16,13 +16,15 @@ limitations under the License. */
...
@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
#include "operators/op_param.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
void
ConvBNReluBasic
(
const
FusionConvBNReluParam
<
CPU
>
&
param
)
{
void
ConvBNReluBasic
(
const
FusionConvBNReluParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
input
=
param
.
Input
();
Tensor
filter
=
*
param
.
Filter
();
Tensor
filter
=
*
param
.
Filter
();
...
...
src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h
浏览文件 @
8137d199
...
@@ -15,10 +15,9 @@ limitations under the License. */
...
@@ -15,10 +15,9 @@ limitations under the License. */
#ifdef DEPTHWISECONV_OP
#ifdef DEPTHWISECONV_OP
#pragma once
#pragma once
#include <operators/math/depthwise_conv_3x3.h>
#include <vector>
#include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h"
#include "operators/op_param.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
...
@@ -44,7 +43,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> ¶m) {
...
@@ -44,7 +43,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> ¶m) {
Bias
,
false
);
Bias
,
false
);
}
else
{
}
else
{
ConvBasic
<
float
,
float
>
(
param
);
GemmConv
<
float
,
float
>
(
param
);
}
}
}
}
...
...
src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h
浏览文件 @
8137d199
...
@@ -16,13 +16,15 @@ limitations under the License. */
...
@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include <vector>
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
#include "operators/op_param.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
void
DWConvBNReluBasic
(
const
FusionDWConvBNReluParam
<
CPU
>
&
param
)
{
void
DWConvBNReluBasic
(
const
FusionDWConvBNReluParam
<
CPU
>
&
param
)
{
const
Tensor
*
input
=
param
.
Input
();
const
Tensor
*
input
=
param
.
Input
();
Tensor
filter
=
*
param
.
Filter
();
Tensor
filter
=
*
param
.
Filter
();
...
...
src/operators/kernel/conv_add_kernel.h
浏览文件 @
8137d199
...
@@ -24,7 +24,7 @@ limitations under the License. */
...
@@ -24,7 +24,7 @@ limitations under the License. */
#include "framework/ddim.h"
#include "framework/ddim.h"
#include "framework/operator.h"
#include "framework/operator.h"
#include "operators/math/conv_func.h"
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv
_
3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/math/vol2col.h"
...
...
src/operators/kernel/dequant_add_bn_relu_kernel.h
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "framework/operator.h"
#include "operators/op_param.h"
namespace
paddle_mobile
{
namespace
operators
{
template
<
typename
DeviceType
,
typename
T
>
class
FusionDequantAddBNReluKernel
:
public
framework
::
OpKernelBase
<
DeviceType
,
FusionDequantAddBNReluParam
<
DeviceType
>>
{
public:
void
Compute
(
const
FusionDequantAddBNReluParam
<
DeviceType
>
&
param
);
bool
Init
(
FusionDequantAddBNReluParam
<
DeviceType
>
*
param
);
};
}
// namespace operators
}
// namespace paddle_mobile
#endif
src/operators/math/depthwise_conv
_
3x3.cpp
→
src/operators/math/depthwise_conv3x3.cpp
浏览文件 @
8137d199
...
@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON
#if __ARM_NEON
#include <arm_neon.h>
#include <arm_neon.h>
#endif
#endif
#include <vector>
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
void
DepthwiseConv3x3
(
const
Tensor
*
input
,
vector
<
int
>
strides
,
vector
<
int
>
paddings
,
const
Tensor
*
filter
,
Tensor
*
bias
,
void
DepthwiseConv3x3
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
bool
if_bias
)
{
const
std
::
vector
<
int
>
&
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
bool
if_bias
)
{
const
int
batch_size
=
input
->
dims
()[
0
];
const
int
batch_size
=
input
->
dims
()[
0
];
const
int
input_height
=
input
->
dims
()[
2
];
const
int
input_height
=
input
->
dims
()[
2
];
...
@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
...
@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
for
(
int
pw
=
0
;
pw
<
output_width
;
pw
++
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
pw
++
)
{
hstart
=
ph
*
stride_height
-
padding_height
;
hstart
=
ph
*
stride_height
-
padding_height
;
wstart
=
pw
*
stride_width
-
padding_width
;
wstart
=
pw
*
stride_width
-
padding_width
;
hend
=
min
(
hstart
+
_kernel_size
,
input_height
+
padding_height
);
hend
=
std
::
min
(
hstart
+
_kernel_size
,
input_height
+
padding_height
);
wend
=
min
(
wstart
+
_kernel_size
,
input_width
+
padding_width
);
wend
=
std
::
min
(
wstart
+
_kernel_size
,
input_width
+
padding_width
);
hstart
=
max
(
hstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
min
(
hend
,
input_height
);
hend
=
std
::
min
(
hend
,
input_height
);
wend
=
min
(
wend
,
input_width
);
wend
=
std
::
min
(
wend
,
input_width
);
pos1
=
input_data
+
hstart
*
input_width
+
wstart
;
pos1
=
input_data
+
hstart
*
input_width
+
wstart
;
pos2
=
input_data
+
(
hstart
+
1
)
*
input_width
+
wstart
;
pos2
=
input_data
+
(
hstart
+
1
)
*
input_width
+
wstart
;
pos3
=
input_data
+
(
hstart
+
2
)
*
input_width
+
wstart
;
pos3
=
input_data
+
(
hstart
+
2
)
*
input_width
+
wstart
;
...
@@ -244,12 +248,14 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
...
@@ -244,12 +248,14 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
}
}
}
}
void
DepthwiseConv3x3s1p1
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
DepthwiseConv3x3s1p1
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
Tensor
*
bias
,
bool
if_bias
)
{
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
bool
if_bias
)
{
#if __ARM_NEON
#if __ARM_NEON
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
float
*
output_data
=
output
->
data
<
float
>
();
float
*
output_data
=
output
->
mutable_
data
<
float
>
();
const
float
*
bias_data
;
const
float
*
bias_data
;
if
(
if_bias
)
{
if
(
if_bias
)
{
bias_data
=
bias
->
data
<
float
>
();
bias_data
=
bias
->
data
<
float
>
();
...
@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
...
@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
#endif
#endif
}
}
void
DepthwiseConvAddBNRelu3x3s1p1
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
DepthwiseConvAddBNRelu3x3s1p1
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
const
Tensor
*
new_scale
,
const
framework
::
Tensor
*
filter
,
const
Tensor
*
new_bias
,
bool
if_relu
)
{
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
)
{
#if __ARM_NEON
#if __ARM_NEON
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
...
@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
...
@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
}
}
/// w!=h not fix
/// w!=h not fix
void
DepthwiseConvAddBNRelu3x3s2p1
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
DepthwiseConvAddBNRelu3x3s2p1
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
const
Tensor
*
new_scale
,
const
framework
::
Tensor
*
filter
,
const
Tensor
*
new_bias
,
bool
if_relu
)
{
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
)
{
#if __ARM_NEON
#if __ARM_NEON
const
int
batch_size
=
input
->
dims
()[
0
];
const
int
batch_size
=
input
->
dims
()[
0
];
...
@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
...
@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
for
(
int
pw
=
0
;
pw
<
output_width
;
pw
++
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
pw
++
)
{
hstart
=
ph
*
stride_height
-
padding_height
;
hstart
=
ph
*
stride_height
-
padding_height
;
wstart
=
pw
*
stride_width
-
padding_width
;
wstart
=
pw
*
stride_width
-
padding_width
;
hend
=
min
(
hstart
+
_kernel_size
,
input_height
+
padding_height
);
hend
=
std
::
min
(
hstart
+
_kernel_size
,
input_height
+
padding_height
);
wend
=
min
(
wstart
+
_kernel_size
,
input_width
+
padding_width
);
wend
=
std
::
min
(
wstart
+
_kernel_size
,
input_width
+
padding_width
);
hstart
=
max
(
hstart
,
0
);
hstart
=
std
::
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
hend
=
min
(
hend
,
input_height
);
hend
=
std
::
min
(
hend
,
input_height
);
wend
=
min
(
wend
,
input_width
);
wend
=
std
::
min
(
wend
,
input_width
);
pos1
=
input_data
+
hstart
*
input_width
+
wstart
;
pos1
=
input_data
+
hstart
*
input_width
+
wstart
;
pos2
=
input_data
+
(
hstart
+
1
)
*
input_width
+
wstart
;
pos2
=
input_data
+
(
hstart
+
1
)
*
input_width
+
wstart
;
pos3
=
input_data
+
(
hstart
+
2
)
*
input_width
+
wstart
;
pos3
=
input_data
+
(
hstart
+
2
)
*
input_width
+
wstart
;
...
@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
...
@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
#endif
#endif
}
}
void
DepthwiseConv3x3s2p1v2
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
DepthwiseConv3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
Tensor
bias
,
bool
if_bias
)
{
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
bias
,
bool
if_bias
)
{
#if __ARM_NEON
#if __ARM_NEON
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
const
float
*
filter_data
=
filter
->
data
<
float
>
();
...
@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
...
@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif
#endif
}
}
void
DepthwiseConvAddBNRelu3x3s2p1v2
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
DepthwiseConvAddBNRelu3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
const
Tensor
*
new_scale
,
const
framework
::
Tensor
*
filter
,
const
Tensor
*
new_bias
,
bool
if_relu
)
{
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
)
{
#if __ARM_NEON
#if __ARM_NEON
// #ifdef _OPENMP
// #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>();
// const float *newscale_data = new_scale->data<float>();
...
@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
...
@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif
#endif
}
}
void
DepthwiseConv3x3s2p0
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
DepthwiseConv3x3s2p0
(
const
framework
::
Tensor
*
input
,
Tensor
*
output
,
Tensor
bias
,
bool
if_bias
)
{
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
bias
,
bool
if_bias
)
{
#if __ARM_NEON
#if __ARM_NEON
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
...
...
src/operators/math/depthwise_conv3x3.h
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
void
DepthwiseConv3x3
(
const
framework
::
Tensor
*
input
,
const
std
::
vector
<
int
>
&
strides
,
const
std
::
vector
<
int
>
&
paddings
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
bias
,
framework
::
Tensor
*
output
,
bool
if_bias
);
void
DepthwiseConv3x3s1p1
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
bias
,
bool
if_bias
);
void
DepthwiseConvAddBNRelu3x3s1p1
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConvAddBNRelu3x3s2p1
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConv3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
bias
,
bool
if_bias
);
void
DepthwiseConvAddBNRelu3x3s2p1v2
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
const
framework
::
Tensor
*
new_scale
,
const
framework
::
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConv3x3s2p0
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
framework
::
Tensor
*
output
,
framework
::
Tensor
bias
,
bool
if_bias
);
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// framework::Tensor *output);
template
<
typename
Itype
,
typename
Otype
>
void
DepthwiseConv3x3s1
(
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
framework
::
Tensor
*
output
);
template
<
typename
Itype
,
typename
Otype
>
void
DepthwiseConv3x3s2
(
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
framework
::
Tensor
*
output
);
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
src/operators/math/depthwise_conv3x3_int8.cpp
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/depthwise_conv3x3.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
// template<>
// void DepthwiseConv3x3<int8_t, int32_t>(
// const framework::Tensor *input, const framework::Tensor *filter,
// const std::vector<int> &strides, framework::Tensor *output) {
// PADDLE_MOBILE_THROW_EXCEPTION(
// "Depthwise conv with generic strides has not been implemented.");
// }
template
<
>
void
DepthwiseConv3x3s1
<
int8_t
,
int32_t
>
(
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
framework
::
Tensor
*
output
)
{
const
int8_t
*
input_data
=
input
.
data
<
int8_t
>
();
const
int8_t
*
filter_data
=
filter
.
data
<
int8_t
>
();
int32_t
*
out_data
=
output
->
mutable_data
<
int32_t
>
();
// make sure that batch size is 1
int
input_c
=
input
.
dims
()[
1
];
int
input_h
=
input
.
dims
()[
2
];
int
input_w
=
input
.
dims
()[
3
];
int
output_c
=
output
->
dims
()[
1
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
int
image_size
=
input_h
*
input_w
;
int
out_image_size
=
output_h
*
output_w
;
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for
for
(
int
g
=
0
;
g
<
input_c
;
++
g
)
{
const
int8_t
*
input_ptr
=
input_data
+
g
*
image_size
;
const
int8_t
*
filter_ptr
=
filter_data
+
g
*
9
;
int32_t
*
output_ptr
=
out_data
+
g
*
out_image_size
;
int
loops
=
(
input_w
-
2
)
/
6
;
int
remain
=
input_w
-
2
-
loops
*
6
;
for
(
int
h
=
0
;
h
<
input_h
-
5
/*(input_h - 2) - 3*/
;
h
+=
4
)
{
const
int8_t
*
input_ptr0
=
input_ptr
+
h
*
input_w
;
const
int8_t
*
input_ptr1
=
input_ptr0
+
input_w
;
const
int8_t
*
input_ptr2
=
input_ptr1
+
input_w
;
const
int8_t
*
input_ptr3
=
input_ptr2
+
input_w
;
const
int8_t
*
input_ptr4
=
input_ptr3
+
input_w
;
const
int8_t
*
input_ptr5
=
input_ptr4
+
input_w
;
int32_t
*
output_ptr0
=
output_ptr
+
h
*
output_w
;
int32_t
*
output_ptr1
=
output_ptr0
+
output_w
;
int32_t
*
output_ptr2
=
output_ptr1
+
output_w
;
int32_t
*
output_ptr3
=
output_ptr2
+
output_w
;
int
loop
=
loops
;
asm
volatile
(
"vld1.32 {q0}, [%[filter_ptr]]
\n
"
"vmovl.s8 q14, d0
\n
"
"vmovl.s8 q15, d1
\n
"
"vdup.s16 d0, d28[0]
\n
"
"vdup.s16 d1, d28[1]
\n
"
"vdup.s16 d2, d28[2]
\n
"
"vdup.s16 d3, d28[3]
\n
"
"vdup.s16 d4, d29[0]
\n
"
"vdup.s16 d5, d29[1]
\n
"
"vdup.s16 d6, d29[2]
\n
"
"vdup.s16 d7, d29[3]
\n
"
"vdup.s16 d8, d30[0]
\n
"
:
:
[
filter_ptr
]
"r"
(
filter_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q14"
,
"q15"
);
asm
volatile
(
"mov r0, #6
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
// loop 6 widths
"loop_4h6w_%=:
\n
"
"vld1.32 {d9}, [%[input_ptr0]], r0
\n
"
"vld1.32 {d10}, [%[input_ptr1]], r0
\n
"
"vld1.32 {d11}, [%[input_ptr2]], r0
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d12, d10, d10, #1
\n
"
"vext.s8 d13, d10, d10, #2
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vmull.s16 q12, d14, d0
\n
"
"vmlal.s16 q12, d16, d1
\n
"
"vmlal.s16 q12, d18, d2
\n
"
"vmull.s16 q13, d15, d0
\n
"
"vmlal.s16 q13, d17, d1
\n
"
"vmlal.s16 q13, d19, d2
\n
"
"vext.s8 d12, d11, d11, #1
\n
"
"vext.s8 d13, d11, d11, #2
\n
"
"vmovl.s8 q7, d11
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]!
\n
"
"vmlal.s16 q12, d14, d3
\n
"
"vmlal.s16 q12, d16, d4
\n
"
"vmlal.s16 q12, d18, d5
\n
"
"vmlal.s16 q13, d15, d3
\n
"
"vmlal.s16 q13, d17, d4
\n
"
"vmlal.s16 q13, d19, d5
\n
"
"vmull.s16 q14, d14, d0
\n
"
"vmlal.s16 q14, d16, d1
\n
"
"vmlal.s16 q14, d18, d2
\n
"
"vmull.s16 q15, d15, d0
\n
"
"vmlal.s16 q15, d17, d1
\n
"
"vmlal.s16 q15, d19, d2
\n
"
"vld1.32 {d9}, [%[input_ptr3]], r0
\n
"
"vld1.32 {d10}, [%[input_ptr4]], r0
\n
"
"vld1.32 {d11}, [%[input_ptr5]], r0
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q12, d14, d6
\n
"
"vmlal.s16 q12, d16, d7
\n
"
"vmlal.s16 q12, d18, d8
\n
"
"vmlal.s16 q13, d15, d6
\n
"
"vmlal.s16 q13, d17, d7
\n
"
"vmlal.s16 q13, d19, d8
\n
"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]!
\n
"
"vmlal.s16 q14, d14, d3
\n
"
"vmlal.s16 q14, d16, d4
\n
"
"vmlal.s16 q14, d18, d5
\n
"
"vmlal.s16 q15, d15, d3
\n
"
"vmlal.s16 q15, d17, d4
\n
"
"vmlal.s16 q15, d19, d5
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d12, d10, d10, #1
\n
"
"vext.s8 d13, d10, d10, #2
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q14, d14, d6
\n
"
"vmlal.s16 q14, d16, d7
\n
"
"vmlal.s16 q14, d18, d8
\n
"
"vmlal.s16 q15, d15, d6
\n
"
"vmlal.s16 q15, d17, d7
\n
"
"vmlal.s16 q15, d19, d8
\n
"
// store row 2
"vst1.32 {d28-d30}, [%[output_ptr2]]!
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vext.s8 d12, d11, d11, #1
\n
"
"vext.s8 d13, d11, d11, #2
\n
"
"vmovl.s8 q7, d11
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
// store row 3
"vst1.32 {d20-d22}, [%[output_ptr3]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_4h6w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble end_%=
\n
"
"vld1.32 {d9}, [%[input_ptr0]]
\n
"
"vmovl.s8 q7, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q8, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vld1.32 {d9}, [%[input_ptr1]]
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vmovl.s8 q7, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q8, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vmull.s16 q12, d14, d0
\n
"
"vmlal.s16 q12, d16, d1
\n
"
"vmlal.s16 q12, d18, d2
\n
"
"vld1.32 {d9}, [%[input_ptr2]]
\n
"
"vmull.s16 q13, d15, d0
\n
"
"vmlal.s16 q13, d17, d1
\n
"
"vmlal.s16 q13, d19, d2
\n
"
"vmovl.s8 q7, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q8, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
"vmlal.s16 q12, d14, d3
\n
"
"vmlal.s16 q12, d16, d4
\n
"
"vmlal.s16 q12, d18, d5
\n
"
"vmlal.s16 q13, d15, d3
\n
"
"vmlal.s16 q13, d17, d4
\n
"
"vmlal.s16 q13, d19, d5
\n
"
"vmull.s16 q14, d14, d0
\n
"
"vmlal.s16 q14, d16, d1
\n
"
"vmlal.s16 q14, d18, d2
\n
"
"vld1.32 {d9}, [%[input_ptr3]]
\n
"
"vmull.s16 q15, d15, d0
\n
"
"vmlal.s16 q15, d17, d1
\n
"
"vmlal.s16 q15, d19, d2
\n
"
"vmovl.s8 q7, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q8, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmlal.s16 q12, d14, d6
\n
"
"vmlal.s16 q12, d16, d7
\n
"
"vmlal.s16 q12, d18, d8
\n
"
"vmlal.s16 q13, d15, d6
\n
"
"vmlal.s16 q13, d17, d7
\n
"
"vmlal.s16 q13, d19, d8
\n
"
"vmlal.s16 q14, d14, d3
\n
"
"vmlal.s16 q14, d16, d4
\n
"
"vmlal.s16 q14, d18, d5
\n
"
"vmlal.s16 q15, d15, d3
\n
"
"vmlal.s16 q15, d17, d4
\n
"
"vmlal.s16 q15, d19, d5
\n
"
"vmull.s16 q5, d14, d0
\n
"
"vmlal.s16 q5, d16, d1
\n
"
"vmlal.s16 q5, d18, d2
\n
"
"vld1.32 {d9}, [%[input_ptr4]]
\n
"
"vmull.s16 q6, d15, d0
\n
"
"vmlal.s16 q6, d17, d1
\n
"
"vmlal.s16 q6, d19, d2
\n
"
"vmovl.s8 q7, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q8, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmlal.s16 q14, d14, d6
\n
"
"vmlal.s16 q14, d16, d7
\n
"
"vmlal.s16 q14, d18, d8
\n
"
"vmlal.s16 q15, d15, d6
\n
"
"vmlal.s16 q15, d17, d7
\n
"
"vmlal.s16 q15, d19, d8
\n
"
"vmlal.s16 q5, d14, d3
\n
"
"vmlal.s16 q5, d16, d4
\n
"
"vmlal.s16 q5, d18, d5
\n
"
"vld1.32 {d9}, [%[input_ptr5]]
\n
"
"vmlal.s16 q6, d15, d3
\n
"
"vmlal.s16 q6, d17, d4
\n
"
"vmlal.s16 q6, d19, d5
\n
"
"vmovl.s8 q7, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q8, d9
\n
"
"vext.s8 d9, d9, d9, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmlal.s16 q5, d14, d6
\n
"
"vmlal.s16 q5, d16, d7
\n
"
"vmlal.s16 q5, d18, d8
\n
"
"vmlal.s16 q6, d15, d6
\n
"
"vmlal.s16 q6, d17, d7
\n
"
"vmlal.s16 q6, d19, d8
\n
"
"cmp %[remain], #4
\n
"
"blt store_4h2w_%=
\n
"
"vst1.32 {q10}, [%[output_ptr0]]!
\n
"
"vst1.32 {q12}, [%[output_ptr1]]!
\n
"
"vst1.32 {q14}, [%[output_ptr2]]!
\n
"
"vst1.32 {q5}, [%[output_ptr3]]!
\n
"
"cmp %[remain], #5
\n
"
"blt end_%=
\n
"
"vst1.32 {d22[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d26[0]}, [%[output_ptr1]]!
\n
"
"vst1.32 {d30[0]}, [%[output_ptr2]]!
\n
"
"vst1.32 {d12[0]}, [%[output_ptr3]]!
\n
"
"b end_%=
\n
"
"store_4h2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_4h1w_%=
\n
"
"vst1.32 {d20}, [%[output_ptr0]]!
\n
"
"vst1.32 {d24}, [%[output_ptr1]]!
\n
"
"vst1.32 {d28}, [%[output_ptr2]]!
\n
"
"vst1.32 {d10}, [%[output_ptr3]]!
\n
"
"cmp %[remain], #3
\n
"
"blt end_%=
\n
"
"vst1.32 {d21[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d25[0]}, [%[output_ptr1]]!
\n
"
"vst1.32 {d29[0]}, [%[output_ptr2]]!
\n
"
"vst1.32 {d11[0]}, [%[output_ptr3]]!
\n
"
"b end_%=
\n
"
"store_4h1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.32 {d20[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d24[0]}, [%[output_ptr1]]!
\n
"
"vst1.32 {d28[0]}, [%[output_ptr2]]!
\n
"
"vst1.32 {d10[0]}, [%[output_ptr3]]!
\n
"
"end_%=:
\n
"
:
[
output_ptr0
]
"+r"
(
output_ptr0
),
[
output_ptr1
]
"+r"
(
output_ptr1
),
[
output_ptr2
]
"+r"
(
output_ptr2
),
[
output_ptr3
]
"+r"
(
output_ptr3
),
[
input_ptr0
]
"+r"
(
input_ptr0
),
[
input_ptr1
]
"+r"
(
input_ptr1
),
[
input_ptr2
]
"+r"
(
input_ptr2
),
[
input_ptr3
]
"+r"
(
input_ptr3
),
[
input_ptr4
]
"+r"
(
input_ptr4
),
[
input_ptr5
]
"+r"
(
input_ptr5
),
[
loop
]
"+r"
(
loop
)
:
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
);
}
// remain height
int
start_h
=
(
input_h
-
2
)
&
0xFFFC
;
for
(
int
h
=
start_h
;
h
<
input_h
-
3
/*(input_h - 2) - 1*/
;
h
+=
2
)
{
const
int8_t
*
input_ptr0
=
input_ptr
+
h
*
input_w
;
const
int8_t
*
input_ptr1
=
input_ptr0
+
input_w
;
const
int8_t
*
input_ptr2
=
input_ptr1
+
input_w
;
const
int8_t
*
input_ptr3
=
input_ptr2
+
input_w
;
int32_t
*
output_ptr0
=
output_ptr
+
h
*
output_w
;
int32_t
*
output_ptr1
=
output_ptr0
+
output_w
;
int
loop
=
loops
;
asm
volatile
(
"vld1.32 {q0}, [%[filter_ptr]]
\n
"
"vmovl.s8 q14, d0
\n
"
"vmovl.s8 q15, d1
\n
"
"vdup.s16 d0, d28[0]
\n
"
"vdup.s16 d1, d28[1]
\n
"
"vdup.s16 d2, d28[2]
\n
"
"vdup.s16 d3, d28[3]
\n
"
"vdup.s16 d4, d29[0]
\n
"
"vdup.s16 d5, d29[1]
\n
"
"vdup.s16 d6, d29[2]
\n
"
"vdup.s16 d7, d29[3]
\n
"
"vdup.s16 d8, d30[0]
\n
"
:
:
[
filter_ptr
]
"r"
(
filter_ptr
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q14"
,
"q15"
);
asm
volatile
(
"mov r0, #6
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
// loop 6 widths
"loop_2h6w_%=:
\n
"
"vld1.32 {d9}, [%[input_ptr0]], r0
\n
"
"vld1.32 {d10}, [%[input_ptr1]], r0
\n
"
"vld1.32 {d11}, [%[input_ptr2]], r0
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d12, d10, d10, #1
\n
"
"vext.s8 d13, d10, d10, #2
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vmull.s16 q12, d14, d0
\n
"
"vmlal.s16 q12, d16, d1
\n
"
"vmlal.s16 q12, d18, d2
\n
"
"vmull.s16 q13, d15, d0
\n
"
"vmlal.s16 q13, d17, d1
\n
"
"vmlal.s16 q13, d19, d2
\n
"
"vext.s8 d12, d11, d11, #1
\n
"
"vext.s8 d13, d11, d11, #2
\n
"
"vmovl.s8 q7, d11
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]!
\n
"
"vmlal.s16 q12, d14, d3
\n
"
"vmlal.s16 q12, d16, d4
\n
"
"vmlal.s16 q12, d18, d5
\n
"
"vmlal.s16 q13, d15, d3
\n
"
"vmlal.s16 q13, d17, d4
\n
"
"vmlal.s16 q13, d19, d5
\n
"
"vld1.32 {d9}, [%[input_ptr3]], r0
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q12, d14, d6
\n
"
"vmlal.s16 q12, d16, d7
\n
"
"vmlal.s16 q12, d18, d8
\n
"
"vmlal.s16 q13, d15, d6
\n
"
"vmlal.s16 q13, d17, d7
\n
"
"vmlal.s16 q13, d19, d8
\n
"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_2h6w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble end_%=
\n
"
"vld1.32 {d9}, [%[input_ptr0]]
\n
"
"vld1.32 {d10}, [%[input_ptr1]]
\n
"
"vld1.32 {d11}, [%[input_ptr2]]
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d12, d10, d10, #1
\n
"
"vext.s8 d13, d10, d10, #2
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vmull.s16 q12, d14, d0
\n
"
"vmlal.s16 q12, d16, d1
\n
"
"vmlal.s16 q12, d18, d2
\n
"
"vmull.s16 q13, d15, d0
\n
"
"vmlal.s16 q13, d17, d1
\n
"
"vmlal.s16 q13, d19, d2
\n
"
"vext.s8 d12, d11, d11, #1
\n
"
"vext.s8 d13, d11, d11, #2
\n
"
"vmovl.s8 q7, d11
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
"vmlal.s16 q12, d14, d3
\n
"
"vmlal.s16 q12, d16, d4
\n
"
"vmlal.s16 q12, d18, d5
\n
"
"vmlal.s16 q13, d15, d3
\n
"
"vmlal.s16 q13, d17, d4
\n
"
"vmlal.s16 q13, d19, d5
\n
"
"vld1.32 {d9}, [%[input_ptr3]]
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q12, d14, d6
\n
"
"vmlal.s16 q12, d16, d7
\n
"
"vmlal.s16 q12, d18, d8
\n
"
"vmlal.s16 q13, d15, d6
\n
"
"vmlal.s16 q13, d17, d7
\n
"
"vmlal.s16 q13, d19, d8
\n
"
"cmp %[remain], #4
\n
"
"blt store_2h2w_%=
\n
"
"vst1.32 {q10}, [%[output_ptr0]]!
\n
"
"vst1.32 {q12}, [%[output_ptr1]]!
\n
"
"cmp %[remain], #5
\n
"
"blt end_%=
\n
"
"vst1.32 {d22[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d26[0]}, [%[output_ptr1]]!
\n
"
"b end_%=
\n
"
"store_2h2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_2h1w_%=
\n
"
"vst1.32 {d20}, [%[output_ptr0]]!
\n
"
"vst1.32 {d24}, [%[output_ptr1]]!
\n
"
"cmp %[remain], #3
\n
"
"blt end_%=
\n
"
"vst1.32 {d21[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d25[0]}, [%[output_ptr1]]!
\n
"
"b end_%=
\n
"
"store_2h1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.32 {d20[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d24[0]}, [%[output_ptr1]]!
\n
"
"end_%=:
\n
"
:
[
output_ptr0
]
"+r"
(
output_ptr0
),
[
output_ptr1
]
"+r"
(
output_ptr1
),
[
input_ptr0
]
"+r"
(
input_ptr0
),
[
input_ptr1
]
"+r"
(
input_ptr1
),
[
input_ptr2
]
"+r"
(
input_ptr2
),
[
input_ptr3
]
"+r"
(
input_ptr3
),
[
loop
]
"+r"
(
loop
)
:
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
);
}
start_h
=
(
input_h
-
2
)
&
0xFFFE
;
if
(
start_h
<
input_h
-
2
)
{
const
int8_t
*
input_ptr0
=
input_ptr
+
start_h
*
input_w
;
const
int8_t
*
input_ptr1
=
input_ptr0
+
input_w
;
const
int8_t
*
input_ptr2
=
input_ptr1
+
input_w
;
int32_t
*
output_ptr0
=
output_ptr
+
start_h
*
output_w
;
int
loop
=
loops
;
asm
volatile
(
"vld1.32 {q0}, [%[filter_ptr]]
\n
"
"vmovl.s8 q14, d0
\n
"
"vmovl.s8 q15, d1
\n
"
"vdup.s16 d0, d28[0]
\n
"
"vdup.s16 d1, d28[1]
\n
"
"vdup.s16 d2, d28[2]
\n
"
"vdup.s16 d3, d28[3]
\n
"
"vdup.s16 d4, d29[0]
\n
"
"vdup.s16 d5, d29[1]
\n
"
"vdup.s16 d6, d29[2]
\n
"
"vdup.s16 d7, d29[3]
\n
"
"vdup.s16 d8, d30[0]
\n
"
:
:
[
filter_ptr
]
"r"
(
filter_ptr
)
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q14"
,
"q15"
);
asm
volatile
(
"mov r0, #6
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
// loop 6 widths
"loop_1h6w_%=:
\n
"
"vld1.32 {d9}, [%[input_ptr0]], r0
\n
"
"vld1.32 {d10}, [%[input_ptr1]], r0
\n
"
"vld1.32 {d11}, [%[input_ptr2]], r0
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d12, d10, d10, #1
\n
"
"vext.s8 d13, d10, d10, #2
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vext.s8 d12, d11, d11, #1
\n
"
"vext.s8 d13, d11, d11, #2
\n
"
"vmovl.s8 q7, d11
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_1h6w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble end_%=
\n
"
"vld1.32 {d9}, [%[input_ptr0]]
\n
"
"vld1.32 {d10}, [%[input_ptr1]]
\n
"
"vld1.32 {d11}, [%[input_ptr2]]
\n
"
"vext.s8 d12, d9, d9, #1
\n
"
"vext.s8 d13, d9, d9, #2
\n
"
"vmovl.s8 q7, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d12, d10, d10, #1
\n
"
"vext.s8 d13, d10, d10, #2
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vext.s8 d12, d11, d11, #1
\n
"
"vext.s8 d13, d11, d11, #2
\n
"
"vmovl.s8 q7, d11
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
"cmp %[remain], #4
\n
"
"blt store_1h2w_%=
\n
"
"vst1.32 {q10}, [%[output_ptr0]]!
\n
"
"cmp %[remain], #5
\n
"
"blt end_%=
\n
"
"vst1.32 {d22[0]}, [%[output_ptr0]]!
\n
"
"b end_%=
\n
"
"store_1h2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_1h1w_%=
\n
"
"vst1.32 {d20}, [%[output_ptr0]]!
\n
"
"cmp %[remain], #3
\n
"
"blt end_%=
\n
"
"vst1.32 {d21[0]}, [%[output_ptr0]]!
\n
"
"b end_%=
\n
"
"store_1h1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.32 {d20[0]}, [%[output_ptr0]]!
\n
"
"end_%=:
\n
"
:
[
output_ptr0
]
"+r"
(
output_ptr0
),
[
input_ptr0
]
"+r"
(
input_ptr0
),
[
input_ptr1
]
"+r"
(
input_ptr1
),
[
input_ptr2
]
"+r"
(
input_ptr2
),
[
loop
]
"+r"
(
loop
)
:
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"r0"
);
}
}
#endif // __aarch64__
}
template
<
>
void
DepthwiseConv3x3s2
<
int8_t
,
int32_t
>
(
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
filter
,
framework
::
Tensor
*
output
)
{
const
int8_t
*
input_data
=
input
.
data
<
int8_t
>
();
const
int8_t
*
filter_data
=
filter
.
data
<
int8_t
>
();
int32_t
*
out_data
=
output
->
mutable_data
<
int32_t
>
();
// make sure that batch size is 1
int
input_c
=
input
.
dims
()[
1
];
int
input_h
=
input
.
dims
()[
2
];
int
input_w
=
input
.
dims
()[
3
];
int
output_c
=
output
->
dims
()[
1
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
int
image_size
=
input_h
*
input_w
;
int
out_image_size
=
output_h
*
output_w
;
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for
for
(
int
g
=
0
;
g
<
input_c
;
++
g
)
{
const
int8_t
*
input_ptr
=
input_data
+
g
*
image_size
;
const
int8_t
*
filter_ptr
=
filter_data
+
g
*
9
;
int32_t
*
output_ptr
=
out_data
+
g
*
out_image_size
;
int
loops
=
output_w
/
6
;
int
remain
=
output_w
-
loops
*
6
;
for
(
int
h
=
0
;
h
<
input_h
-
6
/*(input_h - 1) - 5*/
;
h
+=
6
)
{
const
int8_t
*
input_ptr0
=
input_ptr
+
h
*
input_w
;
const
int8_t
*
input_ptr1
=
input_ptr0
+
input_w
;
const
int8_t
*
input_ptr2
=
input_ptr1
+
input_w
;
const
int8_t
*
input_ptr3
=
input_ptr2
+
input_w
;
const
int8_t
*
input_ptr4
=
input_ptr3
+
input_w
;
const
int8_t
*
input_ptr5
=
input_ptr4
+
input_w
;
const
int8_t
*
input_ptr6
=
input_ptr5
+
input_w
;
int32_t
*
output_ptr0
=
output_ptr
+
(
h
>>
1
)
*
output_w
;
int32_t
*
output_ptr1
=
output_ptr0
+
output_w
;
int32_t
*
output_ptr2
=
output_ptr1
+
output_w
;
int
loop
=
loops
;
asm
volatile
(
"vld1.32 {q0}, [%[filter_ptr]]
\n
"
"vmovl.s8 q14, d0
\n
"
"vmovl.s8 q15, d1
\n
"
"vdup.s16 d0, d28[0]
\n
"
"vdup.s16 d1, d28[1]
\n
"
"vdup.s16 d2, d28[2]
\n
"
"vdup.s16 d3, d28[3]
\n
"
"vdup.s16 d4, d29[0]
\n
"
"vdup.s16 d5, d29[1]
\n
"
"vdup.s16 d6, d29[2]
\n
"
"vdup.s16 d7, d29[3]
\n
"
"vdup.s16 d8, d30[0]
\n
"
:
:
[
filter_ptr
]
"r"
(
filter_ptr
)
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q14"
,
"q15"
);
asm
volatile
(
"mov r0, #12
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
// loop 6 widths
"loop_3h6w_%=:
\n
"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0
\n
"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0
\n
"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d10
\n
"
"vmovl.s8 q9, d11
\n
"
"vmull.s16 q11, d16, d0
\n
"
"vmlal.s16 q11, d18, d1
\n
"
"vmlal.s16 q11, d20, d2
\n
"
"vmull.s16 q12, d17, d0
\n
"
"vmlal.s16 q12, d19, d1
\n
"
"vmlal.s16 q12, d21, d2
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q11, d16, d3
\n
"
"vmlal.s16 q11, d18, d4
\n
"
"vmlal.s16 q11, d20, d5
\n
"
"vmlal.s16 q12, d17, d3
\n
"
"vmlal.s16 q12, d19, d4
\n
"
"vmlal.s16 q12, d21, d5
\n
"
"vext.s8 d9, d14, d14, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d14
\n
"
"vmovl.s8 q9, d15
\n
"
"vmlal.s16 q11, d16, d6
\n
"
"vmlal.s16 q11, d18, d7
\n
"
"vmlal.s16 q11, d20, d8
\n
"
"vmlal.s16 q12, d17, d6
\n
"
"vmlal.s16 q12, d19, d7
\n
"
"vmlal.s16 q12, d21, d8
\n
"
// store row 0, reuse q11/q12
"vst1.32 {d22-d24}, [%[output_ptr0]]!
\n
"
"vmull.s16 q13, d16, d0
\n
"
"vmlal.s16 q13, d18, d1
\n
"
"vmlal.s16 q13, d20, d2
\n
"
"vmull.s16 q14, d17, d0
\n
"
"vmlal.s16 q14, d19, d1
\n
"
"vmlal.s16 q14, d21, d2
\n
"
"vld2.8 {d10, d11}, [%[input_ptr3]], r0
\n
"
"vld2.8 {d12, d13}, [%[input_ptr4]], r0
\n
"
"vld2.8 {d14, d15}, [%[input_ptr5]], r0
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d10
\n
"
"vmovl.s8 q9, d11
\n
"
"vmlal.s16 q13, d16, d3
\n
"
"vmlal.s16 q13, d18, d4
\n
"
"vmlal.s16 q13, d20, d5
\n
"
"vmlal.s16 q14, d17, d3
\n
"
"vmlal.s16 q14, d19, d4
\n
"
"vmlal.s16 q14, d21, d5
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q13, d16, d6
\n
"
"vmlal.s16 q13, d18, d7
\n
"
"vmlal.s16 q13, d20, d8
\n
"
"vmlal.s16 q14, d17, d6
\n
"
"vmlal.s16 q14, d19, d7
\n
"
"vmlal.s16 q14, d21, d8
\n
"
// store row 1
"vst1.32 {d26-d28}, [%[output_ptr1]]!
\n
"
"vmull.s16 q11, d16, d0
\n
"
"vmlal.s16 q11, d18, d1
\n
"
"vmlal.s16 q11, d20, d2
\n
"
"vmull.s16 q12, d17, d0
\n
"
"vmlal.s16 q12, d19, d1
\n
"
"vmlal.s16 q12, d21, d2
\n
"
"vext.s8 d9, d14, d14, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d14
\n
"
"vmovl.s8 q9, d15
\n
"
"vmlal.s16 q11, d16, d3
\n
"
"vmlal.s16 q11, d18, d4
\n
"
"vmlal.s16 q11, d20, d5
\n
"
"vmlal.s16 q12, d17, d3
\n
"
"vmlal.s16 q12, d19, d4
\n
"
"vmlal.s16 q12, d21, d5
\n
"
"vld2.8 {d10, d11}, [%[input_ptr6]], r0
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d10
\n
"
"vmovl.s8 q9, d11
\n
"
"vmlal.s16 q11, d16, d6
\n
"
"vmlal.s16 q11, d18, d7
\n
"
"vmlal.s16 q11, d20, d8
\n
"
"vmlal.s16 q12, d17, d6
\n
"
"vmlal.s16 q12, d19, d7
\n
"
"vmlal.s16 q12, d21, d8
\n
"
// store row 2
"vst1.32 {d22-d24}, [%[output_ptr2]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_3h6w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble end_%=
\n
"
"vld2.8 {d10, d11}, [%[input_ptr0]]
\n
"
"vld2.8 {d12, d13}, [%[input_ptr1]]
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d11
\n
"
"vmull.s16 q10, d14, d0
\n
"
"vmlal.s16 q10, d16, d1
\n
"
"vmlal.s16 q10, d18, d2
\n
"
"vmull.s16 q11, d15, d0
\n
"
"vmlal.s16 q11, d17, d1
\n
"
"vmlal.s16 q11, d19, d2
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d12
\n
"
"vmovl.s8 q8, d13
\n
"
"vmlal.s16 q10, d14, d3
\n
"
"vmlal.s16 q10, d16, d4
\n
"
"vmlal.s16 q10, d18, d5
\n
"
"vmlal.s16 q11, d15, d3
\n
"
"vmlal.s16 q11, d17, d4
\n
"
"vmlal.s16 q11, d19, d5
\n
"
"vld2.8 {d10, d11}, [%[input_ptr2]]
\n
"
"vld2.8 {d12, d13}, [%[input_ptr3]]
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d11
\n
"
"vmlal.s16 q10, d14, d6
\n
"
"vmlal.s16 q10, d16, d7
\n
"
"vmlal.s16 q10, d18, d8
\n
"
"vmlal.s16 q11, d15, d6
\n
"
"vmlal.s16 q11, d17, d7
\n
"
"vmlal.s16 q11, d19, d8
\n
"
"vmull.s16 q12, d14, d0
\n
"
"vmlal.s16 q12, d16, d1
\n
"
"vmlal.s16 q12, d18, d2
\n
"
"vmull.s16 q13, d15, d0
\n
"
"vmlal.s16 q13, d17, d1
\n
"
"vmlal.s16 q13, d19, d2
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d12
\n
"
"vmovl.s8 q8, d13
\n
"
"vmlal.s16 q12, d14, d3
\n
"
"vmlal.s16 q12, d16, d4
\n
"
"vmlal.s16 q12, d18, d5
\n
"
"vmlal.s16 q13, d15, d3
\n
"
"vmlal.s16 q13, d17, d4
\n
"
"vmlal.s16 q13, d19, d5
\n
"
"vld2.8 {d10, d11}, [%[input_ptr4]]
\n
"
"vld2.8 {d12, d13}, [%[input_ptr5]]
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d11
\n
"
"vmlal.s16 q12, d14, d6
\n
"
"vmlal.s16 q12, d16, d7
\n
"
"vmlal.s16 q12, d18, d8
\n
"
"vmlal.s16 q13, d15, d6
\n
"
"vmlal.s16 q13, d17, d7
\n
"
"vmlal.s16 q13, d19, d8
\n
"
"vmull.s16 q14, d14, d0
\n
"
"vmlal.s16 q14, d16, d1
\n
"
"vmlal.s16 q14, d18, d2
\n
"
"vmull.s16 q15, d15, d0
\n
"
"vmlal.s16 q15, d17, d1
\n
"
"vmlal.s16 q15, d19, d2
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d12
\n
"
"vmovl.s8 q8, d13
\n
"
"vmlal.s16 q14, d14, d3
\n
"
"vmlal.s16 q14, d16, d4
\n
"
"vmlal.s16 q14, d18, d5
\n
"
"vmlal.s16 q15, d15, d3
\n
"
"vmlal.s16 q15, d17, d4
\n
"
"vmlal.s16 q15, d19, d5
\n
"
"vld2.8 {d10, d11}, [%[input_ptr6]]
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q9, d9
\n
"
"vmovl.s8 q7, d10
\n
"
"vmovl.s8 q8, d11
\n
"
"vmlal.s16 q14, d14, d6
\n
"
"vmlal.s16 q14, d16, d7
\n
"
"vmlal.s16 q14, d18, d8
\n
"
"vmlal.s16 q15, d15, d6
\n
"
"vmlal.s16 q15, d17, d7
\n
"
"vmlal.s16 q15, d19, d8
\n
"
"cmp %[remain], #4
\n
"
"blt store_3h2w_%=
\n
"
"vst1.32 {q10}, [%[output_ptr0]]!
\n
"
"vst1.32 {q12}, [%[output_ptr1]]!
\n
"
"vst1.32 {q14}, [%[output_ptr2]]!
\n
"
"cmp %[remain], #5
\n
"
"blt end_%=
\n
"
"vst1.32 {d22[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d26[0]}, [%[output_ptr1]]!
\n
"
"vst1.32 {d30[0]}, [%[output_ptr2]]!
\n
"
"b end_%=
\n
"
"store_3h2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_3h1w_%=
\n
"
"vst1.32 {d20}, [%[output_ptr0]]!
\n
"
"vst1.32 {d24}, [%[output_ptr1]]!
\n
"
"vst1.32 {d28}, [%[output_ptr2]]!
\n
"
"cmp %[remain], #3
\n
"
"blt end_%=
\n
"
"vst1.32 {d21[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d25[0]}, [%[output_ptr1]]!
\n
"
"vst1.32 {d29[0]}, [%[output_ptr2]]!
\n
"
"b end_%=
\n
"
"store_3h1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.32 {d20[0]}, [%[output_ptr0]]!
\n
"
"vst1.32 {d24[0]}, [%[output_ptr1]]!
\n
"
"vst1.32 {d28[0]}, [%[output_ptr2]]!
\n
"
"end_%=:
\n
"
:
[
output_ptr0
]
"+r"
(
output_ptr0
),
[
output_ptr1
]
"+r"
(
output_ptr1
),
[
output_ptr2
]
"+r"
(
output_ptr2
),
[
input_ptr6
]
"+r"
(
input_ptr6
),
[
input_ptr0
]
"+r"
(
input_ptr0
),
[
input_ptr1
]
"+r"
(
input_ptr1
),
[
input_ptr2
]
"+r"
(
input_ptr2
),
[
input_ptr3
]
"+r"
(
input_ptr3
),
[
input_ptr4
]
"+r"
(
input_ptr4
),
[
input_ptr5
]
"+r"
(
input_ptr5
),
[
loop
]
"+r"
(
loop
)
:
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
);
}
int
start_h
=
(
output_h
/
3
)
*
6
;
for
(
int
h
=
start_h
;
h
<
input_h
-
2
/*(input_h - 1) - 1*/
;
h
+=
2
)
{
const
int8_t
*
input_ptr0
=
input_ptr
+
h
*
input_w
;
const
int8_t
*
input_ptr1
=
input_ptr0
+
input_w
;
const
int8_t
*
input_ptr2
=
input_ptr1
+
input_w
;
int32_t
*
output_ptr0
=
output_ptr
+
(
h
>>
1
)
*
output_w
;
int
loop
=
loops
;
asm
volatile
(
"vld1.32 {q0}, [%[filter_ptr]]
\n
"
"vmovl.s8 q14, d0
\n
"
"vmovl.s8 q15, d1
\n
"
"vdup.s16 d0, d28[0]
\n
"
"vdup.s16 d1, d28[1]
\n
"
"vdup.s16 d2, d28[2]
\n
"
"vdup.s16 d3, d28[3]
\n
"
"vdup.s16 d4, d29[0]
\n
"
"vdup.s16 d5, d29[1]
\n
"
"vdup.s16 d6, d29[2]
\n
"
"vdup.s16 d7, d29[3]
\n
"
"vdup.s16 d8, d30[0]
\n
"
:
:
[
filter_ptr
]
"r"
(
filter_ptr
)
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q14"
,
"q15"
);
asm
volatile
(
"mov r0, #12
\n
"
"cmp %[loop], #0
\n
"
"ble start_remain_%=
\n
"
// loop 6 widths
"loop_1h6w_%=:
\n
"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0
\n
"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0
\n
"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d10
\n
"
"vmovl.s8 q9, d11
\n
"
"vmull.s16 q11, d16, d0
\n
"
"vmlal.s16 q11, d18, d1
\n
"
"vmlal.s16 q11, d20, d2
\n
"
"vmull.s16 q12, d17, d0
\n
"
"vmlal.s16 q12, d19, d1
\n
"
"vmlal.s16 q12, d21, d2
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q11, d16, d3
\n
"
"vmlal.s16 q11, d18, d4
\n
"
"vmlal.s16 q11, d20, d5
\n
"
"vmlal.s16 q12, d17, d3
\n
"
"vmlal.s16 q12, d19, d4
\n
"
"vmlal.s16 q12, d21, d5
\n
"
"vext.s8 d9, d14, d14, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d14
\n
"
"vmovl.s8 q9, d15
\n
"
"vmlal.s16 q11, d16, d6
\n
"
"vmlal.s16 q11, d18, d7
\n
"
"vmlal.s16 q11, d20, d8
\n
"
"vmlal.s16 q12, d17, d6
\n
"
"vmlal.s16 q12, d19, d7
\n
"
"vmlal.s16 q12, d21, d8
\n
"
// store row 0
"vst1.32 {d22-d24}, [%[output_ptr0]]!
\n
"
"subs %[loop], #1
\n
"
"bne loop_1h6w_%=
\n
"
"start_remain_%=:
\n
"
"cmp %[remain], #0
\n
"
"ble end_%=
\n
"
"vld2.8 {d10, d11}, [%[input_ptr0]]
\n
"
"vld2.8 {d12, d13}, [%[input_ptr1]]
\n
"
"vld2.8 {d14, d15}, [%[input_ptr2]]
\n
"
"vext.s8 d9, d10, d10, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d10
\n
"
"vmovl.s8 q9, d11
\n
"
"vmull.s16 q11, d16, d0
\n
"
"vmlal.s16 q11, d18, d1
\n
"
"vmlal.s16 q11, d20, d2
\n
"
"vmull.s16 q12, d17, d0
\n
"
"vmlal.s16 q12, d19, d1
\n
"
"vmlal.s16 q12, d21, d2
\n
"
"vext.s8 d9, d12, d12, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d12
\n
"
"vmovl.s8 q9, d13
\n
"
"vmlal.s16 q11, d16, d3
\n
"
"vmlal.s16 q11, d18, d4
\n
"
"vmlal.s16 q11, d20, d5
\n
"
"vmlal.s16 q12, d17, d3
\n
"
"vmlal.s16 q12, d19, d4
\n
"
"vmlal.s16 q12, d21, d5
\n
"
"vext.s8 d9, d14, d14, #1
\n
"
"vmovl.s8 q10, d9
\n
"
"vmovl.s8 q8, d14
\n
"
"vmovl.s8 q9, d15
\n
"
"vmlal.s16 q11, d16, d6
\n
"
"vmlal.s16 q11, d18, d7
\n
"
"vmlal.s16 q11, d20, d8
\n
"
"vmlal.s16 q12, d17, d6
\n
"
"vmlal.s16 q12, d19, d7
\n
"
"vmlal.s16 q12, d21, d8
\n
"
"cmp %[remain], #4
\n
"
"blt store_1h2w_%=
\n
"
"vst1.32 {q11}, [%[output_ptr0]]!
\n
"
"cmp %[remain], #5
\n
"
"blt end_%=
\n
"
"vst1.32 {d24[0]}, [%[output_ptr0]]!
\n
"
"b end_%=
\n
"
"store_1h2w_%=:
\n
"
"cmp %[remain], #2
\n
"
"blt store_1h1w_%=
\n
"
"vst1.32 {d22}, [%[output_ptr0]]!
\n
"
"cmp %[remain], #3
\n
"
"blt end_%=
\n
"
"vst1.32 {d23[0]}, [%[output_ptr0]]!
\n
"
"b end_%=
\n
"
"store_1h1w_%=:
\n
"
"cmp %[remain], #1
\n
"
"blt end_%=
\n
"
"vst1.32 {d22[0]}, [%[output_ptr0]]!
\n
"
"end_%=:
\n
"
:
[
output_ptr0
]
"+r"
(
output_ptr0
),
[
input_ptr0
]
"+r"
(
input_ptr0
),
[
input_ptr1
]
"+r"
(
input_ptr1
),
[
input_ptr2
]
"+r"
(
input_ptr2
),
[
loop
]
"+r"
(
loop
)
:
[
remain
]
"r"
(
remain
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"r0"
);
}
}
#endif // __aarch64__
}
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
src/operators/math/gemm.cpp
浏览文件 @
8137d199
...
@@ -26,79 +26,6 @@ limitations under the License. */
...
@@ -26,79 +26,6 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
/*int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;*/
/*
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j);
*buffer++ = *Aij;
*buffer++ = *(Aij + 1);
*buffer++ = *(Aij + 2);
*buffer++ = *(Aij + 3);
}
}
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
Aij = &A(m - m_tail, j);
for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i);
}
for (i = m_tail; i < MR; ++i) {
*buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j);
Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2);
Bj3 = &B(0, j + 3);
for (i = 0; i < k; ++i) {
*buffer++ = *Bj++;
*buffer++ = *Bj1++;
*buffer++ = *Bj2++;
*buffer++ = *Bj3++;
}
}
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j);
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
*/
// 将A矩阵分块复制到连续内存(RowMajor)
// 将A矩阵分块复制到连续内存(RowMajor)
void
Gemm
::
PackMatrixA_4r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
void
Gemm
::
PackMatrixA_4r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
...
...
src/operators/math/im2col.cpp
浏览文件 @
8137d199
...
@@ -22,6 +22,70 @@ namespace paddle_mobile {
...
@@ -22,6 +22,70 @@ namespace paddle_mobile {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
void
ExtractToImg
(
const
float
*
im_data
,
float
*
col_data
,
const
int
im_height
,
const
int
im_width
,
const
int
col_height
,
const
int
col_width
,
const
int
padding_h
,
const
int
padding_w
,
const
int
stride_h
,
const
int
stride_w
,
const
int
kh
,
const
int
kw
)
{
int
h
=
padding_h
-
kh
;
int
w
=
padding_w
-
kw
;
int
col_start_height
=
h
>
0
?
(
h
+
stride_h
-
1
)
/
stride_h
:
0
;
int
col_start_width
=
w
>
0
?
(
w
+
stride_w
-
1
)
/
stride_w
:
0
;
int
start_height
=
kh
+
col_start_height
*
stride_h
-
padding_h
;
int
start_width
=
kw
+
col_start_width
*
stride_w
-
padding_w
;
int
end_height
=
(
col_height
-
col_start_height
)
*
stride_h
+
start_height
;
end_height
=
end_height
>
im_height
?
im_height
:
end_height
;
int
end_width
=
(
col_width
-
col_start_width
)
*
stride_w
+
start_width
;
end_width
=
end_width
>
im_width
?
im_width
:
end_width
;
int
extract
=
(
end_width
-
start_width
+
stride_w
-
1
)
/
stride_w
;
im_data
+=
start_height
*
im_width
+
start_width
;
col_data
+=
col_start_height
*
col_width
+
col_start_width
;
for
(
int
i
=
start_height
;
i
<
end_height
;
i
+=
stride_h
)
{
if
(
stride_w
==
1
)
{
memcpy
(
col_data
,
im_data
,
extract
*
sizeof
(
float
));
}
else
if
(
stride_w
==
2
)
{
int
s
=
0
;
#if __ARM_NEON
for
(;
s
<
extract
-
3
;
s
+=
4
)
{
float32x4x2_t
img
=
vld2q_f32
(
im_data
+
s
*
2
);
vst1q_f32
(
col_data
+
s
,
img
.
val
[
0
]);
}
#endif
for
(;
s
<
extract
;
++
s
)
{
col_data
[
s
]
=
im_data
[
s
*
2
];
}
}
else
if
(
stride_w
==
3
)
{
int
s
=
0
;
#if __ARM_NEON
for
(;
s
<
extract
-
3
;
s
+=
4
)
{
float32x4x3_t
img
=
vld3q_f32
(
im_data
+
s
*
3
);
vst1q_f32
(
col_data
+
s
,
img
.
val
[
0
]);
}
#endif
for
(;
s
<
extract
;
++
s
)
{
col_data
[
s
]
=
im_data
[
s
*
3
];
}
}
else
if
(
stride_w
==
4
)
{
int
s
=
0
;
#if __ARM_NEON
for
(;
s
<
extract
-
3
;
s
+=
4
)
{
float32x4x4_t
img
=
vld4q_f32
(
im_data
+
s
*
4
);
vst1q_f32
(
col_data
+
s
,
img
.
val
[
0
]);
}
#endif
for
(;
s
<
extract
;
++
s
)
{
col_data
[
s
]
=
im_data
[
s
*
4
];
}
}
else
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"stride_w must be one of 1, 2, 3 and 4."
);
}
im_data
+=
im_width
*
stride_h
;
col_data
+=
col_width
;
}
}
/*
/*
* im = [input_channels, input_height, input_width]
* im = [input_channels, input_height, input_width]
* col =
* col =
...
@@ -363,7 +427,27 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
...
@@ -363,7 +427,27 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
col_data
+=
9
*
oosize
;
col_data
+=
9
*
oosize
;
im_data
+=
isize
*
isize
;
im_data
+=
isize
*
isize
;
}
}
}
else
if
(
stride
[
0
]
<=
4
&&
dilation
[
0
]
==
1
&&
dilation
[
0
]
==
dilation
[
1
])
{
int
im_spatial_size
=
im_height
*
im_width
;
int
col_spatial_size
=
col_height
*
col_width
;
// pad 0
memset
(
col_data
,
0
,
col
->
numel
()
*
sizeof
(
float
));
#pragma omp parallel for
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
const
float
*
local_im_data
=
im_data
+
ic
*
im_spatial_size
;
float
*
local_col_data
=
col_data
+
ic
*
filter_height
*
filter_width
*
col_spatial_size
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
ExtractToImg
(
local_im_data
,
local_col_data
,
im_height
,
im_width
,
col_height
,
col_width
,
padding
[
0
],
padding
[
1
],
stride
[
0
],
stride
[
1
],
kh
,
kw
);
local_col_data
+=
col_spatial_size
;
}
}
}
}
else
{
}
else
{
#endif
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
...
@@ -382,25 +466,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
...
@@ -382,25 +466,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
}
}
}
}
}
}
}
#if __ARM_NEON
#else
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%
filter_width
;
int
h_offset
=
(
c
/
filter_width
)
%
filter_height
;
int
c_im
=
c
/
(
filter_width
*
filter_height
);
for
(
int
h
=
0
;
h
<
col_height
;
++
h
)
{
int
im_row_idx
=
h
*
stride
[
0
]
-
padding
[
0
]
+
h_offset
*
dilation
[
0
];
for
(
int
w
=
0
;
w
<
col_width
;
++
w
)
{
int
im_col_idx
=
w
*
stride
[
1
]
-
padding
[
1
]
+
w_offset
*
dilation
[
1
];
int
col_idx
=
(
c
*
col_height
+
h
)
*
col_width
+
w
;
int
im_idx
=
(
im_row_idx
+
c_im
*
im_height
)
*
im_width
+
im_col_idx
;
col_data
[
col_idx
]
=
(
im_row_idx
<
0
||
im_row_idx
>=
im_height
||
im_col_idx
<
0
||
im_col_idx
>=
im_width
)
?
static_cast
<
float
>
(
0
)
:
im_data
[
im_idx
];
}
}
}
}
#endif
#endif
}
}
...
@@ -489,21 +555,26 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
...
@@ -489,21 +555,26 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int
channels_col
=
im_channels
*
filter_height
*
filter_width
;
int
channels_col
=
im_channels
*
filter_height
*
filter_width
;
const
int8_t
*
im_data
=
im
.
data
<
int8_t
>
();
const
int8_t
*
im_data
=
im
.
data
<
int8_t
>
();
int8_t
*
col_data
=
col
->
data
<
int8_t
>
();
int8_t
*
col_data
=
col
->
mutable_
data
<
int8_t
>
();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if
(
stride
[
0
]
<=
4
&&
dilation
[
0
]
==
1
&&
dilation
[
0
]
==
dilation
[
1
])
{
if
(
stride
[
0
]
<=
4
&&
dilation
[
0
]
==
1
&&
dilation
[
0
]
==
dilation
[
1
])
{
int
im_spatial_size
=
im_height
*
im_width
;
int
col_spatial_size
=
col_height
*
col_width
;
// pad 0
// pad 0
memset
(
col_data
,
0
,
col
->
numel
()
*
sizeof
(
int8_t
));
memset
(
col_data
,
0
,
col
->
numel
()
*
sizeof
(
int8_t
));
#pragma omp parallel for
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
for
(
int
ic
=
0
;
ic
<
im_channels
;
++
ic
)
{
const
int8_t
*
local_im_data
=
im_data
+
ic
*
im_spatial_size
;
int8_t
*
local_col_data
=
col_data
+
ic
*
filter_height
*
filter_width
*
col_spatial_size
;
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kh
=
0
;
kh
<
filter_height
;
++
kh
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
for
(
int
kw
=
0
;
kw
<
filter_width
;
++
kw
)
{
ExtractToImg
(
im_data
,
col_data
,
im_height
,
im_width
,
col_height
,
ExtractToImg
(
local_im_data
,
local_col_data
,
im_height
,
im_width
,
col_
width
,
padding
[
0
],
padding
[
1
],
stride
[
0
],
stride
[
1
],
col_
height
,
col_width
,
padding
[
0
],
padding
[
1
],
stride
[
0
],
kh
,
kw
);
stride
[
1
],
kh
,
kw
);
col_data
+=
col_height
*
col_width
;
local_col_data
+=
col_spatial_size
;
}
}
}
}
im_data
+=
im_height
*
im_width
;
}
}
}
else
{
}
else
{
#endif
#endif
...
...
src/operators/math/pad.cpp
浏览文件 @
8137d199
...
@@ -21,10 +21,12 @@ namespace math {
...
@@ -21,10 +21,12 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
class
PadFunctor
<
CPU
,
T
>
{
class
PadFunctor
<
CPU
,
T
>
{
public:
public:
void
operator
()(
const
framework
::
Tensor
&
input
,
const
int
pad_h
,
void
operator
()(
const
framework
::
Tensor
&
input
,
const
int
pad_top
,
const
int
pad_w
,
framework
::
Tensor
*
output
)
{
const
int
pad_bottom
,
const
int
pad_left
,
const
int
pad_right
,
framework
::
Tensor
*
output
)
{
const
T
*
in_data
=
input
.
data
<
T
>
();
const
T
*
in_data
=
input
.
data
<
T
>
();
T
*
out_data
=
output
->
mutable_data
<
T
>
();
T
*
out_data
=
output
->
mutable_data
<
T
>
();
// should check output shape is valid for such pad parameters
const
framework
::
DDim
&
input_shape
=
input
.
dims
();
const
framework
::
DDim
&
input_shape
=
input
.
dims
();
const
framework
::
DDim
&
output_shape
=
output
->
dims
();
const
framework
::
DDim
&
output_shape
=
output
->
dims
();
// fill output with 0
// fill output with 0
...
@@ -32,13 +34,13 @@ class PadFunctor<CPU, T> {
...
@@ -32,13 +34,13 @@ class PadFunctor<CPU, T> {
// should make sure the shape of output is match with input
// should make sure the shape of output is match with input
for
(
int
i
=
0
;
i
<
input_shape
[
0
];
++
i
)
{
for
(
int
i
=
0
;
i
<
input_shape
[
0
];
++
i
)
{
for
(
int
c
=
0
;
c
<
input_shape
[
1
];
++
c
)
{
for
(
int
c
=
0
;
c
<
input_shape
[
1
];
++
c
)
{
out_data
+=
pad_
h
*
output_shape
[
3
];
out_data
+=
pad_
top
*
output_shape
[
3
];
for
(
int
h
=
0
;
h
<
input_shape
[
2
];
++
h
)
{
for
(
int
h
=
0
;
h
<
input_shape
[
2
];
++
h
)
{
memcpy
(
out_data
+
pad_
w
,
in_data
,
sizeof
(
T
)
*
input_shape
[
3
]);
memcpy
(
out_data
+
pad_
left
,
in_data
,
sizeof
(
T
)
*
input_shape
[
3
]);
out_data
+=
output_shape
[
3
];
out_data
+=
output_shape
[
3
];
in_data
+=
input_shape
[
3
];
in_data
+=
input_shape
[
3
];
}
}
out_data
+=
pad_
h
*
output_shape
[
3
];
out_data
+=
pad_
bottom
*
output_shape
[
3
];
}
}
}
}
}
}
...
...
src/operators/math/pad.h
浏览文件 @
8137d199
...
@@ -22,8 +22,9 @@ namespace math {
...
@@ -22,8 +22,9 @@ namespace math {
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
class
PadFunctor
{
class
PadFunctor
{
public:
public:
void
operator
()(
const
framework
::
Tensor
&
input
,
const
int
pad_h
,
void
operator
()(
const
framework
::
Tensor
&
input
,
const
int
pad_top
,
const
int
pad_w
,
framework
::
Tensor
*
output
);
const
int
pad_bottom
,
const
int
pad_left
,
const
int
pad_right
,
framework
::
Tensor
*
output
);
};
};
}
// namespace math
}
// namespace math
...
...
src/operators/math/
depthwise_conv_3x3
.h
→
src/operators/math/
winograd/winograd_transform
.h
浏览文件 @
8137d199
...
@@ -12,40 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,40 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#ifdef CONV_OP
#pragma once
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
using
framework
::
Tensor
;
using
std
::
max
;
template
<
int
tile
,
int
kernel
>
using
std
::
min
;
void
winograd_transform_weight
(
const
framework
::
Tensor
&
weight
,
using
std
::
vector
;
framework
::
Tensor
*
output
);
void
DepthwiseConv3x3
(
const
Tensor
*
input
,
vector
<
int
>
strides
,
template
<
int
tile
,
int
kernel
>
vector
<
int
>
paddings
,
const
Tensor
*
filter
,
Tensor
*
bias
,
void
winograd_transform_input
(
const
framework
::
Tensor
&
input
,
Tensor
*
output
,
bool
if_bias
);
framework
::
Tensor
*
output
);
void
DepthwiseConv3x3s1p1
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
Tensor
*
output
,
Tensor
*
bias
,
bool
if_bias
);
template
<
int
tile
,
int
kernel
>
void
DepthwiseConvAddBNRelu3x3s1p1
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
void
winograd_transform_output
(
const
framework
::
Tensor
&
input
,
Tensor
*
output
,
const
Tensor
*
new_scale
,
const
framework
::
Tensor
&
weight
,
const
Tensor
*
new_bias
,
bool
if_relu
);
framework
::
Tensor
*
output
);
void
DepthwiseConvAddBNRelu3x3s2p1
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
Tensor
*
output
,
const
Tensor
*
new_scale
,
const
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConv3x3s2p1v2
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
Tensor
*
output
,
Tensor
bias
,
bool
if_bias
);
void
DepthwiseConvAddBNRelu3x3s2p1v2
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
Tensor
*
output
,
const
Tensor
*
new_scale
,
const
Tensor
*
new_bias
,
bool
if_relu
);
void
DepthwiseConv3x3s2p0
(
const
Tensor
*
input
,
const
Tensor
*
filter
,
Tensor
*
output
,
Tensor
bias
,
bool
if_bias
);
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle_mobile
}
// namespace paddle_mobile
#endif
src/operators/math/winograd/winograd_transform_f6k3.cpp
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// project.
#ifdef CONV_OP
#ifndef __aarch64__
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
namespace
paddle_mobile
{
namespace
operators
{
namespace
math
{
template
<
>
void
winograd_transform_weight
<
8
,
3
>
(
const
framework
::
Tensor
&
weight
,
framework
::
Tensor
*
output
)
{
/*
* w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9)
* w2 = ((g0 + g2) - g1) * (-2.0 / 9)
* w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90)
* w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90)
* w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2
*/
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
// package weight into [roundup(out_channel/4), 64, in_channel, 4] tiles
int
out_channel
=
weight
.
dims
()[
0
];
int
in_channel
=
weight
.
dims
()[
1
];
// reshape and alloc transformed weight
framework
::
DDim
transformed_shape
=
framework
::
make_ddim
(
std
::
vector
<
int
>
{(
out_channel
+
3
)
/
4
,
64
,
in_channel
,
4
});
float
*
trans_outptr
=
output
->
mutable_data
<
float
>
(
transformed_shape
);
memset
(
trans_outptr
,
0
,
output
->
numel
()
*
sizeof
(
float
));
const
float
transform_matrix
[
8
]
=
{
2.
f
,
-
2.
f
/
9
,
1.
f
/
90
,
1.
f
/
180
};
const
float
*
inptr
=
weight
.
data
<
float
>
();
int
remain_start
=
out_channel
&
0xFFFC
;
#if 0
remain_start = 0;
#else
#pragma omp parallel for
for
(
int
oc
=
0
;
oc
<
out_channel
-
3
;
oc
+=
4
)
{
float
gw
[
96
];
// gw[3][8][4]
const
float
*
inptr0
=
inptr
+
oc
*
in_channel
*
9
;
const
float
*
inptr1
=
inptr
+
(
oc
+
1
)
*
in_channel
*
9
;
const
float
*
inptr2
=
inptr
+
(
oc
+
2
)
*
in_channel
*
9
;
const
float
*
inptr3
=
inptr
+
(
oc
+
3
)
*
in_channel
*
9
;
// oc * 64 * in_channel
float
*
outptr
=
trans_outptr
+
((
oc
*
in_channel
)
<<
6
);
for
(
int
ic
=
0
;
ic
<
in_channel
;
++
ic
)
{
float
*
gw_ptr
=
gw
;
asm
volatile
(
"vld1.32 {d0-d1}, [%[tm_ptr]]
\n
"
"mov r0, #24
\n
"
"vld1.32 {d2-d5}, [%[inptr0]], r0
\n
"
"vld1.32 {d6-d9}, [%[inptr1]], r0
\n
"
"vld1.32 {d10-d13}, [%[inptr2]], r0
\n
"
"vld1.32 {d14-d17}, [%[inptr3]], r0
\n
"
"vtrn.32 q1, q3
\n
"
"vtrn.32 q2, q4
\n
"
"vtrn.32 q5, q7
\n
"
"vtrn.32 q6, q8
\n
"
"vswp.32 d3, d10
\n
"
"vswp.32 d7, d14
\n
"
"vswp.32 d5, d12
\n
"
"vswp.32 d9, d16
\n
"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]!
\n
"
"vadd.f32 q9, q1, q5
\n
"
"vadd.f32 q10, q9, q3
\n
"
"vsub.f32 q11, q9, q3
\n
"
"vmul.f32 q10, q10, d0[1]
\n
"
"vst1.32 {d20-d21}, [%[gw_ptr]]!
\n
"
"vmul.f32 q11, q11, d0[1]
\n
"
"vst1.32 {d22-d23}, [%[gw_ptr]]!
\n
"
"vmul.f32 q9, q1, d0[0]
\n
"
"vmul.f32 q9, q9, d0[0]
\n
"
// 4 * g0
"vmul.f32 q10, q3, d0[0]
\n
"
// 2 * g1
"vmul.f32 q11, q5, d0[0]
\n
"
"vmul.f32 q11, q11, d0[0]
\n
"
// 4 * g2
"vadd.f32 q12, q1, q11
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vadd.f32 q12, q5, q9
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vst1.32 {d10-d11}, [%[gw_ptr]]!
\n
"
// q7: g0, q2: g1, q4: g2
"vst1.32 {d14-d15}, [%[gw_ptr]]!
\n
"
"vadd.f32 q9, q7, q4
\n
"
"vadd.f32 q10, q9, q2
\n
"
"vsub.f32 q11, q9, q2
\n
"
"vmul.f32 q10, q10, d0[1]
\n
"
"vst1.32 {d20-d21}, [%[gw_ptr]]!
\n
"
"vmul.f32 q11, q11, d0[1]
\n
"
"vst1.32 {d22-d23}, [%[gw_ptr]]!
\n
"
"vmul.f32 q9, q7, d0[0]
\n
"
"vmul.f32 q9, q9, d0[0]
\n
"
// 4 * g0
"vmul.f32 q10, q2, d0[0]
\n
"
// 2 * g1
"vmul.f32 q11, q4, d0[0]
\n
"
"vmul.f32 q11, q11, d0[0]
\n
"
// 4 * g2
"vadd.f32 q12, q7, q11
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vadd.f32 q12, q4, q9
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vst1.32 {d8-d9}, [%[gw_ptr]]!
\n
"
"mov r0, #12
\n
"
"vld1.32 {d2-d3}, [%[inptr0]], r0
\n
"
"vld1.32 {d6-d7}, [%[inptr1]], r0
\n
"
"vld1.32 {d10-d11}, [%[inptr2]], r0
\n
"
"vld1.32 {d14-d15}, [%[inptr3]], r0
\n
"
"vtrn.32 q1, q3
\n
"
"vtrn.32 q5, q7
\n
"
"vswp.32 d3, d10
\n
"
"vswp.32 d7, d14
\n
"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]!
\n
"
"vadd.f32 q9, q1, q5
\n
"
"vadd.f32 q10, q9, q3
\n
"
"vsub.f32 q11, q9, q3
\n
"
"vmul.f32 q10, q10, d0[1]
\n
"
"vst1.32 {d20-d21}, [%[gw_ptr]]!
\n
"
"vmul.f32 q11, q11, d0[1]
\n
"
"vst1.32 {d22-d23}, [%[gw_ptr]]!
\n
"
"vmul.f32 q9, q1, d0[0]
\n
"
"vmul.f32 q9, q9, d0[0]
\n
"
// 4 * g0
"vmul.f32 q10, q3, d0[0]
\n
"
// 2 * g1
"vmul.f32 q11, q5, d0[0]
\n
"
"vmul.f32 q11, q11, d0[0]
\n
"
// 4 * g2
"vadd.f32 q12, q1, q11
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vadd.f32 q12, q5, q9
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[gw_ptr]]!
\n
"
"vst1.32 {d10-d11}, [%[gw_ptr]]!
\n
"
:
[
gw_ptr
]
"+r"
(
gw_ptr
),
[
inptr0
]
"+r"
(
inptr0
),
[
inptr1
]
"+r"
(
inptr1
),
[
inptr2
]
"+r"
(
inptr2
),
[
inptr3
]
"+r"
(
inptr3
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
);
float
*
gw_ptr0
=
gw
;
float
*
gw_ptr1
=
gw
+
32
;
float
*
gw_ptr2
=
gw
+
64
;
float
*
outptr0
=
outptr
+
(
ic
<<
2
);
// ic * 4
int
steps
=
(
in_channel
<<
2
)
*
sizeof
(
float
);
// in_channel * 4
asm
volatile
(
"vld1.32 {d0-d1}, [%[tm_ptr]]
\n
"
"mov r0, #8
\n
"
"loop_8_%=:
\n
"
"vld1.32 {d2-d3}, [%[gw_ptr0]]!
\n
"
"vld1.32 {d4-d5}, [%[gw_ptr1]]!
\n
"
"vld1.32 {d6-d7}, [%[gw_ptr2]]!
\n
"
// q1: g0, q2: g1, q3: g2
"vst1.32 {d2-d3}, [%[outptr0]], %[steps]
\n
"
"vadd.f32 q9, q1, q3
\n
"
"vadd.f32 q10, q9, q2
\n
"
"vsub.f32 q11, q9, q2
\n
"
"vmul.f32 q10, q10, d0[1]
\n
"
"vst1.32 {d20-d21}, [%[outptr0]], %[steps]
\n
"
"vmul.f32 q11, q11, d0[1]
\n
"
"vst1.32 {d22-d23}, [%[outptr0]], %[steps]
\n
"
"vmul.f32 q9, q1, d0[0]
\n
"
"vmul.f32 q9, q9, d0[0]
\n
"
// 4 * g0
"vmul.f32 q10, q2, d0[0]
\n
"
// 2 * g1
"vmul.f32 q11, q3, d0[0]
\n
"
"vmul.f32 q11, q11, d0[0]
\n
"
// 4 * g2
"vadd.f32 q12, q1, q11
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps]
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[0]
\n
"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps]
\n
"
// w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
"vadd.f32 q12, q3, q9
\n
"
"vadd.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps]
\n
"
"vsub.f32 q13, q12, q10
\n
"
"vmul.f32 q13, q13, d1[1]
\n
"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps]
\n
"
"vst1.32 {d6-d7}, [%[outptr0]], %[steps]
\n
"
"subs r0, #1
\n
"
"bne loop_8_%=
\n
"
:
[
outptr0
]
"+r"
(
outptr0
),
[
gw_ptr0
]
"+r"
(
gw_ptr0
),
[
gw_ptr1
]
"+r"
(
gw_ptr1
),
[
gw_ptr2
]
"+r"
(
gw_ptr2
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
),
[
steps
]
"r"
(
steps
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
);
}
}
#endif
// remain output channel
#pragma omp parallel for
for
(
int
oc
=
remain_start
;
oc
<
out_channel
;
++
oc
)
{
float
gw
[
3
][
8
];
// gw[3][8]
const
float
*
inptr0
=
inptr
+
oc
*
in_channel
*
9
;
//
// (oc / 4) * 64 * in_channel * 4 + oc % 4
int
offset
=
((
oc
&
0xFFFC
)
<<
6
)
*
in_channel
+
(
oc
&
0x3
);
int
steps
=
(
in_channel
<<
2
);
// in_channel * 4
float
*
outptr
=
trans_outptr
+
offset
;
for
(
int
ic
=
0
;
ic
<
in_channel
;
++
ic
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
,
inptr0
+=
3
)
{
float
g0
=
inptr0
[
0
];
float
g1
=
inptr0
[
1
];
float
g2
=
inptr0
[
2
];
float
d0
=
g0
+
g2
;
float
d1
=
g0
+
4
*
g2
;
float
d2
=
g2
+
4
*
g0
;
float
d3
=
2
*
g1
;
gw
[
i
][
0
]
=
g0
;
gw
[
i
][
1
]
=
-
2.
f
/
9
*
(
d0
+
g1
);
// -2.f/9 * (g0 + g1 + g2)
gw
[
i
][
2
]
=
-
2.
f
/
9
*
(
d0
-
g1
);
// -2.f/9 * (g0 - g1 + g2)
gw
[
i
][
3
]
=
1.
f
/
90
*
(
d1
+
d3
);
// 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw
[
i
][
4
]
=
1.
f
/
90
*
(
d1
-
d3
);
// 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw
[
i
][
5
]
=
1.
f
/
180
*
(
d2
+
d3
);
// 1.f/180 * (4 * g0 + 2 * g1 + g2)
gw
[
i
][
6
]
=
1.
f
/
180
*
(
d2
-
d3
);
// 1.f/180 * (4 * g0 - 2 * g1 + g2)
gw
[
i
][
7
]
=
g2
;
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
g0
=
gw
[
0
][
i
];
float
g1
=
gw
[
1
][
i
];
float
g2
=
gw
[
2
][
i
];
float
d0
=
g0
+
g2
;
float
d1
=
g0
+
4
*
g2
;
float
d2
=
g2
+
4
*
g0
;
float
d3
=
2
*
g1
;
int
offset
=
i
*
8
*
steps
;
outptr
[
offset
]
=
g0
;
outptr
[
offset
+
1
*
steps
]
=
-
2.
f
/
9
*
(
d0
+
g1
);
outptr
[
offset
+
2
*
steps
]
=
-
2.
f
/
9
*
(
d0
-
g1
);
outptr
[
offset
+
3
*
steps
]
=
1.
f
/
90
*
(
d1
+
d3
);
outptr
[
offset
+
4
*
steps
]
=
1.
f
/
90
*
(
d1
-
d3
);
outptr
[
offset
+
5
*
steps
]
=
1.
f
/
180
*
(
d2
+
d3
);
outptr
[
offset
+
6
*
steps
]
=
1.
f
/
180
*
(
d2
-
d3
);
outptr
[
offset
+
7
*
steps
]
=
g2
;
}
outptr
+=
4
;
}
}
}
template
<
>
void
winograd_transform_input
<
8
,
3
>
(
const
framework
::
Tensor
&
input
,
framework
::
Tensor
*
output
)
{
/*
* x0 = (d0 - d6) + (d4 - d2) * 5.25
* x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5)
* x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5)
* x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x7 = (d7 - d1) + (d3 - d5) * 5.25
*/
// package input into [roundup(tiles/8), 64, channel, 8] tiles
int
channel
=
input
.
dims
()[
1
];
int
height
=
input
.
dims
()[
2
];
int
width
=
input
.
dims
()[
3
];
int
h_tiles
=
(
height
+
3
)
/
6
;
// (height - 8 + 5 + 6) / 6
int
w_tiles
=
(
width
+
3
)
/
6
;
// (width - 8 + 5 + 6) / 6
int
tiles
=
(
h_tiles
*
w_tiles
+
7
)
/
8
;
framework
::
DDim
transformed_shape
=
framework
::
make_ddim
(
std
::
vector
<
int
>
{
tiles
,
64
,
channel
,
8
});
float
*
outptr
=
output
->
mutable_data
<
float
>
(
transformed_shape
);
memset
(
outptr
,
0
,
output
->
numel
()
*
sizeof
(
float
));
const
float
*
inptr
=
input
.
data
<
float
>
();
int
inter_h
=
(
height
-
2
)
/
6
;
int
inter_w
=
(
width
-
2
)
/
6
;
int
remain_h
=
height
-
(
inter_h
*
6
);
int
remain_w
=
width
-
(
inter_w
*
6
);
framework
::
Tensor
input_pad
;
if
(
remain_h
>
2
||
remain_w
>
2
)
{
inter_h
+=
(
remain_h
>
2
);
inter_w
+=
(
remain_w
>
2
);
height
=
(
inter_h
-
1
)
*
6
+
8
;
width
=
(
inter_w
-
1
)
*
6
+
8
;
framework
::
DDim
input_shape
=
framework
::
make_ddim
(
std
::
vector
<
int
>
{
1
,
channel
,
height
,
width
});
PadFunctor
<
CPU
,
float
>
pad
;
inptr
=
input_pad
.
mutable_data
<
float
>
(
input_shape
);
pad
(
input
,
0
,
height
-
input
.
dims
()[
2
],
0
,
width
-
input
.
dims
()[
3
],
&
input_pad
);
}
size_t
image_size
=
height
*
width
;
const
float
transform_matrix
[
8
]
=
{
5.25
f
,
-
5.
f
,
-
4.25
f
,
-
2.5
f
,
2.
f
,
-
1.25
f
,
0.5
f
,
0.25
f
};
int
remain_c_start
=
channel
&
0xFFFC
;
#if 1
remain_c_start
=
0
;
#else
#pragma omp parallel for
for
(
int
c
=
0
;
c
<
channel
-
3
;
c
+=
4
)
{
const
float
*
in
=
inptr
+
c
*
image_size
;
float
d_bt
[
64
*
4
];
// d * B_t
for
(
int
h
=
0
;
h
<
h_tiles
;
++
h
)
{
for
(
int
w
=
0
;
w
<
w_tiles
;
++
w
)
{
const
float
*
in0
=
in
+
(
h
*
width
+
w
)
*
6
;
const
float
*
in1
=
in0
+
image_size
;
const
float
*
in2
=
in1
+
image_size
;
const
float
*
in3
=
in2
+
image_size
;
int
steps
=
width
*
sizeof
(
float
);
float
*
d_bt_ptr
=
d_bt
;
asm
volatile
(
"mov r0, #8
\n
"
"vld1.32 {d0-d3}, [%[tm_ptr]]
\n
"
// row loop
"loop_r_%=:
\n
"
"vld1.32 {d4-d7}, [%[in0]], %[steps]
\n
"
"vld1.32 {d8-d11}, [%[in1]], %[steps]
\n
"
"vld1.32 {d12-d15}, [%[in2]], %[steps]
\n
"
"vld1.32 {d16-d19}, [%[in3]], %[steps]
\n
"
"vtrn.32 q2, q4
\n
"
// d0: q2
"vtrn.32 q3, q5
\n
"
// d1: q4
"vtrn.32 q6, q8
\n
"
// d2: q6
"vtrn.32 q7, q9
\n
"
// d3: q8
"vswp.32 d5, d12
\n
"
// d4: q3
"vswp.32 d9, d16
\n
"
// d5: q5
"vswp.32 d7, d14
\n
"
// d6: q7
"vswp.32 d11, d18
\n
"
// d7: q9
"vsub.f32 q10, q2, q7
\n
"
"vsub.f32 q11, q3, q6
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
// d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20-d21}, [%[d_bt]]!
\n
"
"vadd.f32 q10, q6, q7
\n
"
"vadd.f32 q11, q4, q5
\n
"
"vmla.f32 q10, q3, d1[0]
\n
"
// d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q8, d1[0]
\n
"
// d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11
\n
"
"vsub.f32 q13, q10, q11
\n
"
"vst1.32 {d24-d27}, [%[d_bt]]!
\n
"
"vmul.f32 q10, q6, d3[1]
\n
"
// 0.25 * d2
"vmul.f32 q11, q4, d3[0]
\n
"
// 0.5 * d1
"vadd.f32 q10, q10, q7
\n
"
// 0.25 * d2 + d6
"vmla.f32 q11, q5, d2[0]
\n
"
// 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q3, d2[1]
\n
"
// 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q8, d1[1]
\n
"
// 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11
\n
"
"vsub.f32 q13, q10, q11
\n
"
"vst1.32 {d24-d27}, [%[d_bt]]!
\n
"
"vmul.f32 q10, q6, d2[0]
\n
"
// 2 * d2
"vmul.f32 q11, q4, d2[0]
\n
"
// 2 * d1
"vmla.f32 q10, q3, d1[1]
\n
"
// 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q8, d1[1]
\n
"
// 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q7, d3[0]
\n
"
// 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q5, d3[0]
\n
"
// 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0]
\n
"
// 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11
\n
"
"vsub.f32 q13, q10, q11
\n
"
"vst1.32 {d24-d27}, [%[d_bt]]!
\n
"
"vsub.f32 q10, q9, q4
\n
"
"vsub.f32 q11, q8, q5
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
"vst1.32 {d20-d21}, [%[d_bt]]!
\n
"
"subs r0, #1
\n
"
"bne loop_r_%=
\n
"
:
[
d_bt
]
"+r"
(
d_bt_ptr
),
[
in0
]
"+r"
(
in0
),
[
in1
]
"+r"
(
in1
),
[
in2
]
"+r"
(
in2
),
[
in3
]
"+r"
(
in3
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
),
[
steps
]
"r"
(
steps
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
);
float
*
ptr0
=
d_bt
;
float
*
ptr1
=
ptr0
+
32
;
float
*
ptr2
=
ptr1
+
32
;
float
*
ptr3
=
ptr2
+
32
;
float
*
ptr4
=
ptr3
+
32
;
float
*
ptr5
=
ptr4
+
32
;
float
*
ptr6
=
ptr5
+
32
;
float
*
ptr7
=
ptr6
+
32
;
int
tile_indics
=
h
*
w_tiles
+
w
;
int
tile_block
=
tile_indics
>>
3
;
int
block_indics
=
tile_indics
&
0x7
;
// (tiles / 8, 64, channel, 8)
float
*
out0
=
outptr
+
(
tile_block
*
64
*
channel
+
c
)
*
8
+
block_indics
;
steps
=
(
channel
-
3
)
*
8
*
sizeof
(
float
);
asm
volatile
(
"vld1.32 {d0-d3}, [%[tm_ptr]]
\n
"
"mov r0, 4
\n
"
"mov r1, 32
\n
"
"loop_col_%=:
\n
"
// col 0:
"vld1.32 {d4-d5}, [%[ptr0]]!
\n
"
// q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]!
\n
"
// q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]!
\n
"
// q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]!
\n
"
// q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]!
\n
"
// q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]!
\n
"
// q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]!
\n
"
// q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]!
\n
"
// q9: d7
"vsub.f32 q10, q2, q8
\n
"
// d0 - d6
"vsub.f32 q11, q6, q4
\n
"
// d4 - d2
"vmla.f32 q10, q11, d0[0]
\n
"
// d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], r1
\n
"
"vst1.32 {d20[1]}, [%[out0]], r1
\n
"
"vst1.32 {d21[0]}, [%[out0]], r1
\n
"
"vst1.32 {d21[1]}, [%[out0]], %[steps]
\n
"
"vadd.f32 q10, q4, q8
\n
"
"vadd.f32 q11, q3, q7
\n
"
"vmla.f32 q10, q6, d1[0]
\n
"
// d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0]
\n
"
// d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vmul.f32 q10, q4, d3[1]
\n
"
// 0.25 * d2
"vmul.f32 q11, q3, d3[0]
\n
"
// 0.5 * d1
"vadd.f32 q10, q10, q8
\n
"
// 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0]
\n
"
// 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1]
\n
"
// 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1]
\n
"
// 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vmul.f32 q10, q4, d2[0]
\n
"
// 2 * d2
"vmul.f32 q11, q3, d2[0]
\n
"
// 2 * d1
"vmla.f32 q10, q6, d1[1]
\n
"
// 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1]
\n
"
// 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0]
\n
"
// 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0]
\n
"
// 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0]
\n
"
// 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q10, q9, q3
\n
"
"vsub.f32 q11, q5, q7
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
"vst1.32 {d20[0]}, [%[out0]], r1
\n
"
"vst1.32 {d20[1]}, [%[out0]], r1
\n
"
"vst1.32 {d21[0]}, [%[out0]], r1
\n
"
"vst1.32 {d21[1]}, [%[out0]], %[steps]
\n
"
// col 1:
"vld1.32 {d4-d5}, [%[ptr0]]!
\n
"
// q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]!
\n
"
// q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]!
\n
"
// q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]!
\n
"
// q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]!
\n
"
// q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]!
\n
"
// q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]!
\n
"
// q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]!
\n
"
// q9: d7
"vsub.f32 q10, q2, q8
\n
"
// d0 - d6
"vsub.f32 q11, q6, q4
\n
"
// d4 - d2
"vmla.f32 q10, q11, d0[0]
\n
"
// d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], r1
\n
"
"vst1.32 {d20[1]}, [%[out0]], r1
\n
"
"vst1.32 {d21[0]}, [%[out0]], r1
\n
"
"vst1.32 {d21[1]}, [%[out0]], %[steps]
\n
"
"vadd.f32 q10, q4, q8
\n
"
"vadd.f32 q11, q3, q7
\n
"
"vmla.f32 q10, q6, d1[0]
\n
"
// d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0]
\n
"
// d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vmul.f32 q10, q4, d3[1]
\n
"
// 0.25 * d2
"vmul.f32 q11, q3, d3[0]
\n
"
// 0.5 * d1
"vadd.f32 q10, q10, q8
\n
"
// 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0]
\n
"
// 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1]
\n
"
// 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1]
\n
"
// 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vmul.f32 q10, q4, d2[0]
\n
"
// 2 * d2
"vmul.f32 q11, q3, d2[0]
\n
"
// 2 * d1
"vmla.f32 q10, q6, d1[1]
\n
"
// 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1]
\n
"
// 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0]
\n
"
// 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0]
\n
"
// 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0]
\n
"
// 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out0]], r1
\n
"
"vst1.32 {d24[1]}, [%[out0]], r1
\n
"
"vst1.32 {d25[0]}, [%[out0]], r1
\n
"
"vst1.32 {d25[1]}, [%[out0]], %[steps]
\n
"
"vsub.f32 q10, q9, q3
\n
"
"vsub.f32 q11, q5, q7
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
"vst1.32 {d20[0]}, [%[out0]], r1
\n
"
"vst1.32 {d20[1]}, [%[out0]], r1
\n
"
"vst1.32 {d21[0]}, [%[out0]], r1
\n
"
"vst1.32 {d21[1]}, [%[out0]], %[steps]
\n
"
"subs r0, #1
\n
"
"bne loop_col_%=
\n
"
:
[
out0
]
"+r"
(
out0
),
[
ptr0
]
"+r"
(
ptr0
),
[
ptr1
]
"+r"
(
ptr1
),
[
ptr2
]
"+r"
(
ptr2
),
[
ptr3
]
"+r"
(
ptr3
),
[
ptr4
]
"+r"
(
ptr4
),
[
ptr5
]
"+r"
(
ptr5
),
[
ptr6
]
"+r"
(
ptr6
),
[
ptr7
]
"+r"
(
ptr7
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
),
[
steps
]
"r"
(
steps
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
,
"r1"
);
}
}
}
#endif
// remainer channels
#pragma omp parallel for
for
(
int
c
=
remain_c_start
;
c
<
channel
;
++
c
)
{
const
float
*
in
=
inptr
+
c
*
image_size
;
float
d_bt
[
64
];
// d * B_t
for
(
int
h
=
0
;
h
<
h_tiles
;
++
h
)
{
for
(
int
w
=
0
;
w
<
w_tiles
;
++
w
)
{
const
float
*
in0
=
in
+
(
h
*
width
+
w
)
*
6
;
const
float
*
in1
=
in0
+
width
;
const
float
*
in2
=
in1
+
width
;
const
float
*
in3
=
in2
+
width
;
float
*
d_bt_ptr
=
d_bt
;
int
steps
=
4
*
width
*
sizeof
(
float
);
asm
volatile
(
"vld1.32 {d0-d3}, [%[tm_ptr]]
\n
"
"mov r0, #2
\n
"
// row loop
"loop_r_%=:
\n
"
"vld1.32 {d4-d7}, [%[in0]], %[steps]
\n
"
"vld1.32 {d8-d11}, [%[in1]], %[steps]
\n
"
"vld1.32 {d12-d15}, [%[in2]], %[steps]
\n
"
"vld1.32 {d16-d19}, [%[in3]], %[steps]
\n
"
"vtrn.32 q2, q4
\n
"
// d0: q2
"vtrn.32 q3, q5
\n
"
// d1: q4
"vtrn.32 q6, q8
\n
"
// d2: q6
"vtrn.32 q7, q9
\n
"
// d3: q8
"vswp.32 d5, d12
\n
"
// d4: q3
"vswp.32 d9, d16
\n
"
// d5: q5
"vswp.32 d7, d14
\n
"
// d6: q7
"vswp.32 d11, d18
\n
"
// d7: q9
"vsub.f32 q10, q2, q7
\n
"
"vsub.f32 q11, q3, q6
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
// d0 - d6 + (d4 -
// d2) * 5.25"
"vst1.32 {d20-d21}, [%[d_bt]]!
\n
"
"vadd.f32 q10, q6, q7
\n
"
"vadd.f32 q11, q4, q5
\n
"
"vmla.f32 q10, q3, d1[0]
\n
"
// d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q8, d1[0]
\n
"
// d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11
\n
"
"vsub.f32 q13, q10, q11
\n
"
"vst1.32 {d24-d27}, [%[d_bt]]!
\n
"
"vmul.f32 q10, q6, d3[1]
\n
"
// 0.25 * d2
"vmul.f32 q11, q4, d3[0]
\n
"
// 0.5 * d1
"vadd.f32 q10, q10, q7
\n
"
// 0.25 * d2 + d6
"vmla.f32 q11, q5, d2[0]
\n
"
// 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q3, d2[1]
\n
"
// 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q8, d1[1]
\n
"
// 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11
\n
"
"vsub.f32 q13, q10, q11
\n
"
"vst1.32 {d24-d27}, [%[d_bt]]!
\n
"
"vmul.f32 q10, q6, d2[0]
\n
"
// 2 * d2
"vmul.f32 q11, q4, d2[0]
\n
"
// 2 * d1
"vmla.f32 q10, q3, d1[1]
\n
"
// 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q8, d1[1]
\n
"
// 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q7, d3[0]
\n
"
// 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q5, d3[0]
\n
"
// 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0]
\n
"
// 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11
\n
"
"vsub.f32 q13, q10, q11
\n
"
"vst1.32 {d24-d27}, [%[d_bt]]!
\n
"
"vsub.f32 q10, q9, q4
\n
"
"vsub.f32 q11, q8, q5
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
"vst1.32 {d20-d21}, [%[d_bt]]!
\n
"
"subs r0, #1
\n
"
"bne loop_r_%=
\n
"
:
[
d_bt
]
"+r"
(
d_bt_ptr
),
[
in0
]
"+r"
(
in0
),
[
in1
]
"+r"
(
in1
),
[
in2
]
"+r"
(
in2
),
[
in3
]
"+r"
(
in3
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
),
[
steps
]
"r"
(
steps
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
);
float
*
ptr0
=
d_bt
;
float
*
ptr1
=
ptr0
+
32
;
int
tile_indics
=
h
*
w_tiles
+
w
;
int
tile_block
=
tile_indics
>>
3
;
int
block_indics
=
tile_indics
&
0x7
;
// (tiles / 8, 64, channel, 8)
float
*
out0
=
outptr
+
(
tile_block
*
64
*
channel
+
c
)
*
8
+
block_indics
;
float
*
out1
=
out0
+
channel
*
8
;
float
*
out2
=
out1
+
channel
*
8
;
float
*
out3
=
out2
+
channel
*
8
;
float
*
out4
=
out3
+
channel
*
8
;
float
*
out5
=
out4
+
channel
*
8
;
float
*
out6
=
out5
+
channel
*
8
;
float
*
out7
=
out6
+
channel
*
8
;
steps
=
8
*
channel
*
8
*
sizeof
(
float
);
asm
volatile
(
"mov r0, #2
\n
"
"vld1.32 {d0-d3}, [%[tm_ptr]]
\n
"
// row loop
"loop_r_%=:
\n
"
"vld1.32 {d4-d7}, [%[ptr0]]!
\n
"
// q2: d0, q3: d1
"vld1.32 {d8-d11}, [%[ptr0]]!
\n
"
// q4: d2, q5: d3
"vld1.32 {d12-d15}, [%[ptr1]]!
\n
"
// q6: d4, q7: d5
"vld1.32 {d16-d19}, [%[ptr1]]!
\n
"
// q8: d6, q9: d7
"vtrn.32 q2, q3
\n
"
"vtrn.32 q4, q5
\n
"
"vtrn.32 q6, q7
\n
"
"vtrn.32 q8, q9
\n
"
"vswp.32 d5, d8
\n
"
"vswp.32 d7, d10
\n
"
"vswp.32 d13, d16
\n
"
"vswp.32 d15, d18
\n
"
"vsub.f32 q10, q2, q8
\n
"
// d0 - d6
"vsub.f32 q11, q6, q4
\n
"
// d4 - d2
"vmla.f32 q10, q11, d0[0]
\n
"
// d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], %[steps]
\n
"
"vst1.32 {d20[1]}, [%[out0]], %[steps]
\n
"
"vst1.32 {d21[0]}, [%[out0]], %[steps]
\n
"
"vst1.32 {d21[1]}, [%[out0]], %[steps]
\n
"
"vadd.f32 q10, q4, q8
\n
"
"vadd.f32 q11, q3, q7
\n
"
"vmla.f32 q10, q6, d1[0]
\n
"
// d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0]
\n
"
// d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out1]], %[steps]
\n
"
"vst1.32 {d24[1]}, [%[out1]], %[steps]
\n
"
"vst1.32 {d25[0]}, [%[out1]], %[steps]
\n
"
"vst1.32 {d25[1]}, [%[out1]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out2]], %[steps]
\n
"
"vst1.32 {d24[1]}, [%[out2]], %[steps]
\n
"
"vst1.32 {d25[0]}, [%[out2]], %[steps]
\n
"
"vst1.32 {d25[1]}, [%[out2]], %[steps]
\n
"
"vmul.f32 q10, q4, d3[1]
\n
"
// 0.25 * d2
"vmul.f32 q11, q3, d3[0]
\n
"
// 0.5 * d1
"vadd.f32 q10, q10, q8
\n
"
// 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0]
\n
"
// 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1]
\n
"
// 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1]
\n
"
// 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out3]], %[steps]
\n
"
"vst1.32 {d24[1]}, [%[out3]], %[steps]
\n
"
"vst1.32 {d25[0]}, [%[out3]], %[steps]
\n
"
"vst1.32 {d25[1]}, [%[out3]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out4]], %[steps]
\n
"
"vst1.32 {d24[1]}, [%[out4]], %[steps]
\n
"
"vst1.32 {d25[0]}, [%[out4]], %[steps]
\n
"
"vst1.32 {d25[1]}, [%[out4]], %[steps]
\n
"
"vmul.f32 q10, q4, d2[0]
\n
"
// 2 * d2
"vmul.f32 q11, q3, d2[0]
\n
"
// 2 * d1
"vmla.f32 q10, q6, d1[1]
\n
"
// 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1]
\n
"
// 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0]
\n
"
// 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0]
\n
"
// 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0]
\n
"
// 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out5]], %[steps]
\n
"
"vst1.32 {d24[1]}, [%[out5]], %[steps]
\n
"
"vst1.32 {d25[0]}, [%[out5]], %[steps]
\n
"
"vst1.32 {d25[1]}, [%[out5]], %[steps]
\n
"
"vsub.f32 q12, q10, q11
\n
"
"vst1.32 {d24[0]}, [%[out6]], %[steps]
\n
"
"vst1.32 {d24[1]}, [%[out6]], %[steps]
\n
"
"vst1.32 {d25[0]}, [%[out6]], %[steps]
\n
"
"vst1.32 {d25[1]}, [%[out6]], %[steps]
\n
"
"vsub.f32 q10, q9, q3
\n
"
"vsub.f32 q11, q5, q7
\n
"
"vmla.f32 q10, q11, d0[0]
\n
"
"vst1.32 {d20[0]}, [%[out7]], %[steps]
\n
"
"vst1.32 {d20[1]}, [%[out7]], %[steps]
\n
"
"vst1.32 {d21[0]}, [%[out7]], %[steps]
\n
"
"vst1.32 {d21[1]}, [%[out7]], %[steps]
\n
"
"subs r0, #1
\n
"
"bne loop_r_%=
\n
"
:
[
out0
]
"+r"
(
out0
),
[
out1
]
"+r"
(
out1
),
[
out2
]
"+r"
(
out2
),
[
out3
]
"+r"
(
out3
),
[
out4
]
"+r"
(
out4
),
[
out5
]
"+r"
(
out5
),
[
out6
]
"+r"
(
out6
),
[
out7
]
"+r"
(
out7
),
[
ptr0
]
"+r"
(
ptr0
),
[
ptr1
]
"+r"
(
ptr1
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
),
[
steps
]
"r"
(
steps
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"r0"
);
}
}
}
}
template
<
>
void
winograd_transform_output
<
8
,
3
>
(
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
weight
,
framework
::
Tensor
*
output
)
{
// weight shape is [out_channel/4, 64, in_channel, 4],
// input shape is [hw/8, 64, in_channel, 8]
int
in_channel
=
input
.
dims
()[
2
];
int
tiles
=
input
.
dims
()[
0
];
int
out_channel
=
weight
.
dims
()[
0
];
// compute U*V first
framework
::
Tensor
uv_trans
;
framework
::
DDim
shape
=
framework
::
make_ddim
(
std
::
vector
<
int
>
{
out_channel
,
tiles
,
64
,
32
});
float
*
uv_trans_ptr
=
uv_trans
.
mutable_data
<
float
>
(
shape
);
memset
(
uv_trans_ptr
,
0
,
uv_trans
.
numel
()
*
sizeof
(
float
));
const
float
*
input_ptr
=
input
.
data
<
float
>
();
const
float
*
weight_ptr
=
weight
.
data
<
float
>
();
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
out_channel
;
++
i
)
{
float
*
uv_ptr
=
uv_trans_ptr
+
(
i
*
tiles
*
64
*
32
);
for
(
int
j
=
0
;
j
<
tiles
;
++
j
)
{
for
(
int
k
=
0
;
k
<
64
;
++
k
)
{
const
float
*
w_ptr
=
weight_ptr
+
(
i
*
64
+
k
)
*
in_channel
*
4
;
const
float
*
in_ptr
=
input_ptr
+
(
j
*
64
+
k
)
*
in_channel
*
8
;
int
inter_channel
=
in_channel
>>
1
;
int
remain_channel
=
in_channel
&
0x1
;
asm
volatile
(
"veor q8, q8, q8
\n
"
"veor q9, q9, q9
\n
"
"veor q10, q10, q10
\n
"
"veor q11, q11, q11
\n
"
"veor q12, q12, q12
\n
"
"veor q13, q13, q13
\n
"
"veor q14, q14, q14
\n
"
"veor q15, q15, q15
\n
"
"b store_res_%=
\n
"
// loop 2 channels
"loop_2c_%=:
\n
"
"vld1.32 {d0-d3}, [%[w_ptr]]!
\n
"
"vld1.32 {d4-d7}, [%[in_ptr]]!
\n
"
"vld1.32 {d8-d11}, [%[in_ptr]]!
\n
"
"vmla.f32 q8, q2, d0[0]
\n
"
"vmla.f32 q9, q3, d0[0]
\n
"
"vmla.f32 q10, q2, d0[1]
\n
"
"vmla.f32 q11, q3, d0[1]
\n
"
"vmla.f32 q12, q2, d1[0]
\n
"
"vmla.f32 q13, q3, d1[0]
\n
"
"vmla.f32 q14, q2, d1[1]
\n
"
"vmla.f32 q15, q3, d1[1]
\n
"
"vmla.f32 q8, q4, d2[0]
\n
"
"vmla.f32 q9, q5, d2[0]
\n
"
"vmla.f32 q10, q4, d2[1]
\n
"
"vmla.f32 q11, q5, d2[1]
\n
"
"vmla.f32 q12, q4, d3[0]
\n
"
"vmla.f32 q13, q5, d3[0]
\n
"
"vmla.f32 q14, q4, d3[1]
\n
"
"vmla.f32 q15, q5, d3[1]
\n
"
"subs %[inter_channel], #1
\n
"
"bne loop_2c_%=
\n
"
"mov pc, lr
\n
"
// loop 1 channel
"loop_c_%=:
\n
"
"vld1.32 {d0-d1}, [%[w_ptr]]!
\n
"
"vld1.32 {d4-d7}, [%[in_ptr]]!
\n
"
"vmla.f32 q8, q2, d0[0]
\n
"
"vmla.f32 q9, q3, d0[0]
\n
"
"vmla.f32 q10, q2, d0[1]
\n
"
"vmla.f32 q11, q3, d0[1]
\n
"
"vmla.f32 q12, q2, d1[0]
\n
"
"vmla.f32 q13, q3, d1[0]
\n
"
"vmla.f32 q14, q2, d1[1]
\n
"
"vmla.f32 q15, q3, d1[1]
\n
"
"subs %[remain_channel], #1
\n
"
"bne loop_c_%=
\n
"
"mov pc, lr
\n
"
"store_res_%=:
\n
"
"cmp %[inter_channel], #0
\n
"
"it gt
\n
"
"blgt loop_2c_%=
\n
"
"cmp %[remain_channel], #0
\n
"
"it gt
\n
"
"blgt loop_c_%=
\n
"
"vst1.32 {d16-d19}, [%[uv_ptr]]!
\n
"
"vst1.32 {d20-d23}, [%[uv_ptr]]!
\n
"
"vst1.32 {d24-d27}, [%[uv_ptr]]!
\n
"
"vst1.32 {d28-d31}, [%[uv_ptr]]!
\n
"
:
[
w_ptr
]
"+r"
(
w_ptr
),
[
in_ptr
]
"+r"
(
in_ptr
),
[
uv_ptr
]
"+r"
(
uv_ptr
),
[
remain_channel
]
"+r"
(
remain_channel
),
[
inter_channel
]
"+r"
(
inter_channel
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"pc"
,
"lr"
);
}
}
}
/*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
* s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6)
* s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6)
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6)
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) + m7
*/
int
out_h
=
output
->
dims
()[
2
];
int
out_w
=
output
->
dims
()[
3
];
int
h_tiles
=
(
out_h
+
5
)
/
6
;
int
w_tiles
=
(
out_w
+
5
)
/
6
;
int
remain_h
=
out_h
-
out_h
/
6
*
6
;
int
remain_w
=
out_w
-
out_w
/
6
*
6
;
float
*
output_ptr
=
output
->
mutable_data
<
float
>
();
float
transform_matrix
[
8
]
=
{
2.
f
,
4.
f
,
8.
f
,
16.
f
};
#pragma omp parallel for
for
(
int
oc
=
0
;
oc
<
output
->
dims
()[
1
];
++
oc
)
{
float
at_m
[
48
];
// [6][8]
float
output_tmp
[
36
];
// [6][6], temporarily restore results
// (oc / 4) * tiles * 64 * 32 + (oc & 0x3) * 8
const
float
*
uv_ptr
=
uv_trans_ptr
+
(
oc
>>
2
)
*
tiles
*
64
*
32
+
(
oc
&
0x3
)
*
8
;
for
(
int
tile_h
=
0
;
tile_h
<
h_tiles
;
++
tile_h
)
{
for
(
int
tile_w
=
0
;
tile_w
<
w_tiles
;
++
tile_w
)
{
float
*
at_m_ptr
=
at_m
;
int
tile_indics
=
tile_h
*
w_tiles
+
tile_w
;
int
tile_block
=
tile_indics
>>
3
;
int
block_indics
=
tile_indics
&
0x7
;
const
float
*
uv_ptr0
=
uv_ptr
+
tile_block
*
64
*
32
+
block_indics
;
int
steps
=
32
*
sizeof
(
float
);
asm
volatile
(
"vld1.32 {d0-d1}, [%[tm_ptr]]
\n
"
"mov r0, #2
\n
"
"loop_%=:
\n
"
"vld1.32 {d2[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d6[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d10[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d14[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d4[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d8[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d12[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d16[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d2[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d6[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d10[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d14[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d4[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d8[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d12[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d16[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d3[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d7[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d11[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d15[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d5[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d9[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d13[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d17[0]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d3[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d7[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d11[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d15[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d5[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d9[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d13[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vld1.32 {d17[1]}, [%[uv_ptr0]], %[steps]
\n
"
"vadd.f32 q9, q3, q5
\n
"
// m1 + m2
"vadd.f32 q10, q7, q2
\n
"
// m3 + m4
"vadd.f32 q11, q4, q6
\n
"
// m5 + m6
"vsub.f32 q12, q3, q5
\n
"
// m1 - m2
"vsub.f32 q13, q7, q2
\n
"
// m3 - m4
"vsub.f32 q14, q4, q6
\n
"
// m5 - m6
"vmul.f32 q2, q13, d0[0]
\n
"
// 2 * (m3 - m4)
"vmul.f32 q3, q11, d0[0]
\n
"
// 2 * (m5 + m6)
"vadd.f32 q15, q1, q9
\n
"
"vadd.f32 q15, q15, q10
\n
"
"vmla.f32 q15, q3, d1[1]
\n
"
"vst1.32 {d30-d31}, [%[at_m_ptr]]!
\n
"
"vadd.f32 q15, q12, q2
\n
"
"vmla.f32 q15, q14, d1[1]
\n
"
"vst1.32 {d30-d31}, [%[at_m_ptr]]!
\n
"
"vmov.32 q15, q9
\n
"
"vmla.f32 q15, q10, d0[1]
\n
"
"vmla.f32 q15, q11, d1[0]
\n
"
"vst1.32 {d30-d31}, [%[at_m_ptr]]!
\n
"
"vmov.32 q15, q12
\n
"
"vmla.f32 q15, q13, d1[0]
\n
"
"vmla.f32 q15, q14, d0[1]
\n
"
"vst1.32 {d30-d31}, [%[at_m_ptr]]!
\n
"
"vadd.f32 q15, q9, q3
\n
"
"vmla.f32 q15, q10, d1[1]
\n
"
"vst1.32 {d30-d31}, [%[at_m_ptr]]!
\n
"
"vadd.f32 q15, q12, q8
\n
"
"vadd.f32 q15, q15, q14
\n
"
"vmla.f32 q15, q2, d1[1]
\n
"
"vst1.32 {d30-d31}, [%[at_m_ptr]]!
\n
"
"subs r0, #1
\n
"
"bne loop_%=
\n
"
:
[
uv_ptr0
]
"+r"
(
uv_ptr0
),
[
at_m_ptr
]
"+r"
(
at_m_ptr
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
),
[
steps
]
"r"
(
steps
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"r0"
);
float
*
at_m_ptr0
=
at_m
;
float
*
at_m_ptr1
=
at_m
+
24
;
if
((
remain_w
>
0
&&
tile_w
==
w_tiles
-
1
)
||
(
remain_h
>
0
&&
tile_h
==
h_tiles
-
1
))
{
float
*
out_ptr0
=
output_tmp
;
float
*
out_ptr1
=
output_tmp
+
6
;
float
*
out_ptr2
=
output_tmp
+
12
;
float
*
out_ptr3
=
output_tmp
+
18
;
float
*
out_ptr4
=
output_tmp
+
24
;
float
*
out_ptr5
=
output_tmp
+
30
;
asm
volatile
(
"vld1.32 {d0-d1}, [%[tm_ptr]]
\n
"
// process 4 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]!
\n
"
// q1: m0, q2: m1
"vld1.32 {d6-d9}, [%[at_m_ptr0]]!
\n
"
// q3: m2, q4: m3
"vld1.32 {d10-d13}, [%[at_m_ptr1]]!
\n
"
// q5: m4, q6: m5
"vld1.32 {d14-d17}, [%[at_m_ptr1]]!
\n
"
// q7: m6, q8: m7
"vtrn.32 q1, q2
\n
"
"vtrn.32 q3, q4
\n
"
"vtrn.32 q5, q6
\n
"
"vtrn.32 q7, q8
\n
"
"vswp.32 d3, d6
\n
"
"vswp.32 d5, d8
\n
"
"vswp.32 d11, d14
\n
"
"vswp.32 d13, d16
\n
"
"vadd.f32 q9, q2, q3
\n
"
// m1 + m2
"vadd.f32 q10, q4, q5
\n
"
// m3 + m4
"vadd.f32 q11, q6, q7
\n
"
// m5 + m6
"vsub.f32 q12, q2, q3
\n
"
// m1 - m2
"vsub.f32 q13, q4, q5
\n
"
// m3 - m4
"vsub.f32 q14, q6, q7
\n
"
// m5 - m6
"vmul.f32 q6, q13, d0[0]
\n
"
// 2 * (m3 - m4)
"vmul.f32 q7, q11, d0[0]
\n
"
// 2 * (m5 + m6)
"vadd.f32 q1, q1, q9
\n
"
"vadd.f32 q1, q1, q10
\n
"
"vmla.f32 q1, q7, d1[1]
\n
"
"vadd.f32 q2, q12, q6
\n
"
"vmla.f32 q2, q14, d1[1]
\n
"
"vmov.32 q3, q9
\n
"
"vmla.f32 q3, q10, d0[1]
\n
"
"vmla.f32 q3, q11, d1[0]
\n
"
"vmov.32 q4, q12
\n
"
"vmla.f32 q4, q13, d1[0]
\n
"
"vmla.f32 q4, q14, d0[1]
\n
"
"vtrn.32 q1, q2
\n
"
"vtrn.32 q3, q4
\n
"
"vswp.32 d3, d6
\n
"
"vswp.32 d5, d8
\n
"
"vst1.32 {d2-d3}, [%[out_ptr0]]!
\n
"
"vst1.32 {d4-d5}, [%[out_ptr1]]!
\n
"
"vst1.32 {d6-d7}, [%[out_ptr2]]!
\n
"
"vst1.32 {d8-d9}, [%[out_ptr3]]!
\n
"
"vadd.f32 q1, q9, q7
\n
"
"vmla.f32 q1, q10, d1[1]
\n
"
"vadd.f32 q2, q12, q8
\n
"
"vadd.f32 q2, q2, q14
\n
"
"vmla.f32 q2, q6, d1[1]
\n
"
"vtrn.32 q1, q2
\n
"
"vst1.32 {d2}, [%[out_ptr0]]!
\n
"
"vst1.32 {d4}, [%[out_ptr1]]!
\n
"
"vst1.32 {d3}, [%[out_ptr2]]!
\n
"
"vst1.32 {d5}, [%[out_ptr3]]!
\n
"
// remain 2 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]!
\n
"
// d2: m0, d3: m2,
// d4: m1, d5: m3
"vld1.32 {d6-d9}, [%[at_m_ptr1]]!
\n
"
// d6: m4, d7: m6,
// d8: m5, d9: m7
"vtrn.32 q1, q2
\n
"
"vtrn.32 q3, q4
\n
"
"vadd.f32 d10, d4, d3
\n
"
// m1 + m2
"vadd.f32 d11, d5, d6
\n
"
// m3 + m4
"vadd.f32 d12, d8, d7
\n
"
// m5 + m6
"vsub.f32 d13, d4, d3
\n
"
// m1 - m2
"vsub.f32 d14, d5, d6
\n
"
// m3 - m4
"vsub.f32 d15, d8, d7
\n
"
// m5 - m6
"vmul.f32 d16, d14, d0[0]
\n
"
// 2 * (m3 - m4)
"vmul.f32 d17, d12, d0[0]
\n
"
// 2 * (m5 + m6)
"vadd.f32 d18, d2, d10
\n
"
"vadd.f32 d18, d18, d11
\n
"
"vmla.f32 d18, d17, d1[1]
\n
"
"vadd.f32 d20, d13, d16
\n
"
"vmla.f32 d20, d15, d1[1]
\n
"
"vmov.32 d19, d10
\n
"
"vmla.f32 d19, d11, d0[1]
\n
"
"vmla.f32 d19, d12, d1[0]
\n
"
"vmov.32 d21, d13
\n
"
"vmla.f32 d21, d14, d1[0]
\n
"
"vmla.f32 d21, d15, d0[1]
\n
"
"vtrn.32 d18, d20
\n
"
"vtrn.32 d19, d21
\n
"
"vst1.32 {d18-d19}, [%[out_ptr4]]!
\n
"
"vst1.32 {d20-d21}, [%[out_ptr5]]!
\n
"
"vadd.f32 d18, d10, d17
\n
"
"vmla.f32 d18, d11, d1[1]
\n
"
"vadd.f32 d19, d13, d9
\n
"
"vadd.f32 d19, d19, d15
\n
"
"vmla.f32 d19, d16, d1[1]
\n
"
"vtrn.32 d18, d19
\n
"
"vst1.32 {d18}, [%[out_ptr4]]!
\n
"
"vst1.32 {d19}, [%[out_ptr5]]!
\n
"
:
[
out_ptr0
]
"+r"
(
out_ptr0
),
[
out_ptr1
]
"+r"
(
out_ptr1
),
[
out_ptr2
]
"+r"
(
out_ptr2
),
[
out_ptr3
]
"+r"
(
out_ptr3
),
[
out_ptr4
]
"+r"
(
out_ptr4
),
[
out_ptr5
]
"+r"
(
out_ptr5
),
[
at_m_ptr0
]
"+r"
(
at_m_ptr0
),
[
at_m_ptr1
]
"+r"
(
at_m_ptr1
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
size_t
offset
=
(
oc
*
out_h
+
6
*
tile_h
)
*
out_w
+
6
*
tile_w
;
float
*
out_ptr
=
output_ptr
+
offset
;
int
remain_row
=
(
tile_h
<
h_tiles
-
1
)
?
6
:
remain_h
;
int
remain_col
=
(
tile_w
<
w_tiles
-
1
)
?
6
:
remain_w
;
for
(
int
i
=
0
;
i
<
remain_row
;
++
i
,
out_ptr
+=
out_w
)
{
memcpy
(
out_ptr
,
output_tmp
+
i
*
6
,
remain_col
*
sizeof
(
float
));
}
}
else
{
size_t
offset
=
(
oc
*
out_h
+
6
*
tile_h
)
*
out_w
+
6
*
tile_w
;
float
*
out_ptr0
=
output_ptr
+
offset
;
float
*
out_ptr1
=
out_ptr0
+
out_w
;
float
*
out_ptr2
=
out_ptr1
+
out_w
;
float
*
out_ptr3
=
out_ptr2
+
out_w
;
float
*
out_ptr4
=
out_ptr3
+
out_w
;
float
*
out_ptr5
=
out_ptr4
+
out_w
;
asm
volatile
(
"vld1.32 {d0-d1}, [%[tm_ptr]]
\n
"
// process 4 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]!
\n
"
// q1: m0, q2: m1
"vld1.32 {d6-d9}, [%[at_m_ptr0]]!
\n
"
// q3: m2, q4: m3
"vld1.32 {d10-d13}, [%[at_m_ptr1]]!
\n
"
// q5: m4, q6: m5
"vld1.32 {d14-d17}, [%[at_m_ptr1]]!
\n
"
// q7: m6, q8: m7
"vtrn.32 q1, q2
\n
"
"vtrn.32 q3, q4
\n
"
"vtrn.32 q5, q6
\n
"
"vtrn.32 q7, q8
\n
"
"vswp.32 d3, d6
\n
"
"vswp.32 d5, d8
\n
"
"vswp.32 d11, d14
\n
"
"vswp.32 d13, d16
\n
"
"vadd.f32 q9, q2, q3
\n
"
// m1 + m2
"vadd.f32 q10, q4, q5
\n
"
// m3 + m4
"vadd.f32 q11, q6, q7
\n
"
// m5 + m6
"vsub.f32 q12, q2, q3
\n
"
// m1 - m2
"vsub.f32 q13, q4, q5
\n
"
// m3 - m4
"vsub.f32 q14, q6, q7
\n
"
// m5 - m6
"vmul.f32 q6, q13, d0[0]
\n
"
// 2 * (m3 - m4)
"vmul.f32 q7, q11, d0[0]
\n
"
// 2 * (m5 + m6)
"vadd.f32 q1, q1, q9
\n
"
"vadd.f32 q1, q1, q10
\n
"
"vmla.f32 q1, q7, d1[1]
\n
"
"vadd.f32 q2, q12, q6
\n
"
"vmla.f32 q2, q14, d1[1]
\n
"
"vmov.32 q3, q9
\n
"
"vmla.f32 q3, q10, d0[1]
\n
"
"vmla.f32 q3, q11, d1[0]
\n
"
"vmov.32 q4, q12
\n
"
"vmla.f32 q4, q13, d1[0]
\n
"
"vmla.f32 q4, q14, d0[1]
\n
"
"vtrn.32 q1, q2
\n
"
"vtrn.32 q3, q4
\n
"
"vswp.32 d3, d6
\n
"
"vswp.32 d5, d8
\n
"
"vst1.32 {d2-d3}, [%[out_ptr0]]!
\n
"
"vst1.32 {d4-d5}, [%[out_ptr1]]!
\n
"
"vst1.32 {d6-d7}, [%[out_ptr2]]!
\n
"
"vst1.32 {d8-d9}, [%[out_ptr3]]!
\n
"
"vadd.f32 q1, q9, q7
\n
"
"vmla.f32 q1, q10, d1[1]
\n
"
"vadd.f32 q2, q12, q8
\n
"
"vadd.f32 q2, q2, q14
\n
"
"vmla.f32 q2, q6, d1[1]
\n
"
"vtrn.32 q1, q2
\n
"
"vst1.32 {d2}, [%[out_ptr0]]!
\n
"
"vst1.32 {d4}, [%[out_ptr1]]!
\n
"
"vst1.32 {d3}, [%[out_ptr2]]!
\n
"
"vst1.32 {d5}, [%[out_ptr3]]!
\n
"
// remain 2 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]!
\n
"
// d2: m0, d3: m2,
// d4: m1, d5: m3
"vld1.32 {d6-d9}, [%[at_m_ptr1]]!
\n
"
// d6: m4, d7: m6,
// d8: m5, d9: m7
"vtrn.32 q1, q2
\n
"
"vtrn.32 q3, q4
\n
"
"vadd.f32 d10, d4, d3
\n
"
// m1 + m2
"vadd.f32 d11, d5, d6
\n
"
// m3 + m4
"vadd.f32 d12, d8, d7
\n
"
// m5 + m6
"vsub.f32 d13, d4, d3
\n
"
// m1 - m2
"vsub.f32 d14, d5, d6
\n
"
// m3 - m4
"vsub.f32 d15, d8, d7
\n
"
// m5 - m6
"vmul.f32 d16, d14, d0[0]
\n
"
// 2 * (m3 - m4)
"vmul.f32 d17, d12, d0[0]
\n
"
// 2 * (m5 + m6)
"vadd.f32 d18, d2, d10
\n
"
"vadd.f32 d18, d18, d11
\n
"
"vmla.f32 d18, d17, d1[1]
\n
"
"vadd.f32 d20, d13, d16
\n
"
"vmla.f32 d20, d15, d1[1]
\n
"
"vmov.32 d19, d10
\n
"
"vmla.f32 d19, d11, d0[1]
\n
"
"vmla.f32 d19, d12, d1[0]
\n
"
"vmov.32 d21, d13
\n
"
"vmla.f32 d21, d14, d1[0]
\n
"
"vmla.f32 d21, d15, d0[1]
\n
"
"vtrn.32 d18, d20
\n
"
"vtrn.32 d19, d21
\n
"
"vst1.32 {d18-d19}, [%[out_ptr4]]!
\n
"
"vst1.32 {d20-d21}, [%[out_ptr5]]!
\n
"
"vadd.f32 d18, d10, d17
\n
"
"vmla.f32 d18, d11, d1[1]
\n
"
"vadd.f32 d19, d13, d9
\n
"
"vadd.f32 d19, d19, d15
\n
"
"vmla.f32 d19, d16, d1[1]
\n
"
"vtrn.32 d18, d19
\n
"
"vst1.32 {d18}, [%[out_ptr4]]!
\n
"
"vst1.32 {d19}, [%[out_ptr5]]!
\n
"
:
[
out_ptr0
]
"+r"
(
out_ptr0
),
[
out_ptr1
]
"+r"
(
out_ptr1
),
[
out_ptr2
]
"+r"
(
out_ptr2
),
[
out_ptr3
]
"+r"
(
out_ptr3
),
[
out_ptr4
]
"+r"
(
out_ptr4
),
[
out_ptr5
]
"+r"
(
out_ptr5
),
[
at_m_ptr0
]
"+r"
(
at_m_ptr0
),
[
at_m_ptr1
]
"+r"
(
at_m_ptr1
)
:
[
tm_ptr
]
"r"
((
float
*
)
transform_matrix
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
}
}
}
}
}
// namespace math
}
// namespace operators
}
// namespace paddle_mobile
#endif // __aarch64__
#endif // CONV_OP
src/operators/op_param.h
浏览文件 @
8137d199
...
@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
...
@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const
RType
*
Input
()
const
{
return
input_
;
}
const
RType
*
Input
()
const
{
return
input_
;
}
RType
*
Filter
()
const
{
return
filter_
;
}
RType
*
&
Filter
()
const
{
return
filter_
;
}
RType
*
Output
()
const
{
return
output_
;
}
RType
*
&
Output
()
const
{
return
output_
;
}
const
vector
<
int
>
&
Strides
()
const
{
return
strides_
;
}
const
vector
<
int
>
&
Strides
()
const
{
return
strides_
;
}
...
@@ -415,6 +415,19 @@ class ConvParam : public OpParam {
...
@@ -415,6 +415,19 @@ class ConvParam : public OpParam {
const
vector
<
int
>
&
Dilations
()
const
{
return
dilations_
;
}
const
vector
<
int
>
&
Dilations
()
const
{
return
dilations_
;
}
enum
ExecMode
{
EXEC_INVALID
=
0
,
EXEC_GEMM_FLOAT
,
EXEC_DEPTHWISE3x3S1P1_FLOAT
,
EXEC_DEPTHWISE3x3_FLOAT
,
EXEC_WINOGRAD3X3_FLOAT
,
EXEC_WINOGRAD5X5_FLOAT
,
EXEC_GEMM_INT8
,
EXEC_DEPTHWISE3x3_INT8
,
};
ExecMode
&
ExecMode
()
const
{
return
exec_mode_
;
}
const
int
&
Groups
()
const
{
return
groups
;
}
const
int
&
Groups
()
const
{
return
groups
;
}
#ifdef PADDLE_MOBILE_CL
#ifdef PADDLE_MOBILE_CL
...
@@ -426,11 +439,12 @@ class ConvParam : public OpParam {
...
@@ -426,11 +439,12 @@ class ConvParam : public OpParam {
private:
private:
RType
*
input_
;
RType
*
input_
;
RType
*
output_
;
mutable
RType
*
output_
;
RType
*
filter_
;
mutable
RType
*
filter_
;
vector
<
int
>
strides_
;
vector
<
int
>
strides_
;
vector
<
int
>
paddings_
;
vector
<
int
>
paddings_
;
vector
<
int
>
dilations_
;
vector
<
int
>
dilations_
;
mutable
enum
ExecMode
exec_mode_
;
int
groups
;
int
groups
;
#ifdef PADDLE_MOBILE_CL
#ifdef PADDLE_MOBILE_CL
...
@@ -2509,10 +2523,10 @@ class QuantizeParam : public OpParam {
...
@@ -2509,10 +2523,10 @@ class QuantizeParam : public OpParam {
QuantizeParam
(
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
QuantizeParam
(
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
Scope
&
scope
)
{
const
AttributeMap
&
attrs
,
const
Scope
&
scope
)
{
input_
=
InputXFrom
<
GType
>
(
inputs
,
scope
);
input_
=
InputXFrom
<
GType
>
(
inputs
,
scope
);
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
out
put
_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
// online
// online
// scale = max(abs(x))
// scale = max(abs(x))
online_scale_
=
GetVarValue
<
GType
>
(
"OutScale"
,
outputs
,
scope
);
online_scale_
=
OpParam
::
GetVarValue
<
GType
>
(
"OutScale"
,
outputs
,
scope
);
// offline
// offline
if
(
HasAttr
(
"static_scale"
,
attrs
))
{
if
(
HasAttr
(
"static_scale"
,
attrs
))
{
is_static_
=
true
;
is_static_
=
true
;
...
@@ -2522,14 +2536,18 @@ class QuantizeParam : public OpParam {
...
@@ -2522,14 +2536,18 @@ class QuantizeParam : public OpParam {
if
(
HasAttr
(
"round_type"
,
attrs
))
{
if
(
HasAttr
(
"round_type"
,
attrs
))
{
round_type_
=
GetAttr
<
RoundType
>
(
"round_type"
,
attrs
);
round_type_
=
GetAttr
<
RoundType
>
(
"round_type"
,
attrs
);
}
}
// get paddings
paddings_
=
std
::
vector
<
int
>
({
0
,
0
});
if
(
HasAttr
(
"paddings"
,
attrs
))
{
paddings_
=
GetAttr
<
vector
<
int
>>
(
"paddings"
,
attrs
);
}
}
}
public:
public:
// op input
// op input
RType
*
input_
;
RType
*
input_
;
// op output
// op output
RType
*
out_
;
RType
*
output_
;
//
RType
*
online_scale_
;
RType
*
online_scale_
;
// if static scale or not
// if static scale or not
bool
is_static_
=
false
;
bool
is_static_
=
false
;
...
@@ -2537,7 +2555,11 @@ class QuantizeParam : public OpParam {
...
@@ -2537,7 +2555,11 @@ class QuantizeParam : public OpParam {
float
static_scale_
=
1.0
f
;
float
static_scale_
=
1.0
f
;
// round method type
// round method type
// nearest_zero and nearest_even is valid currently
// nearest_zero and nearest_even is valid currently
RoundType
round_type_
=
ROUND_NEAREST_AWAY_ZERO
;
// RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType
round_type_
=
ROUND_NEAREST_TOWARDS_ZERO
;
// optional paddings
std
::
vector
<
int
>
paddings_
;
int8_t
padding_val_
;
};
};
#endif
#endif
...
@@ -2551,8 +2573,8 @@ class DequantizeParam : public OpParam {
...
@@ -2551,8 +2573,8 @@ class DequantizeParam : public OpParam {
DequantizeParam
(
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
DequantizeParam
(
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
Scope
&
scope
)
{
const
AttributeMap
&
attrs
,
const
Scope
&
scope
)
{
input_
=
InputXFrom
<
GType
>
(
inputs
,
scope
);
input_
=
InputXFrom
<
GType
>
(
inputs
,
scope
);
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
out
put
_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
activation_scale_
=
GetVarValue
<
GType
>
(
"Scale"
,
inputs
,
scope
);
activation_scale_
=
OpParam
::
GetVarValue
<
GType
>
(
"Scale"
,
inputs
,
scope
);
// dequantization is performed as x = x / static_scale / online_scale
// dequantization is performed as x = x / static_scale / online_scale
if
(
HasAttr
(
"weight_scale"
,
attrs
))
{
if
(
HasAttr
(
"weight_scale"
,
attrs
))
{
weight_scale_
=
GetAttr
<
float
>
(
"weight_scale"
,
attrs
);
weight_scale_
=
GetAttr
<
float
>
(
"weight_scale"
,
attrs
);
...
@@ -2565,11 +2587,50 @@ class DequantizeParam : public OpParam {
...
@@ -2565,11 +2587,50 @@ class DequantizeParam : public OpParam {
// op input
// op input
RType
*
input_
;
RType
*
input_
;
// op output
// op output
RType
*
out_
;
RType
*
out
put
_
;
RType
*
activation_scale_
;
RType
*
activation_scale_
;
float
weight_scale_
;
float
weight_scale_
;
};
};
#endif
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template
<
typename
Dtype
>
class
FusionDequantAddBNReluParam
:
public
DequantizeParam
<
Dtype
>
{
typedef
typename
DtypeTensorTrait
<
Dtype
>::
gtype
GType
;
typedef
typename
DtypeTensorTrait
<
Dtype
>::
rtype
RType
;
public:
FusionDequantAddBNReluParam
(
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
,
const
Scope
&
scope
)
:
DequantizeParam
<
Dtype
>
(
inputs
,
outputs
,
attrs
,
scope
)
{
// element wise add params
axis_
=
OpParam
::
GetAttr
<
int
>
(
"axis"
,
attrs
);
bias_
=
OpParam
::
InputYFrom
<
GType
>
(
inputs
,
scope
);
// batch norm params
bn_mean_
=
OpParam
::
GetVarValue
<
GType
>
(
"BNMean"
,
inputs
,
scope
);
bn_variance_
=
OpParam
::
GetVarValue
<
GType
>
(
"BNVariance"
,
inputs
,
scope
);
bn_scale_
=
OpParam
::
GetVarValue
<
GType
>
(
"BNScale"
,
inputs
,
scope
);
bn_bias_
=
OpParam
::
GetVarValue
<
GType
>
(
"BNBias"
,
inputs
,
scope
);
epsilon_
=
OpParam
::
GetAttr
<
float
>
(
"epsilon"
,
attrs
);
// output
output_
=
OpParam
::
OutFrom
<
GType
>
(
outputs
,
scope
);
}
public:
// elementwise add
int
axis_
;
RType
*
bias_
;
// batch norm
RType
*
bn_mean_
;
RType
*
bn_variance_
;
RType
*
bn_scale_
;
RType
*
bn_bias_
;
float
epsilon_
;
// output
RType
*
output_
;
};
#endif
}
// namespace operators
}
// namespace operators
}
// namespace paddle_mobile
}
// namespace paddle_mobile
src/operators/quantize_op.cpp
浏览文件 @
8137d199
...
@@ -22,8 +22,11 @@ namespace operators {
...
@@ -22,8 +22,11 @@ namespace operators {
template
<
typename
DeviceType
,
typename
T
>
template
<
typename
DeviceType
,
typename
T
>
void
QuantizeOp
<
DeviceType
,
T
>::
InferShape
()
const
{
void
QuantizeOp
<
DeviceType
,
T
>::
InferShape
()
const
{
const
auto
&
input_dims
=
this
->
param_
.
input_
->
dims
();
auto
input_dims
=
this
->
param_
.
input_
->
dims
();
this
->
param_
.
out_
->
Resize
(
input_dims
);
const
std
::
vector
<
int
>
&
paddings
=
this
->
param_
.
paddings_
;
input_dims
[
2
]
+=
2
*
paddings
[
0
];
input_dims
[
3
]
+=
2
*
paddings
[
1
];
this
->
param_
.
output_
->
Resize
(
input_dims
);
auto
scale_dims
=
framework
::
make_ddim
(
std
::
vector
<
int
>
{
1
});
auto
scale_dims
=
framework
::
make_ddim
(
std
::
vector
<
int
>
{
1
});
this
->
param_
.
online_scale_
->
Resize
(
scale_dims
);
this
->
param_
.
online_scale_
->
Resize
(
scale_dims
);
}
}
...
...
test/CMakeLists.txt
浏览文件 @
8137d199
...
@@ -155,7 +155,7 @@ if (NOT FOUND_MATCH)
...
@@ -155,7 +155,7 @@ if (NOT FOUND_MATCH)
target_link_libraries
(
test-googlenet-quali paddle-mobile
)
target_link_libraries
(
test-googlenet-quali paddle-mobile
)
# gen test
# gen test
ADD_EXECUTABLE
(
test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h
)
ADD_EXECUTABLE
(
test-conv-op operators/test_co
n
v_op.cpp test_helper.h test_include.h executor_for_test.h
)
target_link_libraries
(
test-conv-op paddle-mobile
)
target_link_libraries
(
test-conv-op paddle-mobile
)
# gen test
# gen test
...
@@ -242,10 +242,6 @@ if (NOT FOUND_MATCH)
...
@@ -242,10 +242,6 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE
(
test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h
)
ADD_EXECUTABLE
(
test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h
)
target_link_libraries
(
test-dequantize-op paddle-mobile
)
target_link_libraries
(
test-dequantize-op paddle-mobile
)
# test int8 conv op
ADD_EXECUTABLE
(
test-int8-conv-op operators/test_int8_conv_op.cpp test_helper.h test_include.h
)
target_link_libraries
(
test-int8-conv-op paddle-mobile
)
# gen test log
# gen test log
ADD_EXECUTABLE
(
test-log common/test_log.cpp
)
ADD_EXECUTABLE
(
test-log common/test_log.cpp
)
target_link_libraries
(
test-log paddle-mobile
)
target_link_libraries
(
test-log paddle-mobile
)
...
@@ -368,6 +364,10 @@ if (NOT FOUND_MATCH)
...
@@ -368,6 +364,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE
(
test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h
)
ADD_EXECUTABLE
(
test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h
)
target_link_libraries
(
test-multi-process paddle-mobile
)
target_link_libraries
(
test-multi-process paddle-mobile
)
# gen test benchmark
ADD_EXECUTABLE
(
test-benchmark net/test_benchmark.cpp
)
target_link_libraries
(
test-benchmark paddle-mobile
)
# gen test
# gen test
ADD_EXECUTABLE
(
test-eng net/test_eng.cpp test_helper.h test_include.h
)
ADD_EXECUTABLE
(
test-eng net/test_eng.cpp test_helper.h test_include.h
)
target_link_libraries
(
test-eng paddle-mobile
)
target_link_libraries
(
test-eng paddle-mobile
)
...
...
test/framework/test_load_memory.cpp
浏览文件 @
8137d199
...
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <iostream>
#include <string>
#include <string>
#include "../test_helper.h"
#include "../test_helper.h"
#include "../test_include.h"
#include "../test_include.h"
static
size_t
ReadBuffer
(
const
char
*
file_name
,
uint8_t
**
out
)
{
static
size_t
ReadBuffer
(
const
char
*
file_name
,
uint8_t
**
out
)
{
FILE
*
fp
;
FILE
*
fp
;
fp
=
fopen
(
file_name
,
"rb"
);
fp
=
fopen
(
file_name
,
"rb"
);
...
...
test/net/test_benchmark.cpp
0 → 100644
浏览文件 @
8137d199
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
<
4
)
{
std
::
cout
<<
"Usage: "
<<
std
::
endl
<<
"./test_benchmark fluid_model feed_shape thread_num [use_fuse]"
<<
std
::
endl
;
std
::
cout
<<
"use_fuse: optional, bool, default is 1
\n
"
;
return
1
;
}
bool
optimize
=
true
;
char
*
fluid_model
=
argv
[
1
];
char
*
feed_shape
=
argv
[
2
];
int
thread_num
=
atoi
(
argv
[
3
]);
if
(
argc
==
5
)
{
optimize
=
atoi
(
argv
[
4
]);
}
paddle_mobile
::
PaddleMobile
<
paddle_mobile
::
CPU
>
paddle_mobile
;
paddle_mobile
.
SetThreadNum
(
thread_num
);
auto
time1
=
time
();
if
(
paddle_mobile
.
Load
(
fluid_model
,
optimize
))
{
auto
time2
=
time
();
std
::
cout
<<
"load cost :"
<<
time_diff
(
time1
,
time2
)
<<
"ms
\n
"
;
paddle_mobile
::
framework
::
Tensor
input
;
std
::
shared_ptr
<
paddle_mobile
::
framework
::
Tensor
>
output
;
std
::
vector
<
int64_t
>
dims
{
1
,
3
,
224
,
224
};
if
(
feed_shape
)
{
sscanf
(
feed_shape
,
"%d,%d,%d,%d"
,
&
dims
[
0
],
&
dims
[
1
],
&
dims
[
2
],
&
dims
[
3
]);
}
std
::
cout
<<
"feed shape: ["
<<
dims
[
0
]
<<
", "
<<
dims
[
1
]
<<
", "
<<
dims
[
2
]
<<
", "
<<
dims
[
3
]
<<
"]
\n
"
;
paddle_mobile
::
framework
::
DDim
in_shape
=
paddle_mobile
::
framework
::
make_ddim
(
dims
);
SetupTensor
<
float
>
(
&
input
,
in_shape
,
0.
f
,
255.
f
);
// warmup
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
output
=
paddle_mobile
.
Predict
(
input
);
}
auto
time3
=
time
();
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
output
=
paddle_mobile
.
Predict
(
input
);
}
auto
time4
=
time
();
std
::
cout
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
/
10
<<
"ms
\n
"
;
}
return
0
;
}
test/net/test_googlenet.cpp
浏览文件 @
8137d199
...
@@ -20,12 +20,11 @@ int main() {
...
@@ -20,12 +20,11 @@ int main() {
#ifdef PADDLE_MOBILE_FPGA
#ifdef PADDLE_MOBILE_FPGA
paddle_mobile
::
PaddleMobile
<
paddle_mobile
::
FPGA
>
paddle_mobile
;
paddle_mobile
::
PaddleMobile
<
paddle_mobile
::
FPGA
>
paddle_mobile
;
#endif
#endif
#ifdef PADDLE_MOBILE_CPU
#ifdef PADDLE_MOBILE_CPU
paddle_mobile
::
PaddleMobile
<
paddle_mobile
::
CPU
>
paddle_mobile
;
paddle_mobile
::
PaddleMobile
<
paddle_mobile
::
CPU
>
paddle_mobile
;
#endif
#endif
paddle_mobile
.
SetThreadNum
(
4
);
paddle_mobile
.
SetThreadNum
(
1
);
bool
optimize
=
true
;
bool
optimize
=
true
;
auto
time1
=
time
();
auto
time1
=
time
();
if
(
paddle_mobile
.
Load
(
g_googlenet
,
optimize
))
{
if
(
paddle_mobile
.
Load
(
g_googlenet
,
optimize
))
{
...
@@ -36,7 +35,7 @@ int main() {
...
@@ -36,7 +35,7 @@ int main() {
std
::
vector
<
float
>
output
;
std
::
vector
<
float
>
output
;
std
::
vector
<
int64_t
>
dims
{
1
,
3
,
224
,
224
};
std
::
vector
<
int64_t
>
dims
{
1
,
3
,
224
,
224
};
GetInput
<
float
>
(
g_test_image_1x3x224x224
,
&
input
,
dims
);
GetInput
<
float
>
(
g_test_image_1x3x224x224
,
&
input
,
dims
);
//
预热十次
//
warmup
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
output
=
paddle_mobile
.
Predict
(
input
,
dims
);
output
=
paddle_mobile
.
Predict
(
input
,
dims
);
}
}
...
@@ -46,8 +45,7 @@ int main() {
...
@@ -46,8 +45,7 @@ int main() {
}
}
auto
time4
=
time
();
auto
time4
=
time
();
std
::
cout
<<
"predict cost :"
<<
time_diff
(
time3
,
time4
)
/
10
<<
"ms"
std
::
cout
<<
"predict cost: "
<<
time_diff
(
time3
,
time4
)
/
10
<<
"ms
\n
"
;
<<
std
::
endl
;
}
}
return
0
;
return
0
;
}
}
test/operators/test_
int8_
conv_op.cpp
→
test/operators/test_conv_op.cpp
浏览文件 @
8137d199
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
namespace
paddle_mobile
{
namespace
paddle_mobile
{
// Reference convolution f
or checking results:
// Reference convolution f
rom Caffe for checking results.
// accumulate through explicit loops over input, output, and filters.
// accumulate through explicit loops over input, output, and filters.
template
<
typename
Itype
,
typename
Otype
>
template
<
typename
Itype
,
typename
Otype
>
void
conv2d
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
void
conv2d
(
const
framework
::
Tensor
*
input
,
const
framework
::
Tensor
*
filter
,
...
@@ -129,7 +129,7 @@ void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
...
@@ -129,7 +129,7 @@ void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
}
}
template
<
typename
Itype
,
typename
Otype
,
int
Kernel
,
int
Pad
,
int
Stride
>
template
<
typename
Itype
,
typename
Otype
,
int
Kernel
,
int
Pad
,
int
Stride
>
int
TestConvOp
()
{
int
TestConvOp
(
int
in_channels
,
int
in_height
,
int
in_width
,
int
out_channels
)
{
int
kernel_h
=
Kernel
;
int
kernel_h
=
Kernel
;
int
kernel_w
=
Kernel
;
int
kernel_w
=
Kernel
;
int
pad_h
=
Pad
;
int
pad_h
=
Pad
;
...
@@ -140,10 +140,10 @@ int TestConvOp() {
...
@@ -140,10 +140,10 @@ int TestConvOp() {
int
dilation_w
=
1
;
int
dilation_w
=
1
;
int
batch_size
=
1
;
int
batch_size
=
1
;
int
input_c
=
3
;
int
input_c
=
in_channels
;
int
input_h
=
100
;
int
input_h
=
in_height
;
int
input_w
=
100
;
int
input_w
=
in_width
;
int
output_c
=
10
;
int
output_c
=
out_channels
;
framework
::
DDim
input_shape
=
framework
::
DDim
input_shape
=
framework
::
make_ddim
({
batch_size
,
input_c
,
input_h
,
input_w
});
framework
::
make_ddim
({
batch_size
,
input_c
,
input_h
,
input_w
});
framework
::
DDim
filter_shape
=
framework
::
DDim
filter_shape
=
...
@@ -158,7 +158,7 @@ int TestConvOp() {
...
@@ -158,7 +158,7 @@ int TestConvOp() {
auto
input_var
=
scope
.
get
()
->
Var
(
"input"
);
auto
input_var
=
scope
.
get
()
->
Var
(
"input"
);
auto
input
=
input_var
->
template
GetMutable
<
framework
::
LoDTensor
>();
auto
input
=
input_var
->
template
GetMutable
<
framework
::
LoDTensor
>();
SetupTensor
<
Itype
>
(
input
,
input_shape
,
-
20
,
2
0
);
SetupTensor
<
Itype
>
(
input
,
input_shape
,
-
20
.0
,
20.
0
);
auto
filter_var
=
scope
.
get
()
->
Var
(
"filter"
);
auto
filter_var
=
scope
.
get
()
->
Var
(
"filter"
);
auto
filter
=
filter_var
->
template
GetMutable
<
framework
::
LoDTensor
>();
auto
filter
=
filter_var
->
template
GetMutable
<
framework
::
LoDTensor
>();
...
@@ -174,8 +174,9 @@ int TestConvOp() {
...
@@ -174,8 +174,9 @@ int TestConvOp() {
auto
*
op
=
new
operators
::
ConvOp
<
CPU
,
float
>
(
"conv2d"
,
inputs
,
outputs
,
attrs
,
auto
*
op
=
new
operators
::
ConvOp
<
CPU
,
float
>
(
"conv2d"
,
inputs
,
outputs
,
attrs
,
scope
);
scope
);
// struct timespec ts_begin, ts_end;
op
->
InferShape
();
op
->
InferShape
();
op
->
Init
();
// struct timespec ts_begin, ts_end;
// warmup
// warmup
// op->Run();
// op->Run();
// clock_gettime(CLOCK_MONOTONIC, &ts_begin);
// clock_gettime(CLOCK_MONOTONIC, &ts_begin);
...
@@ -202,9 +203,16 @@ int TestConvOp() {
...
@@ -202,9 +203,16 @@ int TestConvOp() {
const
Otype
*
output_data
=
output
->
data
<
Otype
>
();
const
Otype
*
output_data
=
output
->
data
<
Otype
>
();
Otype
*
output_cmp_data
=
output_cmp
.
data
<
Otype
>
();
Otype
*
output_cmp_data
=
output_cmp
.
data
<
Otype
>
();
for
(
int
i
=
0
;
i
<
output
->
numel
();
++
i
)
{
for
(
int
i
=
0
;
i
<
output
->
numel
();
++
i
)
{
PADDLE_MOBILE_ENFORCE
(
output_data
[
i
]
==
output_cmp_data
[
i
],
float
gap
=
output_data
[
i
]
-
output_cmp_data
[
i
];
PADDLE_MOBILE_ENFORCE
(
std
::
abs
(
gap
/
(
output_data
[
i
]
+
1e-5
))
<
1e-3
,
"output[%d] = %d, output_cmp[%d] = %d"
,
i
,
"output[%d] = %d, output_cmp[%d] = %d"
,
i
,
output_data
[
i
],
i
,
output_cmp_data
[
i
]);
output_data
[
i
],
i
,
output_cmp_data
[
i
]);
// if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
// LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
// << ", output_cmp_data[" << i << "] = " <<
// output_cmp_data[i];
// return 1;
// }
}
}
delete
op
;
delete
op
;
return
0
;
return
0
;
...
@@ -212,68 +220,88 @@ int TestConvOp() {
...
@@ -212,68 +220,88 @@ int TestConvOp() {
}
// namespace paddle_mobile
}
// namespace paddle_mobile
int
main
()
{
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
<
5
)
{
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"Usage:
\n
"
<<
" ./test-int8-conv-op in_channels in_height in_width out_channels
\n
"
<<
" params:
\n
"
<<
" -in_channels: int, input image's channels
\n
"
<<
" -in_height: int, input image's height
\n
"
<<
" -in_width: int, input image's width
\n
"
<<
" -out_channels: int, conv output channels
\n
"
;
return
1
;
}
int
in_channels
=
atoi
(
argv
[
1
]);
int
in_height
=
atoi
(
argv
[
2
]);
int
in_width
=
atoi
(
argv
[
3
]);
int
out_channels
=
atoi
(
argv
[
4
]);
// kernel = 3, pad = 1, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=3, pad=1, stride=1"
;
paddle_mobile
::
TestConvOp
<
float
,
float
,
3
,
1
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 0, stride = 2
// kernel = 7, pad = 0, stride = 2
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=0, stride=2"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=0, stride=2"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
0
,
2
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
0
,
2
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 1, stride = 2
// kernel = 7, pad = 1, stride = 2
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=1, stride=2"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=1, stride=2"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
1
,
2
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
1
,
2
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 3, stride = 2
// kernel = 7, pad = 3, stride = 2
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=3, stride=2"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=3, stride=2"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
3
,
2
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
3
,
2
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 0, stride = 1
// kernel = 7, pad = 0, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=0, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=0, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
0
,
1
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
0
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 1, stride = 1
// kernel = 7, pad = 1, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=1, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=1, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
1
,
1
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
1
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 3, stride = 1
// kernel = 7, pad = 3, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=3, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=3, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
3
,
1
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
3
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 5, stride = 3
// kernel = 7, pad = 5, stride = 3
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=5, stride=3"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=5, stride=3"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
5
,
3
>
(
);
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
5
,
3
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 7, pad = 3, stride = 4
// kernel = 7, pad = 3, stride = 4
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=3, stride=4"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=7, pad=3, stride=4"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
3
,
4
>
();
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
7
,
3
,
4
>
(
in_channels
,
in_height
,
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"
\n
"
;
in_width
,
out_channels
);
// kernel = 3, pad = 0, stride = 1
// kernel = 3, pad = 0, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=3, pad=0, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=3, pad=0, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
3
,
0
,
1
>
();
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
3
,
0
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 3, pad = 0, stride = 1
// kernel = 3, pad = 0, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=3, pad=0, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=3, pad=0, stride=1"
;
paddle_mobile
::
TestConvOp
<
float
,
float
,
3
,
0
,
1
>
();
paddle_mobile
::
TestConvOp
<
float
,
float
,
3
,
0
,
1
>
(
in_channels
,
in_height
,
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"
\n
"
;
in_width
,
out_channels
);
// kernel = 3, pad = 1, stride = 1
// kernel = 3, pad = 1, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=3, pad=1, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=3, pad=1, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
3
,
1
,
1
>
();
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
3
,
1
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 3, pad = 1, stride = 1
// kernel = 3, pad = 1, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=3, pad=1, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=3, pad=1, stride=1"
;
paddle_mobile
::
TestConvOp
<
float
,
float
,
3
,
1
,
1
>
();
paddle_mobile
::
TestConvOp
<
float
,
float
,
3
,
1
,
1
>
(
in_channels
,
in_height
,
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"
\n
"
;
in_width
,
out_channels
);
// kernel = 5, pad = 0, stride = 1
// kernel = 5, pad = 0, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=5, pad=0, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=5, pad=0, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
5
,
0
,
1
>
();
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
5
,
0
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 5, pad = 0, stride = 1
// kernel = 5, pad = 0, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=5, pad=0, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=5, pad=0, stride=1"
;
paddle_mobile
::
TestConvOp
<
float
,
float
,
5
,
0
,
1
>
();
paddle_mobile
::
TestConvOp
<
float
,
float
,
5
,
0
,
1
>
(
in_channels
,
in_height
,
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"
\n
"
;
in_width
,
out_channels
);
// kernel = 5, pad = 2, stride = 1
// kernel = 5, pad = 2, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=5, pad=2, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"int8, kernel=5, pad=2, stride=1"
;
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
5
,
2
,
1
>
();
paddle_mobile
::
TestConvOp
<
int8_t
,
int32_t
,
5
,
2
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
// kernel = 5, pad = 2, stride = 1
// kernel = 5, pad = 2, stride = 1
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=5, pad=2, stride=1"
;
LOG
(
paddle_mobile
::
kLOG_INFO
)
<<
"float, kernel=5, pad=2, stride=1"
;
paddle_mobile
::
TestConvOp
<
float
,
float
,
5
,
2
,
1
>
();
paddle_mobile
::
TestConvOp
<
float
,
float
,
5
,
2
,
1
>
(
in_channels
,
in_height
,
in_width
,
out_channels
);
}
}
test/operators/test_quantize_op.cpp
浏览文件 @
8137d199
...
@@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_helper.h"
#include "../test_include.h"
#include "../test_include.h"
#include "operators/quantize_op.h"
#include "operators/quantize_op.h"
namespace
paddle_mobile
{
namespace
paddle_mobile
{
namespace
round
{
static
float
find_abs_max
(
const
Tensor
*
input
)
{
enum
RoundType
{
float
max_abs
=
0.
f
;
RoundToEven
=
0
,
const
float
*
x
=
input
->
data
<
const
float
>
();
RoundAwayZero
=
1
,
size_t
size
=
input
->
numel
();
RoundTowardsZero
=
2
,
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
};
float
value
=
std
::
abs
(
x
[
i
]);
if
(
value
>
max_abs
)
{
max_abs
=
value
;
}
}
return
max_abs
;
}
}
static
void
quantize_round_to_even
(
const
Tensor
*
input
,
const
float
scale
,
template
<
round
::
RoundType
T
>
Tensor
*
output
)
{
struct
Round
{
const
float
*
x
=
input
->
data
<
const
float
>
();
int8_t
operator
()(
float
x
);
int8_t
*
y
=
output
->
mutable_data
<
int8_t
>
();
};
size_t
size
=
input
->
numel
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
template
<
>
float
value
=
x
[
i
]
*
scale
;
struct
Round
<
round
::
RoundAwayZero
>
{
float
v
=
round
(
value
);
int8_t
operator
()(
float
x
)
{
return
std
::
round
(
x
);
}
};
template
<
>
struct
Round
<
round
::
RoundTowardsZero
>
{
int8_t
operator
()(
float
x
)
{
return
int8_t
(
x
);
}
};
template
<
>
struct
Round
<
round
::
RoundToEven
>
{
int8_t
operator
()(
float
x
)
{
int8_t
ret
=
0
;
float
v
=
std
::
round
(
x
);
int32_t
q
=
(
int32_t
)
v
;
int32_t
q
=
(
int32_t
)
v
;
if
(
abs
(
abs
(
q
-
value
)
-
0.5
)
>
0
)
{
if
(
abs
(
abs
(
q
-
x
)
-
0.5
)
>
0
)
{
y
[
i
]
=
q
;
ret
=
q
;
}
else
{
}
else
{
if
(
abs
(
q
)
%
2
==
0
)
{
if
(
abs
(
q
)
%
2
==
0
)
{
y
[
i
]
=
q
;
ret
=
q
;
}
else
{
}
else
{
y
[
i
]
=
q
+
((
q
>
0
)
?
-
1
:
1
);
ret
=
q
+
((
q
>
0
)
?
-
1
:
1
);
}
}
return
ret
;
}
};
template
<
round
::
RoundType
T
>
static
void
quantize
(
const
Tensor
*
input
,
const
float
scale
,
const
int
pad
,
const
int8_t
pad_val
,
Tensor
*
output
)
{
int
batch_size
=
input
->
dims
()[
0
];
int
channels
=
input
->
dims
()[
1
];
int
input_h
=
input
->
dims
()[
2
];
int
input_w
=
input
->
dims
()[
3
];
int
output_h
=
output
->
dims
()[
2
];
int
output_w
=
output
->
dims
()[
3
];
size_t
input_spatial
=
input_h
*
input_w
;
size_t
output_spatial
=
output_h
*
output_w
;
const
float
*
x
=
input
->
data
<
const
float
>
();
int8_t
*
y
=
output
->
mutable_data
<
int8_t
>
();
for
(
int
nc
=
0
;
nc
<
batch_size
*
channels
;
++
nc
)
{
const
float
*
xh
=
x
+
nc
*
input_spatial
;
int8_t
*
yh
=
y
+
nc
*
output_spatial
;
// pad top
for
(
int
h
=
0
;
h
<
pad
;
++
h
,
yh
+=
output_w
)
{
for
(
int
w
=
0
;
w
<
output_w
;
++
w
)
{
yh
[
w
]
=
pad_val
;
}
}
for
(
int
h
=
0
;
h
<
input_h
;
++
h
,
yh
+=
output_w
,
xh
+=
input_w
)
{
// pad left
for
(
int
w
=
0
;
w
<
pad
;
++
w
)
{
yh
[
w
]
=
pad_val
;
}
for
(
int
w
=
0
;
w
<
input_w
;
++
w
)
{
yh
[
w
+
pad
]
=
Round
<
T
>
()(
xh
[
w
]
*
scale
);
}
// pad right
for
(
int
w
=
0
;
w
<
pad
;
++
w
)
{
yh
[
pad
+
input_w
+
w
]
=
pad_val
;
}
}
// pad bottom
for
(
int
h
=
0
;
h
<
pad
;
++
h
,
yh
+=
output_w
)
{
for
(
int
w
=
0
;
w
<
output_w
;
++
w
)
{
yh
[
w
]
=
pad_val
;
}
}
}
}
}
}
}
}
static
void
quantize_round_to_nearest
(
const
Tensor
*
input
,
const
float
scale
,
static
float
find_abs_max
(
const
Tensor
*
input
)
{
Tensor
*
output
)
{
float
max_abs
=
0.
f
;
const
float
*
x
=
input
->
data
<
const
float
>
();
const
float
*
x
=
input
->
data
<
const
float
>
();
int8_t
*
y
=
output
->
mutable_data
<
int8_t
>
();
size_t
size
=
input
->
numel
();
size_t
size
=
input
->
numel
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
y
[
i
]
=
round
(
x
[
i
]
*
scale
);
float
value
=
std
::
abs
(
x
[
i
]);
if
(
value
>
max_abs
)
{
max_abs
=
value
;
}
}
}
return
max_abs
;
}
}
int
TestQuqntizeOp
()
{
int
TestQuqntizeOp
(
int
argc
,
char
*
argv
[])
{
framework
::
DDim
dim
=
framework
::
make_ddim
({
1
,
3
,
224
,
224
});
if
(
argc
<
5
)
{
std
::
cout
<<
"Usage: ./test-quantize-op batch_size channel height width [pad]"
<<
std
::
endl
;
return
1
;
}
int
pad
=
0
;
int
batch_size
=
atoi
(
argv
[
1
]);
int
channel
=
atoi
(
argv
[
2
]);
int
height
=
atoi
(
argv
[
3
]);
int
width
=
atoi
(
argv
[
4
]);
if
(
argc
==
6
)
{
pad
=
atoi
(
argv
[
5
]);
}
std
::
cout
<<
"batch_size: "
<<
batch_size
<<
", channel: "
<<
channel
<<
", height: "
<<
height
<<
", width: "
<<
width
<<
std
::
endl
;
framework
::
DDim
dim
=
framework
::
make_ddim
({
batch_size
,
channel
,
height
,
width
});
VariableNameMap
inputs
;
VariableNameMap
inputs
;
VariableNameMap
outputs
;
VariableNameMap
outputs
;
...
@@ -80,6 +153,7 @@ int TestQuqntizeOp() {
...
@@ -80,6 +153,7 @@ int TestQuqntizeOp() {
auto
output_scale_var
=
scope
.
get
()
->
Var
(
"output_scale"
);
auto
output_scale_var
=
scope
.
get
()
->
Var
(
"output_scale"
);
framework
::
AttributeMap
attrs
;
framework
::
AttributeMap
attrs
;
attrs
[
"paddings"
].
Set
<
vector
<
int
>>
(
std
::
vector
<
int
>
({
pad
,
pad
}));
auto
*
op
=
new
operators
::
QuantizeOp
<
CPU
,
float
>
(
"quantize"
,
inputs
,
outputs
,
auto
*
op
=
new
operators
::
QuantizeOp
<
CPU
,
float
>
(
"quantize"
,
inputs
,
outputs
,
attrs
,
scope
);
attrs
,
scope
);
op
->
InferShape
();
op
->
InferShape
();
...
@@ -96,10 +170,11 @@ int TestQuqntizeOp() {
...
@@ -96,10 +170,11 @@ int TestQuqntizeOp() {
output_scale_cmp
,
output_scale_data
[
0
]);
output_scale_cmp
,
output_scale_data
[
0
]);
framework
::
Tensor
output_cmp
;
framework
::
Tensor
output_cmp
;
output_cmp
.
Resize
(
dim
);
output_cmp
.
Resize
(
output
->
dims
()
);
float
scale
=
127
/
output_scale_cmp
;
float
scale
=
127
/
output_scale_cmp
;
// quantize_round_to_even(input, scale, &output_cmp);
// quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp);
quantize_round_to_nearest
(
input
,
scale
,
&
output_cmp
);
// quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp);
quantize
<
round
::
RoundTowardsZero
>
(
input
,
scale
,
pad
,
0
,
&
output_cmp
);
int8_t
*
output_cmp_data
=
output_cmp
.
data
<
int8_t
>
();
int8_t
*
output_cmp_data
=
output_cmp
.
data
<
int8_t
>
();
for
(
int
i
=
0
;
i
<
output
->
numel
();
++
i
)
{
for
(
int
i
=
0
;
i
<
output
->
numel
();
++
i
)
{
PADDLE_MOBILE_ENFORCE
(
output_data
[
i
]
==
output_cmp_data
[
i
],
PADDLE_MOBILE_ENFORCE
(
output_data
[
i
]
==
output_cmp_data
[
i
],
...
@@ -113,4 +188,6 @@ int TestQuqntizeOp() {
...
@@ -113,4 +188,6 @@ int TestQuqntizeOp() {
}
// namespace paddle_mobile
}
// namespace paddle_mobile
int
main
()
{
return
paddle_mobile
::
TestQuqntizeOp
();
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
paddle_mobile
::
TestQuqntizeOp
(
argc
,
argv
);
}
tools/build.sh
浏览文件 @
8137d199
...
@@ -212,4 +212,4 @@ else
...
@@ -212,4 +212,4 @@ else
else
else
build_error
"
$1
"
build_error
"
$1
"
fi
fi
fi
fi
\ No newline at end of file
tools/op.cmake
浏览文件 @
8137d199
...
@@ -249,6 +249,7 @@ if(NOT FOUND_MATCH)
...
@@ -249,6 +249,7 @@ if(NOT FOUND_MATCH)
set
(
SUM_OP ON
)
set
(
SUM_OP ON
)
set
(
QUANT_OP ON
)
set
(
QUANT_OP ON
)
set
(
DEQUANT_OP ON
)
set
(
DEQUANT_OP ON
)
set
(
FUSION_DEQUANT_ADD_BN_RELU ON
)
endif
()
endif
()
# option(BATCHNORM_OP "" ON)
# option(BATCHNORM_OP "" ON)
...
@@ -450,6 +451,9 @@ endif()
...
@@ -450,6 +451,9 @@ endif()
if
(
DEQUANT_OP
)
if
(
DEQUANT_OP
)
add_definitions
(
-DDEQUANT_OP
)
add_definitions
(
-DDEQUANT_OP
)
endif
()
endif
()
if
(
FUSION_DEQUANT_ADD_BN_RELU
)
add_definitions
(
-DFUSION_DEQUANT_ADD_BN_RELU_OP
)
endif
()
if
(
TANH_OP
)
if
(
TANH_OP
)
add_definitions
(
-DTANH_OP
)
add_definitions
(
-DTANH_OP
)
...
@@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP)
...
@@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP)
endif
()
endif
()
if
(
FUSION_DECONVADDRELU_OP
)
if
(
FUSION_DECONVADDRELU_OP
)
add_definitions
(
-DFUSION_DECONVADDRELU_OP
)
add_definitions
(
-DFUSION_DECONVADDRELU_OP
)
endif
()
endif
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录