提交 91b8d2be 编写于 作者: H hjchen2

Optimize int8 5x5 depthwise conv, add aarch64 macros to make compilation no problem

上级 b901235e
......@@ -31,12 +31,19 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
#ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else if (depth5x5 && param->Strides()[0] < 2 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8;
} else {
#endif // __aarch64__
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
#ifndef __aarch64__
}
#endif // __aarch64__
} else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
......@@ -50,10 +57,10 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5S1_FLOAT;
#ifndef __aarch64__
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
......@@ -79,9 +86,14 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8:
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
......@@ -94,13 +106,14 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5S1_FLOAT:
math::DepthwiseConv5x5S1<float, float>(*param.Input(), *param.Filter(),
param.Paddings(), param.Output());
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
......
......@@ -161,6 +161,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
}
}
#ifndef __aarch64__
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
......@@ -181,13 +182,33 @@ inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3S2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented.");
GemmConv<Itype, Otype>(param);
}
}
}
#endif // __aarch64__
template <typename Itype, typename Otype>
inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output();
output->mutable_data<Otype>();
if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
}
} else {
GemmConv<Itype, Otype>(param);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include "operators/math/depthwise_conv5x5.h"
#include <arm_neon.h>
#include <iostream>
namespace paddle_mobile {
namespace operators {
......
此差异已折叠。
......@@ -3150,9 +3150,11 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifndef __aarch64__
if (m == 1 && bias == nullptr) {
return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu);
}
#endif // __aarch64__
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
......
......@@ -19,6 +19,8 @@ limitations under the License. */
#include <arm_neon.h>
#include "operators/math/pooling.h"
// TODO(hjchen2): Optimize Pooling2x2NormalRow and use inline assembly
namespace paddle_mobile {
namespace operators {
namespace math {
......@@ -60,7 +62,6 @@ struct Pooling2x2NormalRowLoadInput<P, 2> {
}
};
// TODO(hjchen2): To optimize Pooling2x2NormalRow
template <PoolingType P, int Stride>
inline void Pooling2x2NormalRow(const float *input, const int h_output,
const int input_h, const int input_w,
......
......@@ -424,10 +424,10 @@ class ConvParam : public OpParam {
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5S1_FLOAT,
EXEC_DEPTHWISE5x5S2_FLOAT,
EXEC_DEPTHWISE5x5_FLOAT,
EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
EXEC_DEPTHWISE5x5_INT8,
};
ExecMode &ExecMode() const { return exec_mode_; }
......
......@@ -165,14 +165,12 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels,
auto filter = filter_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(filter, filter_shape, -20, 20);
for (int i = 0; i < input->numel(); ++i) {
DLOG << "input[" << i
<< "] = " << static_cast<int>(input->data<int8_t>()[i]);
}
for (int i = 0; i < filter->numel(); ++i) {
DLOG << "filter[" << i
<< "] = " << static_cast<int>(filter->data<int8_t>()[i]);
}
// for (int i = 0; i < input->numel(); ++i) {
// DLOG << "input[" << i << "] = " << float(input->data<Itype>()[i]);
// }
// for (int i = 0; i < filter->numel(); ++i) {
// DLOG << "filter[" << i << "] = " << float(filter->data<Itype>()[i]);
// }
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
......@@ -198,18 +196,12 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels,
// (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6;
// LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms";
int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1;
int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1;
auto output_shape = framework::make_ddim(
std::vector<int>({batch_size, output_c, output_h, output_w}));
// compare results
auto *output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
output_cmp.mutable_data<Otype>(output_shape);
output_cmp.mutable_data<Otype>(output->dims());
conv2d<Itype, Otype>(input, filter, attrs, &output_cmp);
// compare results
auto output = output_var->template Get<framework::LoDTensor>();
const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
......@@ -285,96 +277,39 @@ int main(int argc, char *argv[]) {
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 5, 2>(
in_channels, in_height, in_width, out_channels, groups);
// // kernel = 7, pad = 0, stride = 2
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 1, stride = 2
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 3, stride = 2
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 1, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 3, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 5, stride = 3
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 3, stride = 4
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 3, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 3, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1";
// paddle_mobile::TestConvOp<float, float, 3, 0, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// // kernel = 3, pad = 1, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 3, pad = 1, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
// paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
// paddle_mobile::TestConvOp<float, float, 5, 0, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// // kernel = 5, pad = 2, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 5, pad = 2, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
// paddle_mobile::TestConvOp<float, float, 5, 2, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 1, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 2, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 5, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=5, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 5, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 1, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 5, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=5, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 5, 1>(
in_channels, in_height, in_width, out_channels, groups);
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册