提交 b7e92db8 编写于 作者: H hjchen2

Optimize: fuse quantize and pad op

上级 b680fc96
...@@ -233,3 +233,6 @@ LOAD_OP1(quantize, CPU); ...@@ -233,3 +233,6 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU); LOAD_OP1(dequantize, CPU);
#endif #endif
#ifdef PAD_OP
LOAD_OP1(pad, CPU);
#endif
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void DequantizeOp<DeviceType, T>::InferShape() const { void DequantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims(); const auto& input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims); this->param_.output_->Resize(input_dims);
} }
} // namespace operators } // namespace operators
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef CONV_OP #ifdef CONV_OP
#include "operators/kernel/conv_kernel.h" #include "operators/kernel/conv_kernel.h"
#include <iostream>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -22,8 +23,15 @@ namespace operators { ...@@ -22,8 +23,15 @@ namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
if (param->Input()->type() == typeid(int8_t)) { if (param->Filter()->type() == typeid(int8_t)) {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8; if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
}
} else { } else {
if (param->Groups() == param->Input()->dims()[1] && if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1] &&
...@@ -35,6 +43,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { ...@@ -35,6 +43,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) { param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && } else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
...@@ -48,6 +57,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { ...@@ -48,6 +57,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
operators::math::winograd_transform_weight<8, 3>(*param->Filter(), operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight); transformed_weight);
param->Filter() = transformed_weight; param->Filter() = transformed_weight;
#endif
} else { } else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
} }
...@@ -60,25 +70,36 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -60,25 +70,36 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8: case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param); GemmConv<int8_t, int32_t>(param);
std::cout << "EXEC_GEMM_INT8" << std::endl;
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
std::cout << "EXEC_DEPTHWISE3x3_INT8" << std::endl;
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); nullptr, false);
std::cout << "EXEC_DEPTHWISE3x3S1P1_FLOAT" << std::endl;
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false); param.Filter(), nullptr, param.Output(), false);
std::cout << "EXEC_DEPTHWISE3x3_FLOAT=" << param.Strides()[0]
<< std::endl;
break; break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT: case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param); WinogradConv3x3<8, 3>(param);
std::cout << "EXEC_WINOGRAD3X3_FLOAT" << std::endl;
break; break;
case ConvParam<CPU>::EXEC_GEMM_FLOAT: case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param); GemmConv<float, float>(param);
std::cout << "EXEC_GEMM_FLOAT" << std::endl;
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode()); param.ExecMode());
} }
std::cout << "exec here..." << std::endl;
} }
template class ConvKernel<CPU, float>; template class ConvKernel<CPU, float>;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
#include "operators/kernel/dequantize_kernel.h" #include "operators/kernel/dequantize_kernel.h"
#include <iostream>
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
...@@ -31,7 +32,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) { ...@@ -31,7 +32,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template <> template <>
void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) { void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
const Tensor *input = param.input_; const Tensor *input = param.input_;
Tensor *output = param.out_; Tensor *output = param.output_;
float activation_scale = param.activation_scale_->data<float>()[0]; float activation_scale = param.activation_scale_->data<float>()[0];
float weight_scale = param.weight_scale_; float weight_scale = param.weight_scale_;
const int32_t *x = input->data<const int32_t>(); const int32_t *x = input->data<const int32_t>();
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP #ifdef ELEMENTWISEADD_OP
#include "operators/kernel/elementwise_add_kernel.h" #include "operators/kernel/elementwise_add_kernel.h"
#include <iostream>
#include "operators/kernel/central-arm-func/elementwise_add_arm_func.h" #include "operators/kernel/central-arm-func/elementwise_add_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -21,15 +21,15 @@ limitations under the License. */ ...@@ -21,15 +21,15 @@ limitations under the License. */
#include <arm_neon.h> #include <arm_neon.h>
#ifndef __aarch64__ #ifndef __aarch64__
float32_t vmaxvq_f32(float32x4_t r) { inline float32_t vmaxvq_f32(float32x4_t r) {
float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r));
return vget_lane_f32(vpmax_f32(v, v), 0); return vget_lane_f32(vpmax_f32(v, v), 0);
} }
#endif #endif
int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); }
int32x4_t vrnd_away_zero(float32x4_t r) { inline int32x4_t vrnd_away_zero(float32x4_t r) {
float32x4_t plus = vdupq_n_f32(0.5); float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0); float32x4_t zero = vdupq_n_f32(0);
...@@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) { ...@@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) {
return ret; return ret;
} }
int32x4_t vrnd_to_even(float32x4_t r) { inline int32x4_t vrnd_to_even(float32x4_t r) {
#if 0 #if 0
int32x4_t ret; int32x4_t ret;
float value[4]; float value[4];
...@@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) { ...@@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) {
return rnd; return rnd;
#endif #endif
} }
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) { ...@@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) {
return max_abs; return max_abs;
} }
#ifdef __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale, static void quantize_round_to_even(const Tensor *input, const float scale,
Tensor *output) { Tensor *output) {
const float *x = input->data<const float>(); const float *x = input->data<const float>();
...@@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel(); size_t size = input->numel();
#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4; size_t loop = size >> 4;
size_t remain = size & 0xF; size_t remain = size & 0xF;
...@@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
y += (loop << 4); y += (loop << 4);
#endif #endif
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
y[i] = trunc(x[i] * scale); y[i] = static_cast<int8_t>(x[i] * scale);
} }
} }
...@@ -272,6 +272,464 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, ...@@ -272,6 +272,464 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
y[i] = round(x[i] * scale); y[i] = round(x[i] * scale);
} }
} }
#else // __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val,
Tensor *output) {}
static void quantize_round_to_zero(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int input_spatial_size = input_h * input_w;
int output_spatial_size = output_h * output_w;
const float *x = input->data<float>();
int8_t *y = output->mutable_data<int8_t>();
// valid area start
int start = paddings[0] * output_w + paddings[1];
for (int batch = 0; batch < input->dims()[0]; ++batch) {
for (int c = 0; c < channels - 3; c += 4) {
const float *x0 = x + c * input_spatial_size;
const float *x1 = x0 + input_spatial_size;
const float *x2 = x1 + input_spatial_size;
const float *x3 = x2 + input_spatial_size;
size_t offset = c * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
int loop = start >> 4;
int remain = start & 0xFFF0;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"vst1.32 {q0}, [%[y1]]! \n"
"vst1.32 {q0}, [%[y2]]! \n"
"vst1.32 {q0}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"vst1.32 {d0}, [%[y1]]! \n"
"vst1.32 {d0}, [%[y2]]! \n"
"vst1.32 {d0}, [%[y3]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
for (int h = 0; h < input_h; ++h) {
int loop = input_w >> 4;
int remain = input_w & 0xFFF0;
int pad_loop = paddings[1] >> 1;
int pad_remain = paddings[1] & 0xFFFE;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"vst1.32 {q9}, [%[y0]] \n"
"vst1.32 {q10}, [%[y0]] \n"
"vst1.32 {q11}, [%[y0]] \n"
"vst1.32 {q12}, [%[y0]] \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {q1, q2}, [%[x0]] \n"
"vld1.32 {q3, q4}, [%[x1]] \n"
"vld1.32 {q5, q6}, [%[x2]] \n"
"vld1.32 {q7, q8}, [%[x3]] \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vst1.32 {d20}, [%[y1]]! \n"
"vst1.32 {d22}, [%[y2]]! \n"
"vst1.32 {d24}, [%[y3]]! \n"
"vmov.32 d18, d19 \n"
"vmov.32 d20, d21 \n"
"vmov.32 d22, d23 \n"
"vmov.32 d24, d25 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vst1.32 {d20[0]}, [%[y1]]! \n"
"vst1.32 {d22[0]}, [%[y2]]! \n"
"vst1.32 {d24[0]}, [%[y3]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"vext.32 d20, d20, d20, #1 \n"
"vext.32 d22, d22, d22, #1 \n"
"vext.32 d24, d24, d24, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vst1.16 {d20[0]}, [%[y1]]! \n"
"vst1.16 {d22[0]}, [%[y2]]! \n"
"vst1.16 {d24[0]}, [%[y3]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"vext.16 d20, d20, d20, #1 \n"
"vext.16 d22, d22, d22, #1 \n"
"vext.16 d24, d24, d24, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"vst1.8 {d20[0]}, [%[y1]]! \n"
"vst1.8 {d22[0]}, [%[y2]]! \n"
"vst1.8 {d24[0]}, [%[y3]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3),
[y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [scale] "r"(scale)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12");
asm volatile(
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble store_pad_2w_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"store_pad_2w_%=: \n"
"cmp %[pad_remain], #2 \n"
"ble store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"ble end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain)
: [val] "r"(padding_val)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12");
x0 += remain;
x1 += remain;
x2 += remain;
x3 += remain;
}
}
for (int c = (channels & 0xFFFC); c < channels; ++c) {
const float *x0 = x + c * input_spatial_size;
int8_t *y0 = y + c * output_spatial_size;
for (int h = 0; h < paddings[0]; ++h) {
int loop = input_w >> 4;
int remain = input_w & 0xFFF0;
int pad_loop = paddings[1] >> 1;
int pad_remain = paddings[1] & 0xFFFE;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"vst1.32 {q9}, [%[y0]] \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble start_pad_%= \n"
"vld1.32 {q1, q2}, [%[x0]] \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vmov.32 d18, d19 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt start_pad_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"start_pad_%=: \n"
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble pad_remain_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"pad_remain_%=: \n"
"cmp %[pad_remain], #2 \n"
"ble store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"ble end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop),
[remain] "+r"(remain), [pad_loop] "+r"(pad_loop),
[pad_remain] "+r"(pad_remain)
: [scale] "r"(scale), [val] "r"(padding_val)
: "memory", "q0", "q1", "q2", "q9");
x0 += remain;
}
}
}
}
#endif // __aarch64__
#endif // ARM_NEON
template <> template <>
bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) { bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
...@@ -280,10 +738,10 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) { ...@@ -280,10 +738,10 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
template <> template <>
void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
float max_abs = 0.f;
const Tensor *input = param.input_; const Tensor *input = param.input_;
Tensor *output = param.out_; Tensor *output = param.output_;
Tensor *output_scale = param.online_scale_; Tensor *output_scale = param.online_scale_;
float max_abs = 0.f;
if (param.is_static_) { if (param.is_static_) {
max_abs = param.static_scale_; max_abs = param.static_scale_;
} else { } else {
...@@ -293,15 +751,19 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { ...@@ -293,15 +751,19 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
// only support int8 currently // only support int8 currently
float scale = 127 / max_abs; float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs; param.online_scale_->mutable_data<float>()[0] = max_abs;
// const auto &paddings = param.paddings_;
std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t padding_val = 127;
switch (param.round_type_) { switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, scale, output); quantize_round_to_even(input, scale, paddings, padding_val, output);
break; break;
case ROUND_NEAREST_TOWARDS_ZERO: case ROUND_NEAREST_TOWARDS_ZERO:
quantize_round_to_zero(input, scale, output); quantize_round_to_zero(input, scale, paddings, padding_val, output);
break; break;
case ROUND_NEAREST_AWAY_ZERO: case ROUND_NEAREST_AWAY_ZERO:
quantize_round_to_nearest(input, scale, output); quantize_round_to_nearest(input, scale, paddings, padding_val, output);
break; break;
default: default:
LOG(kLOG_ERROR) << "round type is not supported."; LOG(kLOG_ERROR) << "round type is not supported.";
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/pad.h" #include "operators/math/pad.h"
...@@ -39,10 +39,7 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -39,10 +39,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
const std::vector<int> paddings = param.Paddings(); const std::vector<int> paddings = param.Paddings();
const std::vector<int> dilations = param.Dilations(); const std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
...@@ -83,6 +80,7 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -83,6 +80,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
math::Vol2ColFunctor<CPU, Itype> vol2col; math::Vol2ColFunctor<CPU, Itype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col;
const int batch_size = static_cast<int>(input->dims()[0]);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -126,7 +124,6 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -126,7 +124,6 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
int batch_size = input->dims()[0]; int batch_size = input->dims()[0];
int groups = param.Groups(); int groups = param.Groups();
const std::vector<int> &paddings = param.Paddings(); const std::vector<int> &paddings = param.Paddings();
math::PadFunctor<CPU, float> pad;
auto winograd_pad = [&](int width, int pad) { auto winograd_pad = [&](int width, int pad) {
int output_tile = tile - kernel + 1; int output_tile = tile - kernel + 1;
...@@ -136,6 +133,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -136,6 +133,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
return pad_width + tile - width; return pad_width + tile - width;
}; };
math::PadFunctor<CPU, float> pad;
Tensor input_pad; Tensor input_pad;
framework::Tensor transformed_input; framework::Tensor transformed_input;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -155,15 +153,49 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) { ...@@ -155,15 +153,49 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
} else { } else {
input_pad = in_batch; input_pad = in_batch;
} }
#if __aarch64__
// TODO(hjchen2)
#else
// tile input and transform // tile input and transform
math::winograd_transform_input<tile, kernel>(input_pad, &transformed_input); math::winograd_transform_input<tile, kernel>(input_pad, &transformed_input);
// caculate output // caculate output
math::winograd_transform_output<tile, kernel>(transformed_input, *filter, math::winograd_transform_output<tile, kernel>(transformed_input, *filter,
output); output);
#endif }
}
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
Tensor *output = param.Output();
output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
// if (paddings[0] || paddings[1]) {
// framework::DDim pad_shape = in_batch.dims();
// pad_shape[2] += 2 * paddings[0];
// pad_shape[3] += 2 * paddings[1];
// input_pad.mutable_data<float>(pad_shape);
// pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
// &input_pad);
// } else {
// input_pad = in_batch;
// }
// math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter,
// &out_batch);
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(in_batch, *filter, &out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(in_batch, *filter, &out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(in_batch, *filter,
// &out_batch);
}
} }
} }
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) { void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
......
...@@ -15,10 +15,9 @@ limitations under the License. */ ...@@ -15,10 +15,9 @@ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef DEPTHWISECONV_OP
#pragma once #pragma once
#include <operators/math/depthwise_conv_3x3.h>
#include <vector> #include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) { void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "framework/ddim.h" #include "framework/ddim.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias, void DepthwiseConv3x3(const framework::Tensor *input,
Tensor *output, bool if_bias) { const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias) {
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
const int input_height = input->dims()[2]; const int input_height = input->dims()[2];
...@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height); hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width); wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0); hstart = std::max(hstart, 0);
wstart = max(wstart, 0); wstart = std::max(wstart, 0);
hend = min(hend, input_height); hend = std::min(hend, input_height);
wend = min(wend, input_width); wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart;
...@@ -244,8 +248,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -244,8 +248,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
} }
} }
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s1p1(const framework::Tensor *input,
Tensor *output, Tensor *bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
/// w!=h not fix /// w!=h not fix
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
...@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height); hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width); wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0); hstart = std::max(hstart, 0);
wstart = max(wstart, 0); wstart = std::max(wstart, 0);
hend = min(hend, input_height); hend = std::min(hend, input_height);
wend = min(wend, input_width); wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart;
...@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
Tensor *output, Tensor bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
// #ifdef _OPENMP // #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>(); // const float *newscale_data = new_scale->data<float>();
...@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
Tensor *output, Tensor bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -12,23 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/depthwise_conv3x3_int8.h" #include "operators/math/depthwise_conv3x3.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3_int8(const framework::Tensor *input, // template<>
const framework::Tensor *filter, // void DepthwiseConv3x3<int8_t, int32_t>(
const std::vector<int> &strides, // const framework::Tensor *input, const framework::Tensor *filter,
framework::Tensor *output) { // const std::vector<int> &strides, framework::Tensor *output) {
PADDLE_MOBILE_THROW_EXCEPTION( // PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented."); // "Depthwise conv with generic strides has not been implemented.");
} // }
void DepthwiseConv3x3s1_int8(const framework::Tensor &input, template <>
const framework::Tensor &filter, void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
framework::Tensor *output) { const framework::Tensor &filter,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>(); const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>(); const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>(); int32_t *out_data = output->mutable_data<int32_t>();
...@@ -41,26 +42,27 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -41,26 +42,27 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
int output_w = output->dims()[3]; int output_w = output->dims()[3];
int image_size = input_h * input_w; int image_size = input_h * input_w;
int out_image_size = output_h * output_w; int out_image_size = output_h * output_w;
memset(out_data, 0, output_c * out_image_size * sizeof(int32_t));
#if __aarch64__ #if __aarch64__
// TODO(hjchen2) // TODO(hjchen2)
#else #else
#pragma omp parallel for #pragma omp parallel for
for (int g = 0; g < input_c; ++g) { for (int g = 0; g < input_c; ++g) {
const int8_t* input_ptr0 = input_data + g * image_size; const int8_t* input_ptr = input_data + g * image_size;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
const int8_t* filter_ptr = filter_data + g * 9; const int8_t* filter_ptr = filter_data + g * 9;
int32_t* output_ptr0 = out_data + g * out_image_size; int32_t* output_ptr = out_data + g * out_image_size;
int32_t* output_ptr1 = output_ptr0 + output_w; int loop = (input_w - 2) / 6;
int32_t* output_ptr2 = output_ptr1 + output_w; int remain = input_w - 2 - loop * 6;
int32_t* output_ptr3 = output_ptr2 + output_w;
for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) { for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) {
int loop = (input_w - 2) / 6; const int8_t* input_ptr0 = input_ptr + h * input_w;
int remain = input_w - loop * 6; const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int32_t* output_ptr2 = output_ptr1 + output_w;
int32_t* output_ptr3 = output_ptr2 + output_w;
asm volatile( asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n" "vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n" "vmovl.s8 q14, d0 \n"
...@@ -81,14 +83,13 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -81,14 +83,13 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
"mov r0, #6 \n" "mov r0, #6 \n"
"cmp %[loop], #0 \n" "cmp %[loop], #0 \n"
"ble start_remain_%= \n" "ble start_remain_%= \n"
// loop 8 widths // loop 6 widths
"loop_4h8w_%=: \n" "loop_4h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n" "vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n" "vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d12, d9, #1 \n" "vext.s8 d13, d9, d9, #2 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
...@@ -99,11 +100,18 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -99,11 +100,18 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, #1 \n" "vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, #2 \n" "vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, d2 \n"
...@@ -111,57 +119,42 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -111,57 +119,42 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, d2 \n"
"vmlal.s16 q10, d14, d3 \n" "vext.s8 d12, d11, d11, #1 \n"
"vmlal.s16 q10, d16, d4 \n" "vext.s8 d13, d11, d11, #2 \n"
"vmlal.s16 q10, d18, d5 \n" "vmovl.s8 q7, d11 \n"
"vmull.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, #1 \n"
"vext.s8 d13, d11, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmull.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, d8 \n"
"vmull.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11 // store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n" "vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vld1.32 {d10}, [%[input_ptr4]], r0 \n" "vld1.32 {d10}, [%[input_ptr4]], r0 \n"
"vld1.32 {d11}, [%[input_ptr5]], r0 \n" "vld1.32 {d11}, [%[input_ptr5]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d12, d9, #1 \n" "vext.s8 d13, d9, d9, #2 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, d8 \n"
...@@ -178,126 +171,121 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -178,126 +171,121 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
"vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n" "vmlal.s16 q15, d19, d5 \n"
"vext.s8 d12, d10, #1 \n" "vmull.s16 q10, d14, d0 \n"
"vext.s8 d13, d10, #2 \n" "vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n" "vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n" "vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmull.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n" "vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n" "vmlal.s16 q14, d18, d8 \n"
"vmull.s16 q15, d15, d6 \n" "vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n" "vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n" "vmlal.s16 q15, d19, d8 \n"
// store row 2 // store row 2
"vst1.32 {d24-d26}, [%[output_ptr2]]! \n" "vst1.32 {d28-d30}, [%[output_ptr2]]! \n"
"vext.s8 d12, d11, #1 \n" "vmlal.s16 q10, d14, d3 \n"
"vext.s8 d13, d11, #2 \n" "vmlal.s16 q10, d16, d4 \n"
"vmovl.s8 q7, d10 \n" "vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n" "vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d6 \n" "vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q10, d18, d8 \n"
"vmull.s16 q11, d15, d6 \n" "vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q11, d19, d8 \n"
// store row 3 // store row 3
"vst1.32 {d20-d22}, [%[output_ptr3]]! \n" "vst1.32 {d20-d22}, [%[output_ptr3]]! \n"
"subs %[loop], #1 \n" "subs %[loop], #1 \n"
"bne loop_4h8w_%= \n" "bne loop_4h6w_%= \n"
"start_remain_%=: \n" "start_remain_%=: \n"
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble end_%= \n" "ble end_%= \n"
"mov r0, %[remain] \n"
"add r0, #2 \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n" "vld1.32 {d9}, [%[input_ptr0]] \n"
"vext.s8 d12, d9, #1 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmull.s16 q10, d14, d0 \n" "vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n" "vmlal.s16 q10, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr1]] \n"
"vld1.32 {d9}, [%[input_ptr1]], r0 \n"
"vmull.s16 q11, d15, d0 \n" "vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n" "vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d9, #1 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n" "vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n" "vmlal.s16 q12, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr2]] \n"
"vmull.s16 q13, d15, d0 \n" "vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n" "vmlal.s16 q13, d19, d2 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vld1.32 {d9}, [%[input_ptr2]], r0 \n"
"vmull.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d9, #1 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q8, d9 \n"
"vmull.s16 q14, d14, d0 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmlal.s16 q14, d16, d1 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d18, d2 \n" "vmlal.s16 q10, d14, d6 \n"
"vmull.s16 q15, d15, d0 \n" "vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q15, d17, d1 \n" "vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q15, d19, d2 \n" "vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n" "vmlal.s16 q12, d18, d5 \n"
"vmull.s16 q13, d15, d3 \n" "vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n" "vmlal.s16 q13, d19, d5 \n"
"vmlal.s16 q10, d14, d6 \n" "vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q10, d18, d8 \n" "vmlal.s16 q14, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n" "vmull.s16 q15, d15, d0 \n"
"vmull.s16 q11, d15, d6 \n" "vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q15, d19, d2 \n"
"vmlal.s16 q11, d19, d8 \n"
"vext.s8 d12, d9, #1 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q8, d9 \n"
"vmull.s16 q5, d14, d0 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmlal.s16 q5, d16, d1 \n" "vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d18, d2 \n"
"vmull.s16 q6, d15, d0 \n"
"vmlal.s16 q6, d17, d1 \n"
"vmlal.s16 q6, d19, d2 \n"
"vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n" "vmlal.s16 q12, d18, d8 \n"
...@@ -308,42 +296,47 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -308,42 +296,47 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
"vmlal.s16 q14, d14, d3 \n" "vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n" "vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n" "vmlal.s16 q14, d18, d5 \n"
"vld1.32 {d9}, [%[input_ptr4]], r0 \n"
"vmlal.s16 q15, d15, d3 \n" "vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n" "vmlal.s16 q15, d19, d5 \n"
"vext.s8 d12, d9, #1 \n" "vmull.s16 q5, d14, d0 \n"
"vext.s8 d13, d9, #2 \n" "vmlal.s16 q5, d16, d1 \n"
"vmlal.s16 q5, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr4]] \n"
"vmull.s16 q6, d15, d0 \n"
"vmlal.s16 q6, d17, d1 \n"
"vmlal.s16 q6, d19, d2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"vmlal.s16 q5, d14, d3 \n" "vmlal.s16 q5, d14, d3 \n"
"vmlal.s16 q5, d16, d4 \n" "vmlal.s16 q5, d16, d4 \n"
"vmlal.s16 q5, d18, d5 \n" "vmlal.s16 q5, d18, d5 \n"
"vmull.s16 q6, d15, d3 \n" "vld1.32 {d9}, [%[input_ptr5]] \n"
"vmlal.s16 q6, d15, d3 \n"
"vmlal.s16 q6, d17, d4 \n" "vmlal.s16 q6, d17, d4 \n"
"vmlal.s16 q6, d19, d5 \n" "vmlal.s16 q6, d19, d5 \n"
"vmull.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vld1.32 {d9}, [%[input_ptr5]], r0 \n"
"vmull.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"vext.s8 d12, d9, #1 \n"
"vext.s8 d13, d9, #2 \n"
"vmovl.s8 q7, d9 \n" "vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d13 \n" "vmovl.s8 q8, d9 \n"
"vmull.s16 q5, d14, d6 \n" "vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d14, d6 \n"
"vmlal.s16 q5, d16, d7 \n" "vmlal.s16 q5, d16, d7 \n"
"vmlal.s16 q5, d18, d8 \n" "vmlal.s16 q5, d18, d8 \n"
"vmull.s16 q6, d15, d6 \n" "vmlal.s16 q6, d15, d6 \n"
"vmlal.s16 q6, d17, d7 \n" "vmlal.s16 q6, d17, d7 \n"
"vmlal.s16 q6, d19, d8 \n" "vmlal.s16 q6, d19, d8 \n"
...@@ -372,7 +365,7 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -372,7 +365,7 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
"blt end_%= \n" "blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n" "vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n" "vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d27[0]}, [%[output_ptr2]]! \n" "vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d11[0]}, [%[output_ptr3]]! \n" "vst1.32 {d11[0]}, [%[output_ptr3]]! \n"
"b end_%= \n" "b end_%= \n"
...@@ -395,8 +388,1071 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, ...@@ -395,8 +388,1071 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input,
} }
// remain height // remain height
int start_h = (input_h - 2) & 0xFFFC; int start_h = (input_h - 2) & 0xFFFC;
for (int h = start_h; h < input_h; ++h) { for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) {
// TODO(hjchen2) const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_2h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_2h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3)
: [loop] "r"(loop), [remain] "r"(remain)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "r0");
}
start_h = (input_h - 2) & 0xFFFE;
if (start_h < input_h - 2) {
const int8_t* input_ptr0 = input_ptr + start_h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
int32_t* output_ptr0 = output_ptr + start_h * output_w;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2)
: [loop] "r"(loop), [remain] "r"(remain)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "r0");
}
}
#endif // __aarch64__
}
template <>
void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for
for (int g = 0; g < input_c; ++g) {
const int8_t* input_ptr = input_data + g * image_size;
const int8_t* filter_ptr = filter_data + g * 9;
int32_t* output_ptr = out_data + g * out_image_size;
int loop = (input_w - 2) / 6;
int remain = input_w - 2 - loop * 6;
for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int32_t* output_ptr2 = output_ptr1 + output_w;
int32_t* output_ptr3 = output_ptr2 + output_w;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_4h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vld1.32 {d10}, [%[input_ptr4]], r0 \n"
"vld1.32 {d11}, [%[input_ptr5]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
// store row 2
"vst1.32 {d28-d30}, [%[output_ptr2]]! \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 3
"vst1.32 {d20-d22}, [%[output_ptr3]]! \n"
"subs %[loop], #1 \n"
"bne loop_4h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr1]] \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr2]] \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vmull.s16 q5, d14, d0 \n"
"vmlal.s16 q5, d16, d1 \n"
"vmlal.s16 q5, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr4]] \n"
"vmull.s16 q6, d15, d0 \n"
"vmlal.s16 q6, d17, d1 \n"
"vmlal.s16 q6, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"vmlal.s16 q5, d14, d3 \n"
"vmlal.s16 q5, d16, d4 \n"
"vmlal.s16 q5, d18, d5 \n"
"vld1.32 {d9}, [%[input_ptr5]] \n"
"vmlal.s16 q6, d15, d3 \n"
"vmlal.s16 q6, d17, d4 \n"
"vmlal.s16 q6, d19, d5 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d14, d6 \n"
"vmlal.s16 q5, d16, d7 \n"
"vmlal.s16 q5, d18, d8 \n"
"vmlal.s16 q6, d15, d6 \n"
"vmlal.s16 q6, d17, d7 \n"
"vmlal.s16 q6, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_4h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"vst1.32 {q5}, [%[output_ptr3]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d12[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_4h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"vst1.32 {d10}, [%[output_ptr3]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d11[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d10[0]}, [%[output_ptr3]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5)
: [loop] "r"(loop), [remain] "r"(remain)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
}
// remain height
int start_h = (input_h - 2) & 0xFFFC;
for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_2h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_2h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3)
: [loop] "r"(loop), [remain] "r"(remain)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "r0");
}
start_h = (input_h - 2) & 0xFFFE;
if (start_h < input_h - 2) {
const int8_t* input_ptr0 = input_ptr + start_h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
int32_t* output_ptr0 = output_ptr + start_h * output_w;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2)
: [loop] "r"(loop), [remain] "r"(remain)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8",
"q9", "q10", "q11", "r0");
} }
} }
#endif // __aarch64__ #endif // __aarch64__
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3_int8(const framework::Tensor *input,
const framework::Tensor *filter,
const std::vector<int> &strides,
framework::Tensor *output);
void DepthwiseConv3x3s1_int8(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output);
void DepthwiseConv3x3s2_int8(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::Tensor;
using std::max;
using std::min;
using std::vector;
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -26,79 +26,6 @@ limitations under the License. */ ...@@ -26,79 +26,6 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
/*int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;*/
/*
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j);
*buffer++ = *Aij;
*buffer++ = *(Aij + 1);
*buffer++ = *(Aij + 2);
*buffer++ = *(Aij + 3);
}
}
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
Aij = &A(m - m_tail, j);
for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i);
}
for (i = m_tail; i < MR; ++i) {
*buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j);
Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2);
Bj3 = &B(0, j + 3);
for (i = 0; i < k; ++i) {
*buffer++ = *Bj++;
*buffer++ = *Bj1++;
*buffer++ = *Bj2++;
*buffer++ = *Bj3++;
}
}
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j);
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
*/
// 将A矩阵分块复制到连续内存(RowMajor) // 将A矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
......
...@@ -423,6 +423,7 @@ class ConvParam : public OpParam { ...@@ -423,6 +423,7 @@ class ConvParam : public OpParam {
EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT, EXEC_WINOGRAD5X5_FLOAT,
EXEC_GEMM_INT8, EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
}; };
ExecMode &ExecMode() const { return exec_mode_; } ExecMode &ExecMode() const { return exec_mode_; }
...@@ -2498,7 +2499,7 @@ class QuantizeParam : public OpParam { ...@@ -2498,7 +2499,7 @@ class QuantizeParam : public OpParam {
QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
// online // online
// scale = max(abs(x)) // scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope); online_scale_ = GetVarValue<GType>("OutScale", outputs, scope);
...@@ -2517,8 +2518,7 @@ class QuantizeParam : public OpParam { ...@@ -2517,8 +2518,7 @@ class QuantizeParam : public OpParam {
// op input // op input
RType *input_; RType *input_;
// op output // op output
RType *out_; RType *output_;
//
RType *online_scale_; RType *online_scale_;
// if static scale or not // if static scale or not
bool is_static_ = false; bool is_static_ = false;
...@@ -2526,7 +2526,11 @@ class QuantizeParam : public OpParam { ...@@ -2526,7 +2526,11 @@ class QuantizeParam : public OpParam {
float static_scale_ = 1.0f; float static_scale_ = 1.0f;
// round method type // round method type
// nearest_zero and nearest_even is valid currently // nearest_zero and nearest_even is valid currently
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
// optional paddings
std::vector<int> paddings_;
int8_t padding_val_;
}; };
#endif #endif
...@@ -2540,7 +2544,7 @@ class DequantizeParam : public OpParam { ...@@ -2540,7 +2544,7 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
activation_scale_ = GetVarValue<GType>("Scale", inputs, scope); activation_scale_ = GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale // dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) { if (HasAttr("weight_scale", attrs)) {
...@@ -2554,11 +2558,32 @@ class DequantizeParam : public OpParam { ...@@ -2554,11 +2558,32 @@ class DequantizeParam : public OpParam {
// op input // op input
RType *input_; RType *input_;
// op output // op output
RType *out_; RType *output_;
RType *activation_scale_; RType *activation_scale_;
float weight_scale_; float weight_scale_;
}; };
#endif #endif
#ifdef PAD_OP
template <typename Dtype>
class PadParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
paddings_ = GetVarValue<std::vector<int>>("Paddings", inputs, scope);
public:
// op input
RType *input_;
// op output
RType *output_;
// paddings
std::vector<int> paddings_;
};
#endif
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,8 +22,12 @@ namespace operators { ...@@ -22,8 +22,12 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const { void QuantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims(); auto input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims); // const auto &paddings = this->param_.paddings_;
std::vector<int> paddings = {0, 0};
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1}); auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims); this->param_.online_scale_->Resize(scale_dims);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册