diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 982f1c0f3525afde8475866c0121343fafc9d5a0..8766c4b25588c86cccd927cd74c0a75808172314 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -233,3 +233,6 @@ LOAD_OP1(quantize, CPU); #ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); #endif +#ifdef PAD_OP +LOAD_OP1(pad, CPU); +#endif diff --git a/src/operators/dequantize_op.cpp b/src/operators/dequantize_op.cpp index 21cd96368c4938d309f08d036b172607a5afee8c..00d08d683997f45fe7447321efa092a5597921a2 100644 --- a/src/operators/dequantize_op.cpp +++ b/src/operators/dequantize_op.cpp @@ -22,7 +22,7 @@ namespace operators { template void DequantizeOp::InferShape() const { const auto& input_dims = this->param_.input_->dims(); - this->param_.out_->Resize(input_dims); + this->param_.output_->Resize(input_dims); } } // namespace operators diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 9054dbdaadbb2f11356da1249b6ce6d8947f0d54..942765443e4176e2fef2c7115dd0b6329bce6622 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef CONV_OP #include "operators/kernel/conv_kernel.h" +#include #include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { @@ -22,8 +23,15 @@ namespace operators { template <> bool ConvKernel::Init(ConvParam *param) { - if (param->Input()->type() == typeid(int8_t)) { - param->ExecMode() = ConvParam::EXEC_GEMM_INT8; + if (param->Filter()->type() == typeid(int8_t)) { + if (param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1] && + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; + } else { + param->ExecMode() = ConvParam::EXEC_GEMM_INT8; + } } else { if (param->Groups() == param->Input()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1] && @@ -35,6 +43,7 @@ bool ConvKernel::Init(ConvParam *param) { param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == 3) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; +#ifndef __aarch64__ } else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && @@ -48,6 +57,7 @@ bool ConvKernel::Init(ConvParam *param) { operators::math::winograd_transform_weight<8, 3>(*param->Filter(), transformed_weight); param->Filter() = transformed_weight; +#endif } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; } @@ -60,25 +70,36 @@ void ConvKernel::Compute(const ConvParam ¶m) { switch (param.ExecMode()) { case ConvParam::EXEC_GEMM_INT8: GemmConv(param); + std::cout << "EXEC_GEMM_INT8" << std::endl; + break; + case ConvParam::EXEC_DEPTHWISE3x3_INT8: + DepthwiseConv3x3(param); + std::cout << "EXEC_DEPTHWISE3x3_INT8" << std::endl; break; case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), nullptr, false); + std::cout << "EXEC_DEPTHWISE3x3S1P1_FLOAT" << std::endl; break; case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), param.Filter(), nullptr, param.Output(), false); + std::cout << "EXEC_DEPTHWISE3x3_FLOAT=" << param.Strides()[0] + << std::endl; break; case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); + std::cout << "EXEC_WINOGRAD3X3_FLOAT" << std::endl; break; case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); + std::cout << "EXEC_GEMM_FLOAT" << std::endl; break; default: PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", param.ExecMode()); } + std::cout << "exec here..." << std::endl; } template class ConvKernel; diff --git a/src/operators/kernel/arm/dequantize_kernel.cpp b/src/operators/kernel/arm/dequantize_kernel.cpp index 03122047f61c585c3955ca18243ab849fb498728..c5a249fe587d3be1952585ef6c7d5372cd2ea37c 100644 --- a/src/operators/kernel/arm/dequantize_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef DEQUANT_OP #include "operators/kernel/dequantize_kernel.h" +#include #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include @@ -31,7 +32,7 @@ bool DequantizeKernel::Init(DequantizeParam *param) { template <> void DequantizeKernel::Compute(const DequantizeParam ¶m) { const Tensor *input = param.input_; - Tensor *output = param.out_; + Tensor *output = param.output_; float activation_scale = param.activation_scale_->data()[0]; float weight_scale = param.weight_scale_; const int32_t *x = input->data(); diff --git a/src/operators/kernel/arm/elementwise_add_kernel.cpp b/src/operators/kernel/arm/elementwise_add_kernel.cpp index 043d27e72f16ab4b38f31d6cff60bd2f4e89a649..88a9d5f088c7333fde90d8b1bea2bebef35bb94b 100644 --- a/src/operators/kernel/arm/elementwise_add_kernel.cpp +++ b/src/operators/kernel/arm/elementwise_add_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef ELEMENTWISEADD_OP #include "operators/kernel/elementwise_add_kernel.h" +#include #include "operators/kernel/central-arm-func/elementwise_add_arm_func.h" namespace paddle_mobile { diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index 17f442abe4e03d936eb3b317d5b6f164ac0924e7..430fbaed799a94aa3784257d5e2b281acc1b167e 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -21,15 +21,15 @@ limitations under the License. */ #include #ifndef __aarch64__ -float32_t vmaxvq_f32(float32x4_t r) { +inline float32_t vmaxvq_f32(float32x4_t r) { float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); return vget_lane_f32(vpmax_f32(v, v), 0); } #endif -int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } +inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } -int32x4_t vrnd_away_zero(float32x4_t r) { +inline int32x4_t vrnd_away_zero(float32x4_t r) { float32x4_t plus = vdupq_n_f32(0.5); float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t zero = vdupq_n_f32(0); @@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) { return ret; } -int32x4_t vrnd_to_even(float32x4_t r) { +inline int32x4_t vrnd_to_even(float32x4_t r) { #if 0 int32x4_t ret; float value[4]; @@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) { return rnd; #endif } -#endif namespace paddle_mobile { namespace operators { @@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) { return max_abs; } +#ifdef __aarch64__ static void quantize_round_to_even(const Tensor *input, const float scale, Tensor *output) { const float *x = input->data(); @@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, const float *x = input->data(); int8_t *y = output->mutable_data(); size_t size = input->numel(); -#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) +#if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; @@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, y += (loop << 4); #endif for (size_t i = 0; i < size; ++i) { - y[i] = trunc(x[i] * scale); + y[i] = static_cast(x[i] * scale); } } @@ -272,6 +272,464 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, y[i] = round(x[i] * scale); } } +#else // __aarch64__ + +static void quantize_round_to_even(const Tensor *input, const float scale, + const std::vector &paddings, + const int8_t padding_val, Tensor *output) {} + +static void quantize_round_to_nearest(const Tensor *input, const float scale, + const std::vector &paddings, + const int8_t padding_val, + Tensor *output) {} + +static void quantize_round_to_zero(const Tensor *input, const float scale, + const std::vector &paddings, + const int8_t padding_val, Tensor *output) { + int channels = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int input_spatial_size = input_h * input_w; + int output_spatial_size = output_h * output_w; + const float *x = input->data(); + int8_t *y = output->mutable_data(); + // valid area start + int start = paddings[0] * output_w + paddings[1]; + + for (int batch = 0; batch < input->dims()[0]; ++batch) { + for (int c = 0; c < channels - 3; c += 4) { + const float *x0 = x + c * input_spatial_size; + const float *x1 = x0 + input_spatial_size; + const float *x2 = x1 + input_spatial_size; + const float *x3 = x2 + input_spatial_size; + size_t offset = c * output_spatial_size; + for (int h = 0; h < 2; ++h) { + int8_t *y0 = + y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); + int8_t *y1 = y0 + output_spatial_size; + int8_t *y2 = y1 + output_spatial_size; + int8_t *y3 = y2 + output_spatial_size; + int loop = start >> 4; + int remain = start & 0xFFF0; + asm volatile( + "vdup.s8 q0, %[val] \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + + "store_16w_%=: \n" + "vst1.32 {q0}, [%[y0]]! \n" + "vst1.32 {q0}, [%[y1]]! \n" + "vst1.32 {q0}, [%[y2]]! \n" + "vst1.32 {q0}, [%[y3]]! \n" + "subs %[loop], #1 \n" + "bne store_16w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d0}, [%[y0]]! \n" + "vst1.32 {d0}, [%[y1]]! \n" + "vst1.32 {d0}, [%[y2]]! \n" + "vst1.32 {d0}, [%[y3]]! \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "vst1.32 {d0[0]}, [%[y1]]! \n" + "vst1.32 {d0[0]}, [%[y2]]! \n" + "vst1.32 {d0[0]}, [%[y3]]! \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "vst1.16 {d0[0]}, [%[y1]]! \n" + "vst1.16 {d0[0]}, [%[y2]]! \n" + "vst1.16 {d0[0]}, [%[y3]]! \n" + "sub %[remain], #2 \n" + + "store_1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "vst1.8 {d0[0]}, [%[y1]]! \n" + "vst1.8 {d0[0]}, [%[y2]]! \n" + "vst1.8 {d0[0]}, [%[y3]]! \n" + "end_%=: \n" + : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), + [loop] "+r"(loop), [remain] "+r"(remain) + : [val] "r"(padding_val) + : "cc", "memory", "q0"); + } + // quantize valid area + int8_t *y0 = y + offset + start; + int8_t *y1 = y0 + output_spatial_size; + int8_t *y2 = y1 + output_spatial_size; + int8_t *y3 = y2 + output_spatial_size; + for (int h = 0; h < input_h; ++h) { + int loop = input_w >> 4; + int remain = input_w & 0xFFF0; + int pad_loop = paddings[1] >> 1; + int pad_remain = paddings[1] & 0xFFFE; + asm volatile( + "vdup.f32 q0, %[scale] \n" + "cmp %[loop], #0 \n" + "ble quantize_remain_%= \n" + + "loop_quantize_%=: \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vld1.32 {q3, q4}, [%[x1]]! \n" + "vld1.32 {q5, q6}, [%[x2]]! \n" + "vld1.32 {q7, q8}, [%[x3]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d18, q1 \n" + "vmovn.s16 d20, q2 \n" + "vmovn.s16 d22, q3 \n" + "vmovn.s16 d24, q4 \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vld1.32 {q3, q4}, [%[x1]]! \n" + "vld1.32 {q5, q6}, [%[x2]]! \n" + "vld1.32 {q7, q8}, [%[x3]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d19, q1 \n" + "vmovn.s16 d21, q2 \n" + "vmovn.s16 d23, q3 \n" + "vmovn.s16 d25, q4 \n" + "vst1.32 {q9}, [%[y0]] \n" + "vst1.32 {q10}, [%[y0]] \n" + "vst1.32 {q11}, [%[y0]] \n" + "vst1.32 {q12}, [%[y0]] \n" + + "subs %[loop], #1 \n" + "bne loop_quantize_%= \n" + + "quantize_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {q1, q2}, [%[x0]] \n" + "vld1.32 {q3, q4}, [%[x1]] \n" + "vld1.32 {q5, q6}, [%[x2]] \n" + "vld1.32 {q7, q8}, [%[x3]] \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d18, q1 \n" + "vmovn.s16 d20, q2 \n" + "vmovn.s16 d22, q3 \n" + "vmovn.s16 d24, q4 \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vld1.32 {q3, q4}, [%[x1]]! \n" + "vld1.32 {q5, q6}, [%[x2]]! \n" + "vld1.32 {q7, q8}, [%[x3]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d19, q1 \n" + "vmovn.s16 d21, q2 \n" + "vmovn.s16 d23, q3 \n" + "vmovn.s16 d25, q4 \n" + + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d18}, [%[y0]]! \n" + "vst1.32 {d20}, [%[y1]]! \n" + "vst1.32 {d22}, [%[y2]]! \n" + "vst1.32 {d24}, [%[y3]]! \n" + "vmov.32 d18, d19 \n" + "vmov.32 d20, d21 \n" + "vmov.32 d22, d23 \n" + "vmov.32 d24, d25 \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d18[0]}, [%[y0]]! \n" + "vst1.32 {d20[0]}, [%[y1]]! \n" + "vst1.32 {d22[0]}, [%[y2]]! \n" + "vst1.32 {d24[0]}, [%[y3]]! \n" + "vext.32 d18, d18, d18, #1 \n" + "vext.32 d20, d20, d20, #1 \n" + "vext.32 d22, d22, d22, #1 \n" + "vext.32 d24, d24, d24, #1 \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1w_%= \n" + "vst1.16 {d18[0]}, [%[y0]]! \n" + "vst1.16 {d20[0]}, [%[y1]]! \n" + "vst1.16 {d22[0]}, [%[y2]]! \n" + "vst1.16 {d24[0]}, [%[y3]]! \n" + "vext.16 d18, d18, d18, #1 \n" + "vext.16 d20, d20, d20, #1 \n" + "vext.16 d22, d22, d22, #1 \n" + "vext.16 d24, d24, d24, #1 \n" + "sub %[remain], #2 \n" + + "store_1w_%=:" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d18[0]}, [%[y0]]! \n" + "vst1.8 {d20[0]}, [%[y1]]! \n" + "vst1.8 {d22[0]}, [%[y2]]! \n" + "vst1.8 {d24[0]}, [%[y3]]! \n" + + "end_%=: \n" + : [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3), + [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), + [loop] "+r"(loop), [remain] "+r"(remain) + : [scale] "r"(scale) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12"); + asm volatile( + "vdup.s8 d0, %[val] \n" + "cmp %[pad_loop], #0 \n" + "ble store_pad_2w_%= \n" + "loop_pad_4w_%=: \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "vst1.32 {d0[0]}, [%[y1]]! \n" + "vst1.32 {d0[0]}, [%[y2]]! \n" + "vst1.32 {d0[0]}, [%[y3]]! \n" + "subs %[pad_loop], #1 \n" + "bne loop_pad_4w_%= \n" + + "store_pad_2w_%=: \n" + "cmp %[pad_remain], #2 \n" + "ble store_pad_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "vst1.16 {d0[0]}, [%[y1]]! \n" + "vst1.16 {d0[0]}, [%[y2]]! \n" + "vst1.16 {d0[0]}, [%[y3]]! \n" + "sub %[pad_remain], #2 \n" + + "store_pad_1w_%=: \n" + "cmp %[pad_remain], #1 \n" + "ble end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "vst1.8 {d0[0]}, [%[y1]]! \n" + "vst1.8 {d0[0]}, [%[y2]]! \n" + "vst1.8 {d0[0]}, [%[y3]]! \n" + "end_%=: \n" + : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), + [pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain) + : [val] "r"(padding_val) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12"); + + x0 += remain; + x1 += remain; + x2 += remain; + x3 += remain; + } + } + for (int c = (channels & 0xFFFC); c < channels; ++c) { + const float *x0 = x + c * input_spatial_size; + int8_t *y0 = y + c * output_spatial_size; + for (int h = 0; h < paddings[0]; ++h) { + int loop = input_w >> 4; + int remain = input_w & 0xFFF0; + int pad_loop = paddings[1] >> 1; + int pad_remain = paddings[1] & 0xFFFE; + asm volatile( + "vdup.f32 q0, %[scale] \n" + "cmp %[loop], #0 \n" + "ble quantize_remain_%= \n" + + "loop_quantize_%=: \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d18, q1 \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d19, q1 \n" + "vst1.32 {q9}, [%[y0]] \n" + + "subs %[loop], #1 \n" + "bne loop_quantize_%= \n" + + "quantize_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble start_pad_%= \n" + + "vld1.32 {q1, q2}, [%[x0]] \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d18, q1 \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d19, q1 \n" + + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d18}, [%[y0]]! \n" + "vmov.32 d18, d19 \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d18[0]}, [%[y0]]! \n" + "vext.32 d18, d18, d18, #1 \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1w_%= \n" + "vst1.16 {d18[0]}, [%[y0]]! \n" + "vext.16 d18, d18, d18, #1 \n" + "sub %[remain], #2 \n" + + "store_1w_%=:" + "cmp %[remain], #1 \n" + "blt start_pad_%= \n" + "vst1.8 {d18[0]}, [%[y0]]! \n" + + "start_pad_%=: \n" + "vdup.s8 d0, %[val] \n" + "cmp %[pad_loop], #0 \n" + "ble pad_remain_%= \n" + "loop_pad_4w_%=: \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "subs %[pad_loop], #1 \n" + "bne loop_pad_4w_%= \n" + + "pad_remain_%=: \n" + "cmp %[pad_remain], #2 \n" + "ble store_pad_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "sub %[pad_remain], #2 \n" + + "store_pad_1w_%=: \n" + "cmp %[pad_remain], #1 \n" + "ble end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "end_%=: \n" + : [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop), + [remain] "+r"(remain), [pad_loop] "+r"(pad_loop), + [pad_remain] "+r"(pad_remain) + : [scale] "r"(scale), [val] "r"(padding_val) + : "memory", "q0", "q1", "q2", "q9"); + x0 += remain; + } + } + } +} +#endif // __aarch64__ +#endif // ARM_NEON template <> bool QuantizeKernel::Init(QuantizeParam *param) { @@ -280,10 +738,10 @@ bool QuantizeKernel::Init(QuantizeParam *param) { template <> void QuantizeKernel::Compute(const QuantizeParam ¶m) { - float max_abs = 0.f; const Tensor *input = param.input_; - Tensor *output = param.out_; + Tensor *output = param.output_; Tensor *output_scale = param.online_scale_; + float max_abs = 0.f; if (param.is_static_) { max_abs = param.static_scale_; } else { @@ -293,15 +751,19 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { // only support int8 currently float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; + // const auto &paddings = param.paddings_; + std::vector paddings = {0, 0}; + // const auto padding_val = param.padding_val_; + int8_t padding_val = 127; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - quantize_round_to_even(input, scale, output); + quantize_round_to_even(input, scale, paddings, padding_val, output); break; case ROUND_NEAREST_TOWARDS_ZERO: - quantize_round_to_zero(input, scale, output); + quantize_round_to_zero(input, scale, paddings, padding_val, output); break; case ROUND_NEAREST_AWAY_ZERO: - quantize_round_to_nearest(input, scale, output); + quantize_round_to_nearest(input, scale, paddings, padding_val, output); break; default: LOG(kLOG_ERROR) << "round type is not supported."; diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index bacaa866b12957cfc300049c56bb9648fd360770..3b5924ecbf886159d129212cc36c8630cb8cce2f 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index a7d14fbad1e4b72a8571d13898e55a6cad8bf9a8..5374eab51f315ee8baa4f4effe04fc97240aabff 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 956beb53c9a9e9d857d9c129d90443b09c0b3bb8..2fe55b371e8bb392b5c5f0d3792c76ae8592ecf0 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #pragma once #include #include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/pad.h" @@ -39,10 +39,7 @@ inline void GemmConv(const ConvParam ¶m) { const std::vector paddings = param.Paddings(); const std::vector dilations = param.Dilations(); - const int batch_size = static_cast(input->dims()[0]); - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - std::vector output_shape_vec(framework::vectorize(output->dims())); size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); @@ -83,6 +80,7 @@ inline void GemmConv(const ConvParam ¶m) { math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; + const int batch_size = static_cast(input->dims()[0]); for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); @@ -126,7 +124,6 @@ inline void WinogradConv3x3(const ConvParam ¶m) { int batch_size = input->dims()[0]; int groups = param.Groups(); const std::vector &paddings = param.Paddings(); - math::PadFunctor pad; auto winograd_pad = [&](int width, int pad) { int output_tile = tile - kernel + 1; @@ -136,6 +133,7 @@ inline void WinogradConv3x3(const ConvParam ¶m) { return pad_width + tile - width; }; + math::PadFunctor pad; Tensor input_pad; framework::Tensor transformed_input; for (int i = 0; i < batch_size; ++i) { @@ -155,15 +153,49 @@ inline void WinogradConv3x3(const ConvParam ¶m) { } else { input_pad = in_batch; } -#if __aarch64__ - // TODO(hjchen2) -#else // tile input and transform math::winograd_transform_input(input_pad, &transformed_input); // caculate output math::winograd_transform_output(transformed_input, *filter, output); -#endif + } +} + +template +inline void DepthwiseConv3x3(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.Filter(); + Tensor *output = param.Output(); + output->mutable_data(); + + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = static_cast(input->dims()[0]); + Tensor input_pad; + math::PadFunctor pad; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + // if (paddings[0] || paddings[1]) { + // framework::DDim pad_shape = in_batch.dims(); + // pad_shape[2] += 2 * paddings[0]; + // pad_shape[3] += 2 * paddings[1]; + // input_pad.mutable_data(pad_shape); + // pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1], + // &input_pad); + // } else { + // input_pad = in_batch; + // } + // math::DepthwiseConv3x3s1(input_pad, *filter, + // &out_batch); + if (strides[0] == 1) { + math::DepthwiseConv3x3s1(in_batch, *filter, &out_batch); + } else if (strides[0] == 2) { + math::DepthwiseConv3x3s2(in_batch, *filter, &out_batch); + } else { + // math::DepthwiseConv3x3(in_batch, *filter, + // &out_batch); + } } } diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index 7c31eed19693d20084e25daa485a0553d5d795f2..e3fe37e19bd10ec5cbbfb59b556df5af9fecd09e 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index c6300f96e1b999c45538417c7b513068697ad4dd..4c8cf393345d16e79799bc5ce9ecd1be1fc0a15a 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -16,13 +16,15 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" #include "operators/op_param.h" + namespace paddle_mobile { namespace operators { + void ConvBNReluBasic(const FusionConvBNReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 9d25800c77f2fadc5495e58e20d26f0328bcdf3f..b48b03491bab9594f36cad0b21485ae72c8c3c31 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -15,10 +15,9 @@ limitations under the License. */ #ifdef DEPTHWISECONV_OP #pragma once -#include #include #include "operators/kernel/central-arm-func/conv_arm_func.h" - +#include "operators/math/depthwise_conv3x3.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index b60bf9b4d6df9d85cc2fbe378a3904c2d13e5e60..a5c08c26237345320fef89e8f0fdd148534dfc8a 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -16,13 +16,15 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" #include "operators/op_param.h" + namespace paddle_mobile { namespace operators { + void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); diff --git a/src/operators/kernel/conv_add_kernel.h b/src/operators/kernel/conv_add_kernel.h index 4e9ff0853f1d502ebb4dc4ef3641d0a879f32b60..140d0475a8ee2f017a7c587c38429ccbb2edd387 100644 --- a/src/operators/kernel/conv_add_kernel.h +++ b/src/operators/kernel/conv_add_kernel.h @@ -24,7 +24,7 @@ limitations under the License. */ #include "framework/ddim.h" #include "framework/operator.h" #include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp similarity index 96% rename from src/operators/math/depthwise_conv_3x3.cpp rename to src/operators/math/depthwise_conv3x3.cpp index b213f82351e03ddebc47efa672f0d21513a3098f..39b9b8d3f1c5c2bf09a3db5de5216dd1a08b491a 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "operators/math/depthwise_conv_3x3.h" + +#include "operators/math/depthwise_conv3x3.h" +#include #if __ARM_NEON #include #endif -#include namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3(const Tensor *input, vector strides, - vector paddings, const Tensor *filter, Tensor *bias, - Tensor *output, bool if_bias) { + +void DepthwiseConv3x3(const framework::Tensor *input, + const std::vector &strides, + const std::vector &paddings, + const framework::Tensor *filter, framework::Tensor *bias, + framework::Tensor *output, bool if_bias) { const int batch_size = input->dims()[0]; const int input_height = input->dims()[2]; @@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector strides, for (int pw = 0; pw < output_width; pw++) { hstart = ph * stride_height - padding_height; wstart = pw * stride_width - padding_width; - hend = min(hstart + _kernel_size, input_height + padding_height); - wend = min(wstart + _kernel_size, input_width + padding_width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - hend = min(hend, input_height); - wend = min(wend, input_width); + hend = std::min(hstart + _kernel_size, input_height + padding_height); + wend = std::min(wstart + _kernel_size, input_width + padding_width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, input_height); + wend = std::min(wend, input_width); pos1 = input_data + hstart * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart; @@ -244,8 +248,10 @@ void DepthwiseConv3x3(const Tensor *input, vector strides, } } -void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor *bias, bool if_bias) { +void DepthwiseConv3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor *bias, + bool if_bias) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); @@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu) { +void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); @@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } /// w!=h not fix -void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu) { +void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu) { #if __ARM_NEON const int batch_size = input->dims()[0]; @@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, for (int pw = 0; pw < output_width; pw++) { hstart = ph * stride_height - padding_height; wstart = pw * stride_width - padding_width; - hend = min(hstart + _kernel_size, input_height + padding_height); - wend = min(wstart + _kernel_size, input_width + padding_width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - hend = min(hend, input_height); - wend = min(wend, input_width); + hend = std::min(hstart + _kernel_size, input_height + padding_height); + wend = std::min(wstart + _kernel_size, input_width + padding_width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, input_height); + wend = std::min(wend, input_width); pos1 = input_data + hstart * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart; @@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias) { +void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); @@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu) { +void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu) { #if __ARM_NEON // #ifdef _OPENMP // const float *newscale_data = new_scale->data(); @@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias) { +void DepthwiseConv3x3s2p0(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias) { #if __ARM_NEON const int batch_size = static_cast(input->dims()[0]); diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h new file mode 100644 index 0000000000000000000000000000000000000000..ecccd3d0feefbdddb75126f51c302b644786938b --- /dev/null +++ b/src/operators/math/depthwise_conv3x3.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "framework/tensor.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +void DepthwiseConv3x3(const framework::Tensor *input, + const std::vector &strides, + const std::vector &paddings, + const framework::Tensor *filter, framework::Tensor *bias, + framework::Tensor *output, bool if_bias); + +void DepthwiseConv3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor *bias, + bool if_bias); + +void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu); + +void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu); + +void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias); + +void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu); + +void DepthwiseConv3x3s2p0(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias); + +// template +// void DepthwiseConv3x3(const framework::Tensor *input, +// const framework::Tensor *filter, +// const std::vector &strides, +// framework::Tensor *output); + +template +void DepthwiseConv3x3s1(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output); + +template +void DepthwiseConv3x3s2(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index e8e0b653e7586cd9e9bbf48e5a145f7b8e5a207b..648e9e2a8fc1b5b5fdb5e779888a839da47d4936 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -12,23 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "operators/math/depthwise_conv3x3_int8.h" +#include "operators/math/depthwise_conv3x3.h" namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3_int8(const framework::Tensor *input, - const framework::Tensor *filter, - const std::vector &strides, - framework::Tensor *output) { - PADDLE_MOBILE_THROW_EXCEPTION( - "Depthwise conv with generic strides has not been implemented."); -} +// template<> +// void DepthwiseConv3x3( +// const framework::Tensor *input, const framework::Tensor *filter, +// const std::vector &strides, framework::Tensor *output) { +// PADDLE_MOBILE_THROW_EXCEPTION( +// "Depthwise conv with generic strides has not been implemented."); +// } -void DepthwiseConv3x3s1_int8(const framework::Tensor &input, - const framework::Tensor &filter, - framework::Tensor *output) { +template <> +void DepthwiseConv3x3s1(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output) { const int8_t *input_data = input.data(); const int8_t *filter_data = filter.data(); int32_t *out_data = output->mutable_data(); @@ -41,26 +42,27 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, int output_w = output->dims()[3]; int image_size = input_h * input_w; int out_image_size = output_h * output_w; - memset(out_data, 0, output_c * out_image_size * sizeof(int32_t)); #if __aarch64__ // TODO(hjchen2) #else #pragma omp parallel for for (int g = 0; g < input_c; ++g) { - const int8_t* input_ptr0 = input_data + g * image_size; - const int8_t* input_ptr1 = input_ptr0 + input_w; - const int8_t* input_ptr2 = input_ptr1 + input_w; - const int8_t* input_ptr3 = input_ptr2 + input_w; - const int8_t* input_ptr4 = input_ptr3 + input_w; - const int8_t* input_ptr5 = input_ptr4 + input_w; + const int8_t* input_ptr = input_data + g * image_size; const int8_t* filter_ptr = filter_data + g * 9; - int32_t* output_ptr0 = out_data + g * out_image_size; - int32_t* output_ptr1 = output_ptr0 + output_w; - int32_t* output_ptr2 = output_ptr1 + output_w; - int32_t* output_ptr3 = output_ptr2 + output_w; + int32_t* output_ptr = out_data + g * out_image_size; + int loop = (input_w - 2) / 6; + int remain = input_w - 2 - loop * 6; for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) { - int loop = (input_w - 2) / 6; - int remain = input_w - loop * 6; + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + const int8_t* input_ptr4 = input_ptr3 + input_w; + const int8_t* input_ptr5 = input_ptr4 + input_w; + int32_t* output_ptr0 = output_ptr + h * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + int32_t* output_ptr2 = output_ptr1 + output_w; + int32_t* output_ptr3 = output_ptr2 + output_w; asm volatile( "vld1.32 {q0}, [%[filter_ptr]] \n" "vmovl.s8 q14, d0 \n" @@ -81,14 +83,13 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, "mov r0, #6 \n" "cmp %[loop], #0 \n" "ble start_remain_%= \n" - // loop 8 widths - "loop_4h8w_%=: \n" + // loop 6 widths + "loop_4h6w_%=: \n" "vld1.32 {d9}, [%[input_ptr0]], r0 \n" "vld1.32 {d10}, [%[input_ptr1]], r0 \n" "vld1.32 {d11}, [%[input_ptr2]], r0 \n" - - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" "vmovl.s8 q7, d9 \n" "vmovl.s8 q8, d12 \n" "vmovl.s8 q9, d13 \n" @@ -99,11 +100,18 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, "vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d19, d2 \n" - "vext.s8 d12, d10, #1 \n" - "vext.s8 d13, d10, #2 \n" + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" "vmovl.s8 q7, d10 \n" "vmovl.s8 q8, d12 \n" "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + "vmull.s16 q12, d14, d0 \n" "vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d18, d2 \n" @@ -111,57 +119,42 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, "vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d19, d2 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmull.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vext.s8 d12, d11, #1 \n" - "vext.s8 d13, d11, #2 \n" - "vmovl.s8 q7, d10 \n" + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" "vmovl.s8 q8, d12 \n" "vmovl.s8 q9, d13 \n" - "vmull.s16 q14, d14, d0 \n" - "vmlal.s16 q14, d16, d1 \n" - "vmlal.s16 q14, d18, d2 \n" - "vmull.s16 q15, d15, d0 \n" - "vmlal.s16 q15, d17, d1 \n" - "vmlal.s16 q15, d19, d2 \n" - - "vmlal.s16 q12, d14, d3 \n" - "vmlal.s16 q12, d16, d4 \n" - "vmlal.s16 q12, d18, d5 \n" - "vmull.s16 q13, d15, d3 \n" - "vmlal.s16 q13, d17, d4 \n" - "vmlal.s16 q13, d19, d5 \n" - "vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d18, d8 \n" - "vmull.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d19, d8 \n" // store row 0, reuse q10/q11 "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" "vld1.32 {d10}, [%[input_ptr4]], r0 \n" "vld1.32 {d11}, [%[input_ptr5]], r0 \n" - - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" "vmovl.s8 q7, d9 \n" "vmovl.s8 q8, d12 \n" "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - "vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d18, d8 \n" @@ -178,126 +171,121 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, "vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d19, d5 \n" - "vext.s8 d12, d10, #1 \n" - "vext.s8 d13, d10, #2 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" "vmovl.s8 q7, d10 \n" "vmovl.s8 q8, d12 \n" "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmull.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vmull.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d14, d6 \n" "vmlal.s16 q14, d16, d7 \n" "vmlal.s16 q14, d18, d8 \n" - "vmull.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d15, d6 \n" "vmlal.s16 q15, d17, d7 \n" "vmlal.s16 q15, d19, d8 \n" // store row 2 - "vst1.32 {d24-d26}, [%[output_ptr2]]! \n" + "vst1.32 {d28-d30}, [%[output_ptr2]]! \n" - "vext.s8 d12, d11, #1 \n" - "vext.s8 d13, d11, #2 \n" - "vmovl.s8 q7, d10 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" "vmovl.s8 q8, d12 \n" "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d14, d6 \n" "vmlal.s16 q10, d16, d7 \n" "vmlal.s16 q10, d18, d8 \n" - "vmull.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d15, d6 \n" "vmlal.s16 q11, d17, d7 \n" "vmlal.s16 q11, d19, d8 \n" // store row 3 "vst1.32 {d20-d22}, [%[output_ptr3]]! \n" - "subs %[loop], #1 \n" - "bne loop_4h8w_%= \n" + "subs %[loop], #1 \n" + "bne loop_4h6w_%= \n" "start_remain_%=: \n" "cmp %[remain], #0 \n" "ble end_%= \n" - "mov r0, %[remain] \n" - "add r0, #2 \n" - "vld1.32 {d9}, [%[input_ptr0]], r0 \n" - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" + "vld1.32 {d9}, [%[input_ptr0]] \n" "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" "vmull.s16 q10, d14, d0 \n" "vmlal.s16 q10, d16, d1 \n" "vmlal.s16 q10, d18, d2 \n" - - "vld1.32 {d9}, [%[input_ptr1]], r0 \n" + "vld1.32 {d9}, [%[input_ptr1]] \n" "vmull.s16 q11, d15, d0 \n" "vmlal.s16 q11, d17, d1 \n" "vmlal.s16 q11, d19, d2 \n" - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + "vmull.s16 q12, d14, d0 \n" "vmlal.s16 q12, d16, d1 \n" "vmlal.s16 q12, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr2]] \n" "vmull.s16 q13, d15, d0 \n" "vmlal.s16 q13, d17, d1 \n" "vmlal.s16 q13, d19, d2 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - - "vld1.32 {d9}, [%[input_ptr2]], r0 \n" - "vmull.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q14, d14, d0 \n" - "vmlal.s16 q14, d16, d1 \n" - "vmlal.s16 q14, d18, d2 \n" - "vmull.s16 q15, d15, d0 \n" - "vmlal.s16 q15, d17, d1 \n" - "vmlal.s16 q15, d19, d2 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" "vmlal.s16 q12, d14, d3 \n" "vmlal.s16 q12, d16, d4 \n" "vmlal.s16 q12, d18, d5 \n" - "vmull.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d15, d3 \n" "vmlal.s16 q13, d17, d4 \n" "vmlal.s16 q13, d19, d5 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - - "vld1.32 {d9}, [%[input_ptr3]], r0 \n" - "vmull.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr3]] \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q5, d14, d0 \n" - "vmlal.s16 q5, d16, d1 \n" - "vmlal.s16 q5, d18, d2 \n" - "vmull.s16 q6, d15, d0 \n" - "vmlal.s16 q6, d17, d1 \n" - "vmlal.s16 q6, d19, d2 \n" - + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" "vmlal.s16 q12, d14, d6 \n" "vmlal.s16 q12, d16, d7 \n" "vmlal.s16 q12, d18, d8 \n" @@ -308,42 +296,47 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, "vmlal.s16 q14, d14, d3 \n" "vmlal.s16 q14, d16, d4 \n" "vmlal.s16 q14, d18, d5 \n" - - "vld1.32 {d9}, [%[input_ptr4]], r0 \n" "vmlal.s16 q15, d15, d3 \n" "vmlal.s16 q15, d17, d4 \n" "vmlal.s16 q15, d19, d5 \n" - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" + "vmull.s16 q5, d14, d0 \n" + "vmlal.s16 q5, d16, d1 \n" + "vmlal.s16 q5, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr4]] \n" + "vmull.s16 q6, d15, d0 \n" + "vmlal.s16 q6, d17, d1 \n" + "vmlal.s16 q6, d19, d2 \n" + "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d16, d7 \n" + "vmlal.s16 q14, d18, d8 \n" + "vmlal.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d17, d7 \n" + "vmlal.s16 q15, d19, d8 \n" + "vmlal.s16 q5, d14, d3 \n" "vmlal.s16 q5, d16, d4 \n" "vmlal.s16 q5, d18, d5 \n" - "vmull.s16 q6, d15, d3 \n" + "vld1.32 {d9}, [%[input_ptr5]] \n" + "vmlal.s16 q6, d15, d3 \n" "vmlal.s16 q6, d17, d4 \n" "vmlal.s16 q6, d19, d5 \n" - "vmull.s16 q14, d14, d6 \n" - "vmlal.s16 q14, d16, d7 \n" - "vmlal.s16 q14, d18, d8 \n" - - "vld1.32 {d9}, [%[input_ptr5]], r0 \n" - "vmull.s16 q15, d15, d6 \n" - "vmlal.s16 q15, d17, d7 \n" - "vmlal.s16 q15, d19, d8 \n" - - "vext.s8 d12, d9, #1 \n" - "vext.s8 d13, d9, #2 \n" "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q5, d14, d6 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q5, d14, d6 \n" "vmlal.s16 q5, d16, d7 \n" "vmlal.s16 q5, d18, d8 \n" - "vmull.s16 q6, d15, d6 \n" + "vmlal.s16 q6, d15, d6 \n" "vmlal.s16 q6, d17, d7 \n" "vmlal.s16 q6, d19, d8 \n" @@ -372,7 +365,7 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, "blt end_%= \n" "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d27[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" "vst1.32 {d11[0]}, [%[output_ptr3]]! \n" "b end_%= \n" @@ -395,8 +388,1071 @@ void DepthwiseConv3x3s1_int8(const framework::Tensor &input, } // remain height int start_h = (input_h - 2) & 0xFFFC; - for (int h = start_h; h < input_h; ++h) { - // TODO(hjchen2) + for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + int32_t* output_ptr0 = output_ptr + h * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_2h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + // store row 1 + "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" + + "subs %[loop], #1 \n" + "bne loop_2h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vld1.32 {d10}, [%[input_ptr1]] \n" + "vld1.32 {d11}, [%[input_ptr2]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld1.32 {d9}, [%[input_ptr3]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_2h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_2h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3) + : [loop] "r"(loop), [remain] "r"(remain) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "r0"); + } + + start_h = (input_h - 2) & 0xFFFE; + if (start_h < input_h - 2) { + const int8_t* input_ptr0 = input_ptr + start_h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + int32_t* output_ptr0 = output_ptr + start_h * output_w; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_1h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "subs %[loop], #1 \n" + "bne loop_1h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vld1.32 {d10}, [%[input_ptr1]] \n" + "vld1.32 {d11}, [%[input_ptr2]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_1h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), + [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2) + : [loop] "r"(loop), [remain] "r"(remain) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "r0"); + } + } +#endif // __aarch64__ +} + +template <> +void DepthwiseConv3x3s2(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output) { + const int8_t *input_data = input.data(); + const int8_t *filter_data = filter.data(); + int32_t *out_data = output->mutable_data(); + // make sure that batch size is 1 + int input_c = input.dims()[1]; + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; +#if __aarch64__ + // TODO(hjchen2) +#else + #pragma omp parallel for + for (int g = 0; g < input_c; ++g) { + const int8_t* input_ptr = input_data + g * image_size; + const int8_t* filter_ptr = filter_data + g * 9; + int32_t* output_ptr = out_data + g * out_image_size; + int loop = (input_w - 2) / 6; + int remain = input_w - 2 - loop * 6; + for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + const int8_t* input_ptr4 = input_ptr3 + input_w; + const int8_t* input_ptr5 = input_ptr4 + input_w; + int32_t* output_ptr0 = output_ptr + h * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + int32_t* output_ptr2 = output_ptr1 + output_w; + int32_t* output_ptr3 = output_ptr2 + output_w; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_4h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" + + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vld1.32 {d10}, [%[input_ptr4]], r0 \n" + "vld1.32 {d11}, [%[input_ptr5]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + // store row 1 + "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" + + "vmlal.s16 q14, d14, d3 \n" + "vmlal.s16 q14, d16, d4 \n" + "vmlal.s16 q14, d18, d5 \n" + "vmlal.s16 q15, d15, d3 \n" + "vmlal.s16 q15, d17, d4 \n" + "vmlal.s16 q15, d19, d5 \n" + + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d16, d7 \n" + "vmlal.s16 q14, d18, d8 \n" + "vmlal.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d17, d7 \n" + "vmlal.s16 q15, d19, d8 \n" + // store row 2 + "vst1.32 {d28-d30}, [%[output_ptr2]]! \n" + + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 3 + "vst1.32 {d20-d22}, [%[output_ptr3]]! \n" + + "subs %[loop], #1 \n" + "bne loop_4h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr1]] \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr2]] \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr3]] \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + + "vmlal.s16 q14, d14, d3 \n" + "vmlal.s16 q14, d16, d4 \n" + "vmlal.s16 q14, d18, d5 \n" + "vmlal.s16 q15, d15, d3 \n" + "vmlal.s16 q15, d17, d4 \n" + "vmlal.s16 q15, d19, d5 \n" + + "vmull.s16 q5, d14, d0 \n" + "vmlal.s16 q5, d16, d1 \n" + "vmlal.s16 q5, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr4]] \n" + "vmull.s16 q6, d15, d0 \n" + "vmlal.s16 q6, d17, d1 \n" + "vmlal.s16 q6, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d16, d7 \n" + "vmlal.s16 q14, d18, d8 \n" + "vmlal.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d17, d7 \n" + "vmlal.s16 q15, d19, d8 \n" + + "vmlal.s16 q5, d14, d3 \n" + "vmlal.s16 q5, d16, d4 \n" + "vmlal.s16 q5, d18, d5 \n" + "vld1.32 {d9}, [%[input_ptr5]] \n" + "vmlal.s16 q6, d15, d3 \n" + "vmlal.s16 q6, d17, d4 \n" + "vmlal.s16 q6, d19, d5 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q5, d14, d6 \n" + "vmlal.s16 q5, d16, d7 \n" + "vmlal.s16 q5, d18, d8 \n" + "vmlal.s16 q6, d15, d6 \n" + "vmlal.s16 q6, d17, d7 \n" + "vmlal.s16 q6, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_4h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "vst1.32 {q14}, [%[output_ptr2]]! \n" + "vst1.32 {q5}, [%[output_ptr3]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d12[0]}, [%[output_ptr3]]! \n" + "b end_%= \n" + + "store_4h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_4h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "vst1.32 {d28}, [%[output_ptr2]]! \n" + "vst1.32 {d10}, [%[output_ptr3]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d11[0]}, [%[output_ptr3]]! \n" + "b end_%= \n" + + "store_4h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d10[0]}, [%[output_ptr3]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3), + [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5) + : [loop] "r"(loop), [remain] "r"(remain) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); + } + // remain height + int start_h = (input_h - 2) & 0xFFFC; + for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + int32_t* output_ptr0 = output_ptr + h * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_2h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + // store row 1 + "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" + + "subs %[loop], #1 \n" + "bne loop_2h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vld1.32 {d10}, [%[input_ptr1]] \n" + "vld1.32 {d11}, [%[input_ptr2]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld1.32 {d9}, [%[input_ptr3]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_2h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_2h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3) + : [loop] "r"(loop), [remain] "r"(remain) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "r0"); + } + + start_h = (input_h - 2) & 0xFFFE; + if (start_h < input_h - 2) { + const int8_t* input_ptr0 = input_ptr + start_h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + int32_t* output_ptr0 = output_ptr + start_h * output_w; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_1h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "subs %[loop], #1 \n" + "bne loop_1h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vld1.32 {d10}, [%[input_ptr1]] \n" + "vld1.32 {d11}, [%[input_ptr2]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_1h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), + [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2) + : [loop] "r"(loop), [remain] "r"(remain) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "r0"); } } #endif // __aarch64__ diff --git a/src/operators/math/depthwise_conv3x3_int8.h b/src/operators/math/depthwise_conv3x3_int8.h deleted file mode 100644 index f77f7b9d7cbb50a0b5e0897b57433fce5206e815..0000000000000000000000000000000000000000 --- a/src/operators/math/depthwise_conv3x3_int8.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "framework/tensor.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -void DepthwiseConv3x3_int8(const framework::Tensor *input, - const framework::Tensor *filter, - const std::vector &strides, - framework::Tensor *output); - -void DepthwiseConv3x3s1_int8(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output); - -void DepthwiseConv3x3s2_int8(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output); - -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv_3x3.h b/src/operators/math/depthwise_conv_3x3.h deleted file mode 100644 index b146b88e737a07ea08250315fc94653f63d2ad05..0000000000000000000000000000000000000000 --- a/src/operators/math/depthwise_conv_3x3.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "framework/tensor.h" -#include "operators/math/conv_func.h" - -namespace paddle_mobile { -namespace operators { -namespace math { -using framework::Tensor; -using std::max; -using std::min; -using std::vector; - -void DepthwiseConv3x3(const Tensor *input, vector strides, - vector paddings, const Tensor *filter, Tensor *bias, - Tensor *output, bool if_bias); -void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor *bias, bool if_bias); -void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu); -void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu); -void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias); -void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu); - -void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias); -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index d3e6de3134ff91f47c66c927194a5ba688e931b0..c17b2a5e4df0f0ca88da79a9ce55c2ecae0316b5 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -26,79 +26,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { -/*int MC = 0; -int KC = 0; -int NC = 0; - -float *packedA; -float *packedB; -float *packedC; -float *zero; - -typedef void (*FnPack)(int, int, int, const float *, int, float *); -typedef void (*FnAddDot)(int, const float *, const float *, float *, int); - -FnPack procPackA; -FnPack procPackB; -FnAddDot procAddDot;*/ - -/* -// 将A矩阵分块复制到连续内存(ColMajor) -void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - int i, j; - const float *Aij; - for (i = 0; i < m - m_tail; i += MR) { - for (j = 0; j < k; ++j) { - Aij = &A(i, j); - *buffer++ = *Aij; - *buffer++ = *(Aij + 1); - *buffer++ = *(Aij + 2); - *buffer++ = *(Aij + 3); - } - } - if (m_tail != 0) { - for (j = 0; j < k; ++j) { - Aij = &A(m - m_tail, j); - for (i = 0; i < m_tail; ++i) { - *buffer++ = *(Aij + i); - } - for (i = m_tail; i < MR; ++i) { - *buffer++ = 0; - } - } - } -} - -// 将B矩阵分块复制到连续内存(ColMajor) -void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { - int i, j; - const float *Bj, *Bj1, *Bj2, *Bj3; - for (j = 0; j < n - n_tail; j += NR) { - Bj = &B(0, j); - Bj1 = &B(0, j + 1); - Bj2 = &B(0, j + 2); - Bj3 = &B(0, j + 3); - for (i = 0; i < k; ++i) { - *buffer++ = *Bj++; - *buffer++ = *Bj1++; - *buffer++ = *Bj2++; - *buffer++ = *Bj3++; - } - } - if (n_tail != 0) { - for (i = 0; i < k; ++i) { - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = B(i, j); - } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; - } - } - } -} -*/ // 将A矩阵分块复制到连续内存(RowMajor) void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 121f8be09299d8c6ee8f452b289423b4aad8cf34..dedc2e48cc3bb6f53b017d072902c9865cc56782 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -423,6 +423,7 @@ class ConvParam : public OpParam { EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD5X5_FLOAT, EXEC_GEMM_INT8, + EXEC_DEPTHWISE3x3_INT8, }; ExecMode &ExecMode() const { return exec_mode_; } @@ -2498,7 +2499,7 @@ class QuantizeParam : public OpParam { QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + output_ = OutFrom(outputs, scope); // online // scale = max(abs(x)) online_scale_ = GetVarValue("OutScale", outputs, scope); @@ -2517,8 +2518,7 @@ class QuantizeParam : public OpParam { // op input RType *input_; // op output - RType *out_; - // + RType *output_; RType *online_scale_; // if static scale or not bool is_static_ = false; @@ -2526,7 +2526,11 @@ class QuantizeParam : public OpParam { float static_scale_ = 1.0f; // round method type // nearest_zero and nearest_even is valid currently - RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; + // optional paddings + std::vector paddings_; + int8_t padding_val_; }; #endif @@ -2540,7 +2544,7 @@ class DequantizeParam : public OpParam { DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + output_ = OutFrom(outputs, scope); activation_scale_ = GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale if (HasAttr("weight_scale", attrs)) { @@ -2554,11 +2558,32 @@ class DequantizeParam : public OpParam { // op input RType *input_; // op output - RType *out_; + RType *output_; RType *activation_scale_; float weight_scale_; }; #endif +#ifdef PAD_OP +template +class PadParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + input_ = InputXFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + paddings_ = GetVarValue>("Paddings", inputs, scope); + + public: + // op input + RType *input_; + // op output + RType *output_; + // paddings + std::vector paddings_; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/quantize_op.cpp b/src/operators/quantize_op.cpp index 865539d7d26de41b319b4d82ed168b2ec74d722d..ee5e042610751999470aa32a7cb7f1fb8fbb3635 100644 --- a/src/operators/quantize_op.cpp +++ b/src/operators/quantize_op.cpp @@ -22,8 +22,12 @@ namespace operators { template void QuantizeOp::InferShape() const { - const auto& input_dims = this->param_.input_->dims(); - this->param_.out_->Resize(input_dims); + auto input_dims = this->param_.input_->dims(); + // const auto &paddings = this->param_.paddings_; + std::vector paddings = {0, 0}; + input_dims[2] += 2 * paddings[0]; + input_dims[3] += 2 * paddings[1]; + this->param_.output_->Resize(input_dims); auto scale_dims = framework::make_ddim(std::vector{1}); this->param_.online_scale_->Resize(scale_dims); }