diff --git a/CMakeLists.txt b/CMakeLists.txt index bbf2be8fd7362b876c80f925e057b63443202ce7..f5d68712a64b5a47657a7af9c0e6b47604893e23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,6 +5,7 @@ option(DEBUGING "enable debug mode" ON) option(USE_EXCEPTION "use std exception" ON) option(SYMBOL_HIDDEN "symbol hidden" OFF) # on when use jni or ios io option(LOG_PROFILE "log profile" OFF) + # select the platform to build option(CPU "armv7 with neon" ON) option(GPU_MALI "mali gpu" OFF) @@ -15,7 +16,6 @@ if(FPGA) option(FPGAV2 "fpga v2" OFF) endif() - project(paddle-mobile) file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm) @@ -247,6 +247,3 @@ elseif(FPGA) add_subdirectory(test) endif() - - - diff --git a/src/common/types.cpp b/src/common/types.cpp index ba00f639d76ae7c928f5b7484c08cec0b0926904..36c93046c1c09d0ec5043ef9a7514dedf212e738 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; + const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add"; @@ -134,6 +136,7 @@ std::unordered_map< {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index e9c0f81232dab7583c57fb036b58601aa26ec3c9..5704618c9e475781b22df4ae3a0ac3a994eb8c90 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_FUSION_DECONV_RELU; diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 80a990d5550ded3a5cc049fef366ba7e90938c00..0e00585f9042124e1a62e6ad8ce01ebfbfd541a0 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -30,7 +30,6 @@ limitations under the License. */ #ifdef PADDLE_EXECUTOR_MULTITHREAD #include -#include #include "common/threadpool.h" #endif @@ -73,7 +72,7 @@ Executor::Executor(const framework::Program p, int batch_size, op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), program_.scope); // infer shape to reshape tensor before predict, - // but for lod tensor, it will need to reshape in runtime + // but for lod tensor, it will still need to reshape in runtime if (!loddable_) { op_base->InferShape(); } diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 982f1c0f3525afde8475866c0121343fafc9d5a0..135ef9083e42271fe63cdc29ee53e876f532c287 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU); #ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +LOAD_OP1(fusion_dequant_add_bn_relu, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); +#endif diff --git a/src/framework/operator.h b/src/framework/operator.h index fa7417a2975e224d9cac9bfdd4e28d73a34e019e..464910b613322451d05adcc772825079d0d8f677 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -127,11 +127,6 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape() const = 0; void Init() { - // for (auto i : this->inputs_) { - // DLOG << i.first; - // DLOG << i.second; - // } - PADDLE_MOBILE_ENFORCE(kernel_.Init(¶m_), " %s kernel init failed", this->type_.c_str()); } diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 99d642919b9c34378f7bb90f0b7aacd61aa75d0e..9e6ae7288b755d40973264f8744c7c54f73193bd 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -54,22 +54,6 @@ class Tensor : public TensorBase { this->offset_ = inTensor.offset_; } -#ifdef PADDLE_MOBILE_DEBUG - template - inline void dump(std::string filename) const { - const T *dataptr = data(); - std::ofstream out(filename.c_str()); - for (int i = 0; i < numel(); ++i) { - out << dataptr[i] << " "; - } - out << "形状:"; - for (int j = 0; j < dims_.size(); ++j) { - out << dims_[j] << " "; - } - out.close(); - } -#endif - /*! Resize the dimensions of the memory block. */ inline Tensor &Resize(const DDim &dims) { dims_ = dims; diff --git a/src/operators/dequantize_op.cpp b/src/operators/dequantize_op.cpp index 21cd96368c4938d309f08d036b172607a5afee8c..00d08d683997f45fe7447321efa092a5597921a2 100644 --- a/src/operators/dequantize_op.cpp +++ b/src/operators/dequantize_op.cpp @@ -22,7 +22,7 @@ namespace operators { template void DequantizeOp::InferShape() const { const auto& input_dims = this->param_.input_->dims(); - this->param_.out_->Resize(input_dims); + this->param_.output_->Resize(input_dims); } } // namespace operators diff --git a/src/operators/fusion_dequant_add_bn_relu_op.cpp b/src/operators/fusion_dequant_add_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80d9040afb29b7a42c742b821e9d7522c1a12827 --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_op.cpp @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "operators/fusion_dequant_add_bn_relu_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionDequantAddBNReluOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu, + ops::FusionDequantAddBNReluMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu, + ops::FusionDequantAddBNReluOp); +#endif + +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_op.h b/src/operators/fusion_dequant_add_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dbd9ad0de2ece751ffd4da05cb09f0091a5755aa --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_op.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; } +}; + +template +class FusionDequantAddBNReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluParam, + operators::FusionDequantAddBNReluKernel> { + public: + FusionDequantAddBNReluOp(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluParam, + operators::FusionDequantAddBNReluKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 0d67bdc656f2ba9ad674c18c9cefbd7d9cd711df..840be6c67d2e350c914a7d8aa8e9a32acdd00fb1 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -22,12 +22,76 @@ namespace operators { template <> bool ConvKernel::Init(ConvParam *param) { + if (param->Filter()->type() == typeid(int8_t)) { + if (param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1] && + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 && + param->Strides()[0] == param->Strides()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; + } else { + param->ExecMode() = ConvParam::EXEC_GEMM_INT8; + } + } else { + if (param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1] && + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT; + } else if (param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1] && + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; +#ifndef __aarch64__ + } else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Strides()[0] == param->Strides()[1] && + param->Dilations()[0] == param->Dilations()[1] && + param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 && + param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 && + param->Input()->dims()[1] >= 16 && + param->Input()->dims()[2] <= 140 /* refered from ncnn */) { + param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; + // transform weight + framework::Tensor *transformed_weight = new framework::Tensor; + operators::math::winograd_transform_weight<8, 3>(*param->Filter(), + transformed_weight); + param->Filter() = transformed_weight; +#endif + } else { + param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; + } + } return true; } template <> void ConvKernel::Compute(const ConvParam ¶m) { - ConvCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_GEMM_INT8: + GemmConv(param); + break; + case ConvParam::EXEC_DEPTHWISE3x3_INT8: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: + math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), + nullptr, false); + break; + case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), nullptr, param.Output(), false); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } } template class ConvKernel; diff --git a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfe1935c216f94d660997b1bfa42f18e63295992 --- /dev/null +++ b/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +namespace paddle_mobile { +namespace operators { + +template <> +bool FusionDequantAddBNReluKernel::Init( + FusionDequantAddBNReluParam *param) { + // elementwise add params + const Tensor *bias = param->bias_; + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *bias_ptr = bias->data(); + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c]; + } + return true; +} + +template <> +void FusionDequantAddBNReluKernel::Compute( + const FusionDequantAddBNReluParam ¶m) { + const int32_t *input = param.input_->data(); + const float *bn_scale = param.bn_scale_->data(); + const float *bn_bias = param.bn_bias_->data(); + // dequantize params + const float activation_scale = param.activation_scale_->data()[0]; + const float weight_scale = param.weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + + float *output = param.output_->mutable_data(); + int batch_size = param.input_->dims()[0]; + int channels = param.input_->dims()[1]; + size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3]; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + float scale = bn_scale[c] * dequant_scale; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + float *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __zero = vdupq_n_f32(0.f); + + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmaxq_f32(__zero, f0); + f1 = vmaxq_f32(__zero, f1); + f2 = vmaxq_f32(__zero, f2); + f3 = vmaxq_f32(__zero, f3); + vst1q_f32(y, f0); + vst1q_f32(y + 4, f1); + vst1q_f32(y + 8, f2); + vst1q_f32(y + 12, f3); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + y[k] = std::max(scale * x[k] + bias, 0.f); + } + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // FUSION_DEQUANT_ADD_BN_RELU_OP diff --git a/src/operators/kernel/arm/dequantize_kernel.cpp b/src/operators/kernel/arm/dequantize_kernel.cpp index ea893730c1148158f574fb6c467265b334ba2f45..2c13cac1a673a37581cc1748037f4d879fcd7b56 100644 --- a/src/operators/kernel/arm/dequantize_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_kernel.cpp @@ -31,7 +31,7 @@ bool DequantizeKernel::Init(DequantizeParam *param) { template <> void DequantizeKernel::Compute(const DequantizeParam ¶m) { const Tensor *input = param.input_; - Tensor *output = param.out_; + Tensor *output = param.output_; float activation_scale = param.activation_scale_->data()[0]; float weight_scale = param.weight_scale_; const int32_t *x = input->data(); @@ -43,11 +43,15 @@ void DequantizeKernel::Compute(const DequantizeParam ¶m) { size_t loop = size >> 4; size_t remain = size & 0xF; float32x4_t s = vdupq_n_f32(scale); + + #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { - int32x4_t r0 = vld1q_s32(x); - int32x4_t r1 = vld1q_s32(x + 4); - int32x4_t r2 = vld1q_s32(x + 8); - int32x4_t r3 = vld1q_s32(x + 12); + const int32_t *local_x = x + (i << 4); + float *local_y = y + (i << 4); + int32x4_t r0 = vld1q_s32(local_x); + int32x4_t r1 = vld1q_s32(local_x + 4); + int32x4_t r2 = vld1q_s32(local_x + 8); + int32x4_t r3 = vld1q_s32(local_x + 12); float32x4_t f0 = vcvtq_f32_s32(r0); float32x4_t f1 = vcvtq_f32_s32(r1); float32x4_t f2 = vcvtq_f32_s32(r2); @@ -56,14 +60,14 @@ void DequantizeKernel::Compute(const DequantizeParam ¶m) { f1 = vmulq_f32(f1, s); f2 = vmulq_f32(f2, s); f3 = vmulq_f32(f3, s); - vst1q_f32(y, f0); - vst1q_f32(y + 4, f1); - vst1q_f32(y + 8, f2); - vst1q_f32(y + 12, f3); - x += 16; - y += 16; + vst1q_f32(local_y, f0); + vst1q_f32(local_y + 4, f1); + vst1q_f32(local_y + 8, f2); + vst1q_f32(local_y + 12, f3); } size = remain; + x += (loop << 4); + y += (loop << 4); #endif for (size_t i = 0; i < size; ++i) { y[i] = x[i] * scale; diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index 17f442abe4e03d936eb3b317d5b6f164ac0924e7..1e7623436a1a73644aca61e4634a7cd405bd64ad 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -21,15 +21,15 @@ limitations under the License. */ #include #ifndef __aarch64__ -float32_t vmaxvq_f32(float32x4_t r) { +inline float32_t vmaxvq_f32(float32x4_t r) { float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); return vget_lane_f32(vpmax_f32(v, v), 0); } #endif -int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } +inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } -int32x4_t vrnd_away_zero(float32x4_t r) { +inline int32x4_t vrnd_away_zero(float32x4_t r) { float32x4_t plus = vdupq_n_f32(0.5); float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t zero = vdupq_n_f32(0); @@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) { return ret; } -int32x4_t vrnd_to_even(float32x4_t r) { +inline int32x4_t vrnd_to_even(float32x4_t r) { #if 0 int32x4_t ret; float value[4]; @@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) { return rnd; #endif } -#endif namespace paddle_mobile { namespace operators { @@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) { return max_abs; } +#ifdef __aarch64__ static void quantize_round_to_even(const Tensor *input, const float scale, Tensor *output) { const float *x = input->data(); @@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, const float *x = input->data(); int8_t *y = output->mutable_data(); size_t size = input->numel(); -#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) +#if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = size >> 4; size_t remain = size & 0xF; @@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, y += (loop << 4); #endif for (size_t i = 0; i < size; ++i) { - y[i] = trunc(x[i] * scale); + y[i] = static_cast(x[i] * scale); } } @@ -272,6 +272,508 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, y[i] = round(x[i] * scale); } } +#else // __aarch64__ + +static void quantize_round_to_even(const Tensor *input, const float scale, + const std::vector &paddings, + const int8_t padding_val, Tensor *output) {} + +static void quantize_round_to_nearest(const Tensor *input, const float scale, + const std::vector &paddings, + const int8_t padding_val, + Tensor *output) {} + +static void quantize_round_to_zero(const Tensor *input, const float scale, + const std::vector &paddings, + const int8_t padding_val, Tensor *output) { + int channels = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int input_spatial_size = input_h * input_w; + int output_spatial_size = output_h * output_w; + const float *x = input->data(); + int8_t *y = output->mutable_data(); + // valid area start + int start = paddings[0] * output_w + paddings[1]; + + for (int batch = 0; batch < input->dims()[0]; ++batch) { + #pragma omp parallel for + for (int c = 0; c < channels - 3; c += 4) { + const float *input0 = x + (batch * channels + c) * input_spatial_size; + const float *input1 = input0 + input_spatial_size; + const float *input2 = input1 + input_spatial_size; + const float *input3 = input2 + input_spatial_size; + size_t offset = (batch * channels + c) * output_spatial_size; + for (int h = 0; h < 2; ++h) { + int8_t *y0 = + y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); + int8_t *y1 = y0 + output_spatial_size; + int8_t *y2 = y1 + output_spatial_size; + int8_t *y3 = y2 + output_spatial_size; + int loop = start >> 4; + int remain = start & 0xF; + asm volatile( + "vdup.s8 q0, %[val] \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + + "store_16w_%=: \n" + "vst1.32 {q0}, [%[y0]]! \n" + "vst1.32 {q0}, [%[y1]]! \n" + "vst1.32 {q0}, [%[y2]]! \n" + "vst1.32 {q0}, [%[y3]]! \n" + "subs %[loop], #1 \n" + "bne store_16w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d0}, [%[y0]]! \n" + "vst1.32 {d0}, [%[y1]]! \n" + "vst1.32 {d0}, [%[y2]]! \n" + "vst1.32 {d0}, [%[y3]]! \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "vst1.32 {d0[0]}, [%[y1]]! \n" + "vst1.32 {d0[0]}, [%[y2]]! \n" + "vst1.32 {d0[0]}, [%[y3]]! \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "vst1.16 {d0[0]}, [%[y1]]! \n" + "vst1.16 {d0[0]}, [%[y2]]! \n" + "vst1.16 {d0[0]}, [%[y3]]! \n" + "sub %[remain], #2 \n" + + "store_1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "vst1.8 {d0[0]}, [%[y1]]! \n" + "vst1.8 {d0[0]}, [%[y2]]! \n" + "vst1.8 {d0[0]}, [%[y3]]! \n" + "end_%=: \n" + : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), + [loop] "+r"(loop), [remain] "+r"(remain) + : [val] "r"(padding_val) + : "cc", "memory", "q0"); + } + // quantize valid area + int8_t *y0 = y + offset + start; + int8_t *y1 = y0 + output_spatial_size; + int8_t *y2 = y1 + output_spatial_size; + int8_t *y3 = y2 + output_spatial_size; + for (int h = 0; h < input_h; ++h) { + const float *x0 = input0 + h * input_w; + const float *x1 = input1 + h * input_w; + const float *x2 = input2 + h * input_w; + const float *x3 = input3 + h * input_w; + int loop = input_w >> 4; + int remain = input_w & 0xF; + int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 + int pad_remain = (paddings[1] << 1) & 0x3; + int remain_steps = remain; + asm volatile( + "vdup.f32 q0, %[scale] \n" + "cmp %[loop], #0 \n" + "ble quantize_remain_%= \n" + + "loop_quantize_%=: \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vld1.32 {q3, q4}, [%[x1]]! \n" + "vld1.32 {q5, q6}, [%[x2]]! \n" + "vld1.32 {q7, q8}, [%[x3]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d18, q1 \n" + "vmovn.s16 d20, q2 \n" + "vmovn.s16 d22, q3 \n" + "vmovn.s16 d24, q4 \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vld1.32 {q3, q4}, [%[x1]]! \n" + "vld1.32 {q5, q6}, [%[x2]]! \n" + "vld1.32 {q7, q8}, [%[x3]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d19, q1 \n" + "vmovn.s16 d21, q2 \n" + "vmovn.s16 d23, q3 \n" + "vmovn.s16 d25, q4 \n" + "vst1.32 {q9}, [%[y0]]! \n" + "vst1.32 {q10}, [%[y1]]! \n" + "vst1.32 {q11}, [%[y2]]! \n" + "vst1.32 {q12}, [%[y3]]! \n" + + "subs %[loop], #1 \n" + "bne loop_quantize_%= \n" + + "quantize_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vld1.32 {q3, q4}, [%[x1]]! \n" + "vld1.32 {q5, q6}, [%[x2]]! \n" + "vld1.32 {q7, q8}, [%[x3]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d18, q1 \n" + "vmovn.s16 d20, q2 \n" + "vmovn.s16 d22, q3 \n" + "vmovn.s16 d24, q4 \n" + "vld1.32 {q1, q2}, [%[x0]] \n" + "vld1.32 {q3, q4}, [%[x1]] \n" + "vld1.32 {q5, q6}, [%[x2]] \n" + "vld1.32 {q7, q8}, [%[x3]] \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vmul.f32 q5, q5, q0 \n" + "vmul.f32 q6, q6, q0 \n" + "vmul.f32 q7, q7, q0 \n" + "vmul.f32 q8, q8, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vcvt.s32.f32 q3, q3 \n" + "vcvt.s32.f32 q4, q4 \n" + "vcvt.s32.f32 q5, q5 \n" + "vcvt.s32.f32 q6, q6 \n" + "vcvt.s32.f32 q7, q7 \n" + "vcvt.s32.f32 q8, q8 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s32 d4, q3 \n" + "vmovn.s32 d5, q4 \n" + "vmovn.s32 d6, q5 \n" + "vmovn.s32 d7, q6 \n" + "vmovn.s32 d8, q7 \n" + "vmovn.s32 d9, q8 \n" + "vmovn.s16 d19, q1 \n" + "vmovn.s16 d21, q2 \n" + "vmovn.s16 d23, q3 \n" + "vmovn.s16 d25, q4 \n" + + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d18}, [%[y0]]! \n" + "vst1.32 {d20}, [%[y1]]! \n" + "vst1.32 {d22}, [%[y2]]! \n" + "vst1.32 {d24}, [%[y3]]! \n" + "vmov.32 d18, d19 \n" + "vmov.32 d20, d21 \n" + "vmov.32 d22, d23 \n" + "vmov.32 d24, d25 \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d18[0]}, [%[y0]]! \n" + "vst1.32 {d20[0]}, [%[y1]]! \n" + "vst1.32 {d22[0]}, [%[y2]]! \n" + "vst1.32 {d24[0]}, [%[y3]]! \n" + "vext.32 d18, d18, d18, #1 \n" + "vext.32 d20, d20, d20, #1 \n" + "vext.32 d22, d22, d22, #1 \n" + "vext.32 d24, d24, d24, #1 \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1w_%= \n" + "vst1.16 {d18[0]}, [%[y0]]! \n" + "vst1.16 {d20[0]}, [%[y1]]! \n" + "vst1.16 {d22[0]}, [%[y2]]! \n" + "vst1.16 {d24[0]}, [%[y3]]! \n" + "vext.16 d18, d18, d18, #1 \n" + "vext.16 d20, d20, d20, #1 \n" + "vext.16 d22, d22, d22, #1 \n" + "vext.16 d24, d24, d24, #1 \n" + "sub %[remain], #2 \n" + + "store_1w_%=:" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d18[0]}, [%[y0]]! \n" + "vst1.8 {d20[0]}, [%[y1]]! \n" + "vst1.8 {d22[0]}, [%[y2]]! \n" + "vst1.8 {d24[0]}, [%[y3]]! \n" + + "end_%=: \n" + : [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3), + [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), + [loop] "+r"(loop), [remain] "+r"(remain) + : [scale] "r"(scale) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + asm volatile( + "vdup.s8 d0, %[val] \n" + "cmp %[pad_loop], #0 \n" + "ble store_pad_2w_%= \n" + "loop_pad_4w_%=: \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "vst1.32 {d0[0]}, [%[y1]]! \n" + "vst1.32 {d0[0]}, [%[y2]]! \n" + "vst1.32 {d0[0]}, [%[y3]]! \n" + "subs %[pad_loop], #1 \n" + "bne loop_pad_4w_%= \n" + + "store_pad_2w_%=: \n" + "cmp %[pad_remain], #2 \n" + "blt store_pad_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "vst1.16 {d0[0]}, [%[y1]]! \n" + "vst1.16 {d0[0]}, [%[y2]]! \n" + "vst1.16 {d0[0]}, [%[y3]]! \n" + "sub %[pad_remain], #2 \n" + + "store_pad_1w_%=: \n" + "cmp %[pad_remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "vst1.8 {d0[0]}, [%[y1]]! \n" + "vst1.8 {d0[0]}, [%[y2]]! \n" + "vst1.8 {d0[0]}, [%[y3]]! \n" + "end_%=: \n" + : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), + [pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain) + : [val] "r"(padding_val) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12"); + } + } + for (int c = (channels & 0xFFFC); c < channels; ++c) { + const float *input0 = x + (batch * channels + c) * input_spatial_size; + size_t offset = (batch * channels + c) * output_spatial_size; + for (int h = 0; h < 2; ++h) { + int8_t *y0 = + y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); + int loop = start >> 4; + int remain = start & 0xF; + asm volatile( + "vdup.s8 q0, %[val] \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + + "store_16w_%=: \n" + "vst1.32 {q0}, [%[y0]]! \n" + "subs %[loop], #1 \n" + "bne store_16w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d0}, [%[y0]]! \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "sub %[remain], #2 \n" + + "store_1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "end_%=: \n" + : [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain) + : [val] "r"(padding_val) + : "cc", "memory", "q0"); + } + // quantize valid area + int8_t *y0 = y + offset + start; + for (int h = 0; h < input_h; ++h) { + const float *x0 = input0 + h * input_w; + int loop = input_w >> 4; + int remain = input_w & 0xF; + int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 + int pad_remain = (paddings[1] << 1) & 0x3; + asm volatile( + "vdup.f32 q0, %[scale] \n" + "cmp %[loop], #0 \n" + "ble quantize_remain_%= \n" + + "loop_quantize_%=: \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d18, q1 \n" + "vld1.32 {q1, q2}, [%[x0]]! \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d19, q1 \n" + "vst1.32 {q9}, [%[y0]]! \n" + + "subs %[loop], #1 \n" + "bne loop_quantize_%= \n" + + "quantize_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble start_pad_%= \n" + + "vldm %[x0], {d2-d9} \n" + "vmul.f32 q1, q1, q0 \n" + "vmul.f32 q2, q2, q0 \n" + "vcvt.s32.f32 q1, q1 \n" + "vcvt.s32.f32 q2, q2 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d18, q1 \n" + "vmul.f32 q3, q3, q0 \n" + "vmul.f32 q4, q4, q0 \n" + "vcvt.s32.f32 q1, q3 \n" + "vcvt.s32.f32 q2, q4 \n" + "vmovn.s32 d2, q1 \n" + "vmovn.s32 d3, q2 \n" + "vmovn.s16 d19, q1 \n" + + "cmp %[remain], #8 \n" + "blt store_4w_%= \n" + "vst1.32 {d18}, [%[y0]]! \n" + "vmov.32 d18, d19 \n" + "sub %[remain], #8 \n" + + "store_4w_%=: \n" + "cmp %[remain], #4 \n" + "blt store_2w_%= \n" + "vst1.32 {d18[0]}, [%[y0]]! \n" + "vext.32 d18, d18, d18, #1 \n" + "sub %[remain], #4 \n" + + "store_2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1w_%= \n" + "vst1.16 {d18[0]}, [%[y0]]! \n" + "vext.16 d18, d18, d18, #1 \n" + "sub %[remain], #2 \n" + + "store_1w_%=:" + "cmp %[remain], #1 \n" + "blt start_pad_%= \n" + "vst1.8 {d18[0]}, [%[y0]]! \n" + + "start_pad_%=: \n" + "vdup.s8 d0, %[val] \n" + "cmp %[pad_loop], #0 \n" + "ble pad_remain_%= \n" + "loop_pad_4w_%=: \n" + "vst1.32 {d0[0]}, [%[y0]]! \n" + "subs %[pad_loop], #1 \n" + "bne loop_pad_4w_%= \n" + + "pad_remain_%=: \n" + "cmp %[pad_remain], #2 \n" + "blt store_pad_1w_%= \n" + "vst1.16 {d0[0]}, [%[y0]]! \n" + "sub %[pad_remain], #2 \n" + + "store_pad_1w_%=: \n" + "cmp %[pad_remain], #1 \n" + "blt end_%= \n" + "vst1.8 {d0[0]}, [%[y0]]! \n" + "end_%=: \n" + : [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop), + [remain] "+r"(remain), [pad_loop] "+r"(pad_loop), + [pad_remain] "+r"(pad_remain) + : [scale] "r"(scale), [val] "r"(padding_val) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q9"); + } + } + } +} +#endif // __aarch64__ +#endif // ARM_NEON template <> bool QuantizeKernel::Init(QuantizeParam *param) { @@ -280,10 +782,10 @@ bool QuantizeKernel::Init(QuantizeParam *param) { template <> void QuantizeKernel::Compute(const QuantizeParam ¶m) { - float max_abs = 0.f; const Tensor *input = param.input_; - Tensor *output = param.out_; + Tensor *output = param.output_; Tensor *output_scale = param.online_scale_; + float max_abs = 0.f; if (param.is_static_) { max_abs = param.static_scale_; } else { @@ -293,15 +795,19 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { // only support int8 currently float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; + const auto &paddings = param.paddings_; + // std::vector paddings = {0, 0}; + // const auto padding_val = param.padding_val_; + int8_t padding_val = 0; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - quantize_round_to_even(input, scale, output); + quantize_round_to_even(input, scale, paddings, padding_val, output); break; case ROUND_NEAREST_TOWARDS_ZERO: - quantize_round_to_zero(input, scale, output); + quantize_round_to_zero(input, scale, paddings, padding_val, output); break; case ROUND_NEAREST_AWAY_ZERO: - quantize_round_to_nearest(input, scale, output); + quantize_round_to_nearest(input, scale, paddings, padding_val, output); break; default: LOG(kLOG_ERROR) << "round type is not supported."; diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index bacaa866b12957cfc300049c56bb9648fd360770..3b5924ecbf886159d129212cc36c8630cb8cce2f 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #include #include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index a7d14fbad1e4b72a8571d13898e55a6cad8bf9a8..5374eab51f315ee8baa4f4effe04fc97240aabff 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index e7a8c7f52db327f3ff5871566c3557c484ba4d13..b01a654c713f2328d62714f23af68d606380d203 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -17,18 +17,19 @@ limitations under the License. */ #pragma once #include #include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/pad.h" #include "operators/math/vol2col.h" +#include "operators/math/winograd/winograd_transform.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { template -inline void ConvBasic(const ConvParam ¶m) { +inline void GemmConv(const ConvParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor *output = param.Output(); @@ -38,10 +39,7 @@ inline void ConvBasic(const ConvParam ¶m) { const std::vector paddings = param.Paddings(); const std::vector dilations = param.Dilations(); - const int batch_size = static_cast(input->dims()[0]); - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - std::vector output_shape_vec(framework::vectorize(output->dims())); size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); @@ -82,6 +80,7 @@ inline void ConvBasic(const ConvParam ¶m) { math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; + const int batch_size = static_cast(input->dims()[0]); for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); @@ -99,7 +98,6 @@ inline void ConvBasic(const ConvParam ¶m) { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (data_dim == 3U) { // vol2col vol2col(in_slice, dilations, strides, paddings, &col); @@ -116,25 +114,86 @@ inline void ConvBasic(const ConvParam ¶m) { } } -template -void ConvCompute(const ConvParam ¶m) { - if (param.Input()->type() == typeid(int8_t)) { - ConvBasic(param); - } else { - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - nullptr, false); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3) { - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); +template +inline void WinogradConv3x3(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.Filter(); + Tensor *output = param.Output(); + output->mutable_data(); + int batch_size = input->dims()[0]; + int groups = param.Groups(); + const std::vector &paddings = param.Paddings(); + + auto winograd_pad = [&](int width, int pad) { + int output_tile = tile - kernel + 1; + // int tiles = (width + pad - kernel) / output_tile + 1; + // return (tiles - 1) * output_tile + tile - width; + int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile; + return pad_width + tile - width; + }; + + math::PadFunctor pad; + Tensor input_pad; + framework::Tensor transformed_input; + for (int i = 0; i < batch_size; ++i) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + // int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]); + // int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]); + int pad_bottom = paddings[0]; + int pad_right = paddings[1]; + if (paddings[0] || paddings[1] || pad_bottom || pad_right) { + framework::DDim pad_shape = in_batch.dims(); + pad_shape[2] += paddings[0] + pad_bottom; + pad_shape[3] += paddings[1] + pad_right; + input_pad.mutable_data(pad_shape); + pad(in_batch, paddings[0], pad_bottom, paddings[1], pad_right, + &input_pad); + } else { + input_pad = in_batch; + } + // tile input and transform + math::winograd_transform_input(input_pad, &transformed_input); + // caculate output + math::winograd_transform_output(transformed_input, *filter, + output); + } +} + +template +inline void DepthwiseConv3x3(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.Filter(); + Tensor *output = param.Output(); + output->mutable_data(); + + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = static_cast(input->dims()[0]); + Tensor input_pad; + math::PadFunctor pad; + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + if (paddings[0] || paddings[1]) { + framework::DDim pad_shape = in_batch.dims(); + pad_shape[2] += 2 * paddings[0]; + pad_shape[3] += 2 * paddings[1]; + input_pad.mutable_data(pad_shape); + pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1], + &input_pad); + } else { + input_pad = in_batch; + } + if (strides[0] == 1) { + math::DepthwiseConv3x3s1(input_pad, *filter, &out_batch); + } else if (strides[0] == 2) { + math::DepthwiseConv3x3s2(input_pad, *filter, &out_batch); } else { - ConvBasic(param); + // math::DepthwiseConv3x3(input_pad, *filter, + // &out_batch); + PADDLE_MOBILE_THROW_EXCEPTION( + "Depthwise conv with generic strides has not been implemented."); } } } diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index 7c31eed19693d20084e25daa485a0553d5d795f2..e3fe37e19bd10ec5cbbfb59b556df5af9fecd09e 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -17,7 +17,7 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index c6300f96e1b999c45538417c7b513068697ad4dd..4c8cf393345d16e79799bc5ce9ecd1be1fc0a15a 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -16,13 +16,15 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" #include "operators/op_param.h" + namespace paddle_mobile { namespace operators { + void ConvBNReluBasic(const FusionConvBNReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 73170bdab922a46831334307aebc8af210ddfb73..b48b03491bab9594f36cad0b21485ae72c8c3c31 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -15,10 +15,9 @@ limitations under the License. */ #ifdef DEPTHWISECONV_OP #pragma once -#include #include #include "operators/kernel/central-arm-func/conv_arm_func.h" - +#include "operators/math/depthwise_conv3x3.h" #include "operators/op_param.h" namespace paddle_mobile { @@ -44,7 +43,7 @@ void DepthwiseConvCompute(const ConvParam ¶m) { Bias, false); } else { - ConvBasic(param); + GemmConv(param); } } diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index b60bf9b4d6df9d85cc2fbe378a3904c2d13e5e60..a5c08c26237345320fef89e8f0fdd148534dfc8a 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -16,13 +16,15 @@ limitations under the License. */ #pragma once #include -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" #include "operators/op_param.h" + namespace paddle_mobile { namespace operators { + void DWConvBNReluBasic(const FusionDWConvBNReluParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); diff --git a/src/operators/kernel/conv_add_kernel.h b/src/operators/kernel/conv_add_kernel.h index 4e9ff0853f1d502ebb4dc4ef3641d0a879f32b60..140d0475a8ee2f017a7c587c38429ccbb2edd387 100644 --- a/src/operators/kernel/conv_add_kernel.h +++ b/src/operators/kernel/conv_add_kernel.h @@ -24,7 +24,7 @@ limitations under the License. */ #include "framework/ddim.h" #include "framework/operator.h" #include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/dequant_add_bn_relu_kernel.h b/src/operators/kernel/dequant_add_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7138e5c415caca6766913f9959bd41def0943d34 --- /dev/null +++ b/src/operators/kernel/dequant_add_bn_relu_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class FusionDequantAddBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantAddBNReluParam ¶m); + bool Init(FusionDequantAddBNReluParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp similarity index 96% rename from src/operators/math/depthwise_conv_3x3.cpp rename to src/operators/math/depthwise_conv3x3.cpp index f5bcf1202391911e2bf4b891032576a4e4ded064..39b9b8d3f1c5c2bf09a3db5de5216dd1a08b491a 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "operators/math/depthwise_conv_3x3.h" + +#include "operators/math/depthwise_conv3x3.h" +#include #if __ARM_NEON #include #endif -#include namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3(const Tensor *input, vector strides, - vector paddings, const Tensor *filter, Tensor *bias, - Tensor *output, bool if_bias) { + +void DepthwiseConv3x3(const framework::Tensor *input, + const std::vector &strides, + const std::vector &paddings, + const framework::Tensor *filter, framework::Tensor *bias, + framework::Tensor *output, bool if_bias) { const int batch_size = input->dims()[0]; const int input_height = input->dims()[2]; @@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector strides, for (int pw = 0; pw < output_width; pw++) { hstart = ph * stride_height - padding_height; wstart = pw * stride_width - padding_width; - hend = min(hstart + _kernel_size, input_height + padding_height); - wend = min(wstart + _kernel_size, input_width + padding_width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - hend = min(hend, input_height); - wend = min(wend, input_width); + hend = std::min(hstart + _kernel_size, input_height + padding_height); + wend = std::min(wstart + _kernel_size, input_width + padding_width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, input_height); + wend = std::min(wend, input_width); pos1 = input_data + hstart * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart; @@ -244,12 +248,14 @@ void DepthwiseConv3x3(const Tensor *input, vector strides, } } -void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor *bias, bool if_bias) { +void DepthwiseConv3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor *bias, + bool if_bias) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); - float *output_data = output->data(); + float *output_data = output->mutable_data(); const float *bias_data; if (if_bias) { bias_data = bias->data(); @@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu) { +void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); @@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } /// w!=h not fix -void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu) { +void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu) { #if __ARM_NEON const int batch_size = input->dims()[0]; @@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, for (int pw = 0; pw < output_width; pw++) { hstart = ph * stride_height - padding_height; wstart = pw * stride_width - padding_width; - hend = min(hstart + _kernel_size, input_height + padding_height); - wend = min(wstart + _kernel_size, input_width + padding_width); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - hend = min(hend, input_height); - wend = min(wend, input_width); + hend = std::min(hstart + _kernel_size, input_height + padding_height); + wend = std::min(wstart + _kernel_size, input_width + padding_width); + hstart = std::max(hstart, 0); + wstart = std::max(wstart, 0); + hend = std::min(hend, input_height); + wend = std::min(wend, input_width); pos1 = input_data + hstart * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart; @@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias) { +void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); @@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu) { +void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu) { #if __ARM_NEON // #ifdef _OPENMP // const float *newscale_data = new_scale->data(); @@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, #endif } -void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias) { +void DepthwiseConv3x3s2p0(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias) { #if __ARM_NEON const int batch_size = static_cast(input->dims()[0]); diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h new file mode 100644 index 0000000000000000000000000000000000000000..72cadaf21553a428e1479d5548d2aa5f4fcdf90c --- /dev/null +++ b/src/operators/math/depthwise_conv3x3.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "framework/tensor.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +void DepthwiseConv3x3(const framework::Tensor *input, + const std::vector &strides, + const std::vector &paddings, + const framework::Tensor *filter, framework::Tensor *bias, + framework::Tensor *output, bool if_bias); + +void DepthwiseConv3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor *bias, + bool if_bias); + +void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu); + +void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu); + +void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias); + +void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, + const framework::Tensor *new_scale, + const framework::Tensor *new_bias, + bool if_relu); + +void DepthwiseConv3x3s2p0(const framework::Tensor *input, + const framework::Tensor *filter, + framework::Tensor *output, framework::Tensor bias, + bool if_bias); + +// TODO(hjchen2) need to be implemented +// template +// void DepthwiseConv3x3(const framework::Tensor *input, +// const framework::Tensor *filter, +// const std::vector &strides, +// framework::Tensor *output); + +template +void DepthwiseConv3x3s1(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output); + +template +void DepthwiseConv3x3s2(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddd8f79f7ce350e048585917f96d82639d4ea951 --- /dev/null +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -0,0 +1,1207 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/math/depthwise_conv3x3.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +// template<> +// void DepthwiseConv3x3( +// const framework::Tensor *input, const framework::Tensor *filter, +// const std::vector &strides, framework::Tensor *output) { +// PADDLE_MOBILE_THROW_EXCEPTION( +// "Depthwise conv with generic strides has not been implemented."); +// } + +template <> +void DepthwiseConv3x3s1(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output) { + const int8_t *input_data = input.data(); + const int8_t *filter_data = filter.data(); + int32_t *out_data = output->mutable_data(); + // make sure that batch size is 1 + int input_c = input.dims()[1]; + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; +#if __aarch64__ + // TODO(hjchen2) +#else + #pragma omp parallel for + for (int g = 0; g < input_c; ++g) { + const int8_t* input_ptr = input_data + g * image_size; + const int8_t* filter_ptr = filter_data + g * 9; + int32_t* output_ptr = out_data + g * out_image_size; + int loops = (input_w - 2) / 6; + int remain = input_w - 2 - loops * 6; + for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + const int8_t* input_ptr4 = input_ptr3 + input_w; + const int8_t* input_ptr5 = input_ptr4 + input_w; + int32_t* output_ptr0 = output_ptr + h * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + int32_t* output_ptr2 = output_ptr1 + output_w; + int32_t* output_ptr3 = output_ptr2 + output_w; + int loop = loops; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_4h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" + + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vld1.32 {d10}, [%[input_ptr4]], r0 \n" + "vld1.32 {d11}, [%[input_ptr5]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + // store row 1 + "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" + + "vmlal.s16 q14, d14, d3 \n" + "vmlal.s16 q14, d16, d4 \n" + "vmlal.s16 q14, d18, d5 \n" + "vmlal.s16 q15, d15, d3 \n" + "vmlal.s16 q15, d17, d4 \n" + "vmlal.s16 q15, d19, d5 \n" + + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d16, d7 \n" + "vmlal.s16 q14, d18, d8 \n" + "vmlal.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d17, d7 \n" + "vmlal.s16 q15, d19, d8 \n" + // store row 2 + "vst1.32 {d28-d30}, [%[output_ptr2]]! \n" + + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 3 + "vst1.32 {d20-d22}, [%[output_ptr3]]! \n" + + "subs %[loop], #1 \n" + "bne loop_4h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr1]] \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr2]] \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr3]] \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + + "vmlal.s16 q14, d14, d3 \n" + "vmlal.s16 q14, d16, d4 \n" + "vmlal.s16 q14, d18, d5 \n" + "vmlal.s16 q15, d15, d3 \n" + "vmlal.s16 q15, d17, d4 \n" + "vmlal.s16 q15, d19, d5 \n" + + "vmull.s16 q5, d14, d0 \n" + "vmlal.s16 q5, d16, d1 \n" + "vmlal.s16 q5, d18, d2 \n" + "vld1.32 {d9}, [%[input_ptr4]] \n" + "vmull.s16 q6, d15, d0 \n" + "vmlal.s16 q6, d17, d1 \n" + "vmlal.s16 q6, d19, d2 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d16, d7 \n" + "vmlal.s16 q14, d18, d8 \n" + "vmlal.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d17, d7 \n" + "vmlal.s16 q15, d19, d8 \n" + + "vmlal.s16 q5, d14, d3 \n" + "vmlal.s16 q5, d16, d4 \n" + "vmlal.s16 q5, d18, d5 \n" + "vld1.32 {d9}, [%[input_ptr5]] \n" + "vmlal.s16 q6, d15, d3 \n" + "vmlal.s16 q6, d17, d4 \n" + "vmlal.s16 q6, d19, d5 \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q5, d14, d6 \n" + "vmlal.s16 q5, d16, d7 \n" + "vmlal.s16 q5, d18, d8 \n" + "vmlal.s16 q6, d15, d6 \n" + "vmlal.s16 q6, d17, d7 \n" + "vmlal.s16 q6, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_4h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "vst1.32 {q14}, [%[output_ptr2]]! \n" + "vst1.32 {q5}, [%[output_ptr3]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d12[0]}, [%[output_ptr3]]! \n" + "b end_%= \n" + + "store_4h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_4h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "vst1.32 {d28}, [%[output_ptr2]]! \n" + "vst1.32 {d10}, [%[output_ptr3]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d11[0]}, [%[output_ptr3]]! \n" + "b end_%= \n" + + "store_4h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d10[0]}, [%[output_ptr3]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3), + [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), + [loop] "+r"(loop) + : [remain] "r"(remain) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); + } + // remain height + int start_h = (input_h - 2) & 0xFFFC; + for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + int32_t* output_ptr0 = output_ptr + h * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + int loop = loops; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_2h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + // store row 1 + "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" + + "subs %[loop], #1 \n" + "bne loop_2h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vld1.32 {d10}, [%[input_ptr1]] \n" + "vld1.32 {d11}, [%[input_ptr2]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld1.32 {d9}, [%[input_ptr3]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_2h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_2h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [loop] "+r"(loop) + : [remain] "r"(remain) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "r0"); + } + + start_h = (input_h - 2) & 0xFFFE; + if (start_h < input_h - 2) { + const int8_t* input_ptr0 = input_ptr + start_h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + int32_t* output_ptr0 = output_ptr + start_h * output_w; + int loop = loops; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_1h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + // store row 0, reuse q10/q11 + "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" + + "subs %[loop], #1 \n" + "bne loop_1h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld1.32 {d9}, [%[input_ptr0]] \n" + "vld1.32 {d10}, [%[input_ptr1]] \n" + "vld1.32 {d11}, [%[input_ptr2]] \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_1h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), + [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), + [loop] "+r"(loop) + : [remain] "r"(remain) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "r0"); + } + } +#endif // __aarch64__ +} + +template <> +void DepthwiseConv3x3s2(const framework::Tensor &input, + const framework::Tensor &filter, + framework::Tensor *output) { + const int8_t *input_data = input.data(); + const int8_t *filter_data = filter.data(); + int32_t *out_data = output->mutable_data(); + // make sure that batch size is 1 + int input_c = input.dims()[1]; + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_c = output->dims()[1]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; +#if __aarch64__ + // TODO(hjchen2) +#else + #pragma omp parallel for + for (int g = 0; g < input_c; ++g) { + const int8_t* input_ptr = input_data + g * image_size; + const int8_t* filter_ptr = filter_data + g * 9; + int32_t* output_ptr = out_data + g * out_image_size; + int loops = output_w / 6; + int remain = output_w - loops * 6; + for (int h = 0; h < input_h - 6 /*(input_h - 1) - 5*/; h += 6) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + const int8_t* input_ptr3 = input_ptr2 + input_w; + const int8_t* input_ptr4 = input_ptr3 + input_w; + const int8_t* input_ptr5 = input_ptr4 + input_w; + const int8_t* input_ptr6 = input_ptr5 + input_w; + int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w; + int32_t* output_ptr1 = output_ptr0 + output_w; + int32_t* output_ptr2 = output_ptr1 + output_w; + int loop = loops; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #12 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_3h6w_%=: \n" + "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" + "vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n" + "vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q11, d16, d0 \n" + "vmlal.s16 q11, d18, d1 \n" + "vmlal.s16 q11, d20, d2 \n" + "vmull.s16 q12, d17, d0 \n" + "vmlal.s16 q12, d19, d1 \n" + "vmlal.s16 q12, d21, d2 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q11, d16, d3 \n" + "vmlal.s16 q11, d18, d4 \n" + "vmlal.s16 q11, d20, d5 \n" + "vmlal.s16 q12, d17, d3 \n" + "vmlal.s16 q12, d19, d4 \n" + "vmlal.s16 q12, d21, d5 \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, d6 \n" + "vmlal.s16 q11, d18, d7 \n" + "vmlal.s16 q11, d20, d8 \n" + "vmlal.s16 q12, d17, d6 \n" + "vmlal.s16 q12, d19, d7 \n" + "vmlal.s16 q12, d21, d8 \n" + // store row 0, reuse q11/q12 + "vst1.32 {d22-d24}, [%[output_ptr0]]! \n" + + "vmull.s16 q13, d16, d0 \n" + "vmlal.s16 q13, d18, d1 \n" + "vmlal.s16 q13, d20, d2 \n" + "vmull.s16 q14, d17, d0 \n" + "vmlal.s16 q14, d19, d1 \n" + "vmlal.s16 q14, d21, d2 \n" + + "vld2.8 {d10, d11}, [%[input_ptr3]], r0 \n" + "vld2.8 {d12, d13}, [%[input_ptr4]], r0 \n" + "vld2.8 {d14, d15}, [%[input_ptr5]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmlal.s16 q13, d16, d3 \n" + "vmlal.s16 q13, d18, d4 \n" + "vmlal.s16 q13, d20, d5 \n" + "vmlal.s16 q14, d17, d3 \n" + "vmlal.s16 q14, d19, d4 \n" + "vmlal.s16 q14, d21, d5 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q13, d16, d6 \n" + "vmlal.s16 q13, d18, d7 \n" + "vmlal.s16 q13, d20, d8 \n" + "vmlal.s16 q14, d17, d6 \n" + "vmlal.s16 q14, d19, d7 \n" + "vmlal.s16 q14, d21, d8 \n" + // store row 1 + "vst1.32 {d26-d28}, [%[output_ptr1]]! \n" + + "vmull.s16 q11, d16, d0 \n" + "vmlal.s16 q11, d18, d1 \n" + "vmlal.s16 q11, d20, d2 \n" + "vmull.s16 q12, d17, d0 \n" + "vmlal.s16 q12, d19, d1 \n" + "vmlal.s16 q12, d21, d2 \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, d3 \n" + "vmlal.s16 q11, d18, d4 \n" + "vmlal.s16 q11, d20, d5 \n" + "vmlal.s16 q12, d17, d3 \n" + "vmlal.s16 q12, d19, d4 \n" + "vmlal.s16 q12, d21, d5 \n" + + "vld2.8 {d10, d11}, [%[input_ptr6]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmlal.s16 q11, d16, d6 \n" + "vmlal.s16 q11, d18, d7 \n" + "vmlal.s16 q11, d20, d8 \n" + "vmlal.s16 q12, d17, d6 \n" + "vmlal.s16 q12, d19, d7 \n" + "vmlal.s16 q12, d21, d8 \n" + // store row 2 + "vst1.32 {d22-d24}, [%[output_ptr2]]! \n" + + "subs %[loop], #1 \n" + "bne loop_3h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "vld2.8 {d10, d11}, [%[input_ptr0]] \n" + "vld2.8 {d12, d13}, [%[input_ptr1]] \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmull.s16 q10, d14, d0 \n" + "vmlal.s16 q10, d16, d1 \n" + "vmlal.s16 q10, d18, d2 \n" + "vmull.s16 q11, d15, d0 \n" + "vmlal.s16 q11, d17, d1 \n" + "vmlal.s16 q11, d19, d2 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d12 \n" + "vmovl.s8 q8, d13 \n" + "vmlal.s16 q10, d14, d3 \n" + "vmlal.s16 q10, d16, d4 \n" + "vmlal.s16 q10, d18, d5 \n" + "vmlal.s16 q11, d15, d3 \n" + "vmlal.s16 q11, d17, d4 \n" + "vmlal.s16 q11, d19, d5 \n" + + "vld2.8 {d10, d11}, [%[input_ptr2]] \n" + "vld2.8 {d12, d13}, [%[input_ptr3]] \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmlal.s16 q10, d14, d6 \n" + "vmlal.s16 q10, d16, d7 \n" + "vmlal.s16 q10, d18, d8 \n" + "vmlal.s16 q11, d15, d6 \n" + "vmlal.s16 q11, d17, d7 \n" + "vmlal.s16 q11, d19, d8 \n" + + "vmull.s16 q12, d14, d0 \n" + "vmlal.s16 q12, d16, d1 \n" + "vmlal.s16 q12, d18, d2 \n" + "vmull.s16 q13, d15, d0 \n" + "vmlal.s16 q13, d17, d1 \n" + "vmlal.s16 q13, d19, d2 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d12 \n" + "vmovl.s8 q8, d13 \n" + "vmlal.s16 q12, d14, d3 \n" + "vmlal.s16 q12, d16, d4 \n" + "vmlal.s16 q12, d18, d5 \n" + "vmlal.s16 q13, d15, d3 \n" + "vmlal.s16 q13, d17, d4 \n" + "vmlal.s16 q13, d19, d5 \n" + + "vld2.8 {d10, d11}, [%[input_ptr4]] \n" + "vld2.8 {d12, d13}, [%[input_ptr5]] \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmlal.s16 q12, d14, d6 \n" + "vmlal.s16 q12, d16, d7 \n" + "vmlal.s16 q12, d18, d8 \n" + "vmlal.s16 q13, d15, d6 \n" + "vmlal.s16 q13, d17, d7 \n" + "vmlal.s16 q13, d19, d8 \n" + + "vmull.s16 q14, d14, d0 \n" + "vmlal.s16 q14, d16, d1 \n" + "vmlal.s16 q14, d18, d2 \n" + "vmull.s16 q15, d15, d0 \n" + "vmlal.s16 q15, d17, d1 \n" + "vmlal.s16 q15, d19, d2 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d12 \n" + "vmovl.s8 q8, d13 \n" + "vmlal.s16 q14, d14, d3 \n" + "vmlal.s16 q14, d16, d4 \n" + "vmlal.s16 q14, d18, d5 \n" + "vmlal.s16 q15, d15, d3 \n" + "vmlal.s16 q15, d17, d4 \n" + "vmlal.s16 q15, d19, d5 \n" + + "vld2.8 {d10, d11}, [%[input_ptr6]] \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmlal.s16 q14, d14, d6 \n" + "vmlal.s16 q14, d16, d7 \n" + "vmlal.s16 q14, d18, d8 \n" + "vmlal.s16 q15, d15, d6 \n" + "vmlal.s16 q15, d17, d7 \n" + "vmlal.s16 q15, d19, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_3h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "vst1.32 {q14}, [%[output_ptr2]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" + "b end_%= \n" + + "store_3h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_3h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "vst1.32 {d28}, [%[output_ptr2]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" + "b end_%= \n" + + "store_3h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [output_ptr2] "+r"(output_ptr2), [input_ptr6] "+r"(input_ptr6), + [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), + [loop] "+r"(loop) + : [remain] "r"(remain) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); + } + + int start_h = (output_h / 3) * 6; + for (int h = start_h; h < input_h - 2 /*(input_h - 1) - 1*/; h += 2) { + const int8_t* input_ptr0 = input_ptr + h * input_w; + const int8_t* input_ptr1 = input_ptr0 + input_w; + const int8_t* input_ptr2 = input_ptr1 + input_w; + int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w; + int loop = loops; + asm volatile( + "vld1.32 {q0}, [%[filter_ptr]] \n" + "vmovl.s8 q14, d0 \n" + "vmovl.s8 q15, d1 \n" + "vdup.s16 d0, d28[0] \n" + "vdup.s16 d1, d28[1] \n" + "vdup.s16 d2, d28[2] \n" + "vdup.s16 d3, d28[3] \n" + "vdup.s16 d4, d29[0] \n" + "vdup.s16 d5, d29[1] \n" + "vdup.s16 d6, d29[2] \n" + "vdup.s16 d7, d29[3] \n" + "vdup.s16 d8, d30[0] \n" + : + : [filter_ptr] "r"(filter_ptr) + : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); + asm volatile( + "mov r0, #12 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + // loop 6 widths + "loop_1h6w_%=: \n" + "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" + "vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n" + "vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q11, d16, d0 \n" + "vmlal.s16 q11, d18, d1 \n" + "vmlal.s16 q11, d20, d2 \n" + "vmull.s16 q12, d17, d0 \n" + "vmlal.s16 q12, d19, d1 \n" + "vmlal.s16 q12, d21, d2 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q11, d16, d3 \n" + "vmlal.s16 q11, d18, d4 \n" + "vmlal.s16 q11, d20, d5 \n" + "vmlal.s16 q12, d17, d3 \n" + "vmlal.s16 q12, d19, d4 \n" + "vmlal.s16 q12, d21, d5 \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, d6 \n" + "vmlal.s16 q11, d18, d7 \n" + "vmlal.s16 q11, d20, d8 \n" + "vmlal.s16 q12, d17, d6 \n" + "vmlal.s16 q12, d19, d7 \n" + "vmlal.s16 q12, d21, d8 \n" + // store row 0 + "vst1.32 {d22-d24}, [%[output_ptr0]]! \n" + + "subs %[loop], #1 \n" + "bne loop_1h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "vld2.8 {d10, d11}, [%[input_ptr0]] \n" + "vld2.8 {d12, d13}, [%[input_ptr1]] \n" + "vld2.8 {d14, d15}, [%[input_ptr2]] \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q11, d16, d0 \n" + "vmlal.s16 q11, d18, d1 \n" + "vmlal.s16 q11, d20, d2 \n" + "vmull.s16 q12, d17, d0 \n" + "vmlal.s16 q12, d19, d1 \n" + "vmlal.s16 q12, d21, d2 \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q11, d16, d3 \n" + "vmlal.s16 q11, d18, d4 \n" + "vmlal.s16 q11, d20, d5 \n" + "vmlal.s16 q12, d17, d3 \n" + "vmlal.s16 q12, d19, d4 \n" + "vmlal.s16 q12, d21, d5 \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, d6 \n" + "vmlal.s16 q11, d18, d7 \n" + "vmlal.s16 q11, d20, d8 \n" + "vmlal.s16 q12, d17, d6 \n" + "vmlal.s16 q12, d19, d7 \n" + "vmlal.s16 q12, d21, d8 \n" + + "cmp %[remain], #4 \n" + "blt store_1h2w_%= \n" + "vst1.32 {q11}, [%[output_ptr0]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d24[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d22}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d23[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" + : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), + [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), + [loop] "+r"(loop) + : [remain] "r"(remain) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "r0"); + } + } +#endif // __aarch64__ +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv_3x3.h b/src/operators/math/depthwise_conv_3x3.h deleted file mode 100644 index b146b88e737a07ea08250315fc94653f63d2ad05..0000000000000000000000000000000000000000 --- a/src/operators/math/depthwise_conv_3x3.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "framework/tensor.h" -#include "operators/math/conv_func.h" - -namespace paddle_mobile { -namespace operators { -namespace math { -using framework::Tensor; -using std::max; -using std::min; -using std::vector; - -void DepthwiseConv3x3(const Tensor *input, vector strides, - vector paddings, const Tensor *filter, Tensor *bias, - Tensor *output, bool if_bias); -void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor *bias, bool if_bias); -void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu); -void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu); -void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias); -void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, - Tensor *output, const Tensor *new_scale, - const Tensor *new_bias, bool if_relu); - -void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, - Tensor *output, Tensor bias, bool if_bias); -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index d3e6de3134ff91f47c66c927194a5ba688e931b0..c17b2a5e4df0f0ca88da79a9ce55c2ecae0316b5 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -26,79 +26,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { -/*int MC = 0; -int KC = 0; -int NC = 0; - -float *packedA; -float *packedB; -float *packedC; -float *zero; - -typedef void (*FnPack)(int, int, int, const float *, int, float *); -typedef void (*FnAddDot)(int, const float *, const float *, float *, int); - -FnPack procPackA; -FnPack procPackB; -FnAddDot procAddDot;*/ - -/* -// 将A矩阵分块复制到连续内存(ColMajor) -void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - int i, j; - const float *Aij; - for (i = 0; i < m - m_tail; i += MR) { - for (j = 0; j < k; ++j) { - Aij = &A(i, j); - *buffer++ = *Aij; - *buffer++ = *(Aij + 1); - *buffer++ = *(Aij + 2); - *buffer++ = *(Aij + 3); - } - } - if (m_tail != 0) { - for (j = 0; j < k; ++j) { - Aij = &A(m - m_tail, j); - for (i = 0; i < m_tail; ++i) { - *buffer++ = *(Aij + i); - } - for (i = m_tail; i < MR; ++i) { - *buffer++ = 0; - } - } - } -} - -// 将B矩阵分块复制到连续内存(ColMajor) -void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { - int i, j; - const float *Bj, *Bj1, *Bj2, *Bj3; - for (j = 0; j < n - n_tail; j += NR) { - Bj = &B(0, j); - Bj1 = &B(0, j + 1); - Bj2 = &B(0, j + 2); - Bj3 = &B(0, j + 3); - for (i = 0; i < k; ++i) { - *buffer++ = *Bj++; - *buffer++ = *Bj1++; - *buffer++ = *Bj2++; - *buffer++ = *Bj3++; - } - } - if (n_tail != 0) { - for (i = 0; i < k; ++i) { - for (int j = n - n_tail; j < n; ++j) { - *buffer++ = B(i, j); - } - for (int j = n; j < n + (NR - n_tail); ++j) { - *buffer++ = 0; - } - } - } -} -*/ // 将A矩阵分块复制到连续内存(RowMajor) void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 47055ec4f24e5b5b226c1f084bb2253d2ebb77c7..9449ad70819f2ea114fac8848f6ee023871d47f2 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -22,6 +22,70 @@ namespace paddle_mobile { namespace operators { namespace math { +void ExtractToImg(const float *im_data, float *col_data, const int im_height, + const int im_width, const int col_height, const int col_width, + const int padding_h, const int padding_w, const int stride_h, + const int stride_w, const int kh, const int kw) { + int h = padding_h - kh; + int w = padding_w - kw; + int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; + int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0; + int start_height = kh + col_start_height * stride_h - padding_h; + int start_width = kw + col_start_width * stride_w - padding_w; + + int end_height = (col_height - col_start_height) * stride_h + start_height; + end_height = end_height > im_height ? im_height : end_height; + int end_width = (col_width - col_start_width) * stride_w + start_width; + end_width = end_width > im_width ? im_width : end_width; + int extract = (end_width - start_width + stride_w - 1) / stride_w; + + im_data += start_height * im_width + start_width; + col_data += col_start_height * col_width + col_start_width; + + for (int i = start_height; i < end_height; i += stride_h) { + if (stride_w == 1) { + memcpy(col_data, im_data, extract * sizeof(float)); + } else if (stride_w == 2) { + int s = 0; +#if __ARM_NEON + for (; s < extract - 3; s += 4) { + float32x4x2_t img = vld2q_f32(im_data + s * 2); + vst1q_f32(col_data + s, img.val[0]); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s * 2]; + } + } else if (stride_w == 3) { + int s = 0; +#if __ARM_NEON + for (; s < extract - 3; s += 4) { + float32x4x3_t img = vld3q_f32(im_data + s * 3); + vst1q_f32(col_data + s, img.val[0]); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s * 3]; + } + } else if (stride_w == 4) { + int s = 0; +#if __ARM_NEON + for (; s < extract - 3; s += 4) { + float32x4x4_t img = vld4q_f32(im_data + s * 4); + vst1q_f32(col_data + s, img.val[0]); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s * 4]; + } + } else { + PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4."); + } + im_data += im_width * stride_h; + col_data += col_width; + } +} + /* * im = [input_channels, input_height, input_width] * col = @@ -363,7 +427,27 @@ void Im2ColFunctor::operator()( col_data += 9 * oosize; im_data += isize * isize; } + } else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { + int im_spatial_size = im_height * im_width; + int col_spatial_size = col_height * col_width; + // pad 0 + memset(col_data, 0, col->numel() * sizeof(float)); + #pragma omp parallel for + for (int ic = 0; ic < im_channels; ++ic) { + const float *local_im_data = im_data + ic * im_spatial_size; + float *local_col_data = + col_data + ic * filter_height * filter_width * col_spatial_size; + for (int kh = 0; kh < filter_height; ++kh) { + for (int kw = 0; kw < filter_width; ++kw) { + ExtractToImg(local_im_data, local_col_data, im_height, im_width, + col_height, col_width, padding[0], padding[1], stride[0], + stride[1], kh, kw); + local_col_data += col_spatial_size; + } + } + } } else { +#endif for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; int h_offset = (c / filter_width) % filter_height; @@ -382,25 +466,7 @@ void Im2ColFunctor::operator()( } } } - } -#else - for (int c = 0; c < channels_col; ++c) { - int w_offset = c % filter_width; - int h_offset = (c / filter_width) % filter_height; - int c_im = c / (filter_width * filter_height); - for (int h = 0; h < col_height; ++h) { - int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; - for (int w = 0; w < col_width; ++w) { - int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; - int col_idx = (c * col_height + h) * col_width + w; - int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; - - col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || - im_col_idx < 0 || im_col_idx >= im_width) - ? static_cast(0) - : im_data[im_idx]; - } - } +#if __ARM_NEON } #endif } @@ -489,21 +555,26 @@ void Im2ColFunctor::operator()( int channels_col = im_channels * filter_height * filter_width; const int8_t *im_data = im.data(); - int8_t *col_data = col->data(); + int8_t *col_data = col->mutable_data(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { + int im_spatial_size = im_height * im_width; + int col_spatial_size = col_height * col_width; // pad 0 memset(col_data, 0, col->numel() * sizeof(int8_t)); + #pragma omp parallel for for (int ic = 0; ic < im_channels; ++ic) { + const int8_t *local_im_data = im_data + ic * im_spatial_size; + int8_t *local_col_data = + col_data + ic * filter_height * filter_width * col_spatial_size; for (int kh = 0; kh < filter_height; ++kh) { for (int kw = 0; kw < filter_width; ++kw) { - ExtractToImg(im_data, col_data, im_height, im_width, col_height, - col_width, padding[0], padding[1], stride[0], stride[1], - kh, kw); - col_data += col_height * col_width; + ExtractToImg(local_im_data, local_col_data, im_height, im_width, + col_height, col_width, padding[0], padding[1], stride[0], + stride[1], kh, kw); + local_col_data += col_spatial_size; } } - im_data += im_height * im_width; } } else { #endif diff --git a/src/operators/math/pad.cpp b/src/operators/math/pad.cpp index d8153c445b007e8c5a902301e2724f22c8f6add1..49fede1eb30d8cabcabb4dd4828e43eb8900a2f9 100644 --- a/src/operators/math/pad.cpp +++ b/src/operators/math/pad.cpp @@ -21,10 +21,12 @@ namespace math { template class PadFunctor { public: - void operator()(const framework::Tensor &input, const int pad_h, - const int pad_w, framework::Tensor *output) { + void operator()(const framework::Tensor &input, const int pad_top, + const int pad_bottom, const int pad_left, const int pad_right, + framework::Tensor *output) { const T *in_data = input.data(); T *out_data = output->mutable_data(); + // should check output shape is valid for such pad parameters const framework::DDim &input_shape = input.dims(); const framework::DDim &output_shape = output->dims(); // fill output with 0 @@ -32,13 +34,13 @@ class PadFunctor { // should make sure the shape of output is match with input for (int i = 0; i < input_shape[0]; ++i) { for (int c = 0; c < input_shape[1]; ++c) { - out_data += pad_h * output_shape[3]; + out_data += pad_top * output_shape[3]; for (int h = 0; h < input_shape[2]; ++h) { - memcpy(out_data + pad_w, in_data, sizeof(T) * input_shape[3]); + memcpy(out_data + pad_left, in_data, sizeof(T) * input_shape[3]); out_data += output_shape[3]; in_data += input_shape[3]; } - out_data += pad_h * output_shape[3]; + out_data += pad_bottom * output_shape[3]; } } } diff --git a/src/operators/math/pad.h b/src/operators/math/pad.h index 0f5a4b89674f92746f75bb1e4f9364d5a16fdba2..9031caf36a872d091b333570320955e7fc30f78a 100644 --- a/src/operators/math/pad.h +++ b/src/operators/math/pad.h @@ -22,8 +22,9 @@ namespace math { template class PadFunctor { public: - void operator()(const framework::Tensor &input, const int pad_h, - const int pad_w, framework::Tensor *output); + void operator()(const framework::Tensor &input, const int pad_top, + const int pad_bottom, const int pad_left, const int pad_right, + framework::Tensor *output); }; } // namespace math diff --git a/src/operators/math/winograd/winograd_transform.h b/src/operators/math/winograd/winograd_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..599a9b9233becdc0d7bdc7f8ef12b9d4cccd60d0 --- /dev/null +++ b/src/operators/math/winograd/winograd_transform.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#pragma once + +#include "framework/tensor.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +void winograd_transform_weight(const framework::Tensor &weight, + framework::Tensor *output); + +template +void winograd_transform_input(const framework::Tensor &input, + framework::Tensor *output); + +template +void winograd_transform_output(const framework::Tensor &input, + const framework::Tensor &weight, + framework::Tensor *output); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d9a7cb3b51c7abcd029f316164f5d3d88ec24be0 --- /dev/null +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -0,0 +1,1366 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn +// project. + +#ifdef CONV_OP + +#ifndef __aarch64__ + +#include "operators/math/pad.h" +#include "operators/math/winograd/winograd_transform.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template <> +void winograd_transform_weight<8, 3>(const framework::Tensor &weight, + framework::Tensor *output) { + /* + * w0 = g0 + * w1 = ((g0 + g2) + g1) * (-2.0 / 9) + * w2 = ((g0 + g2) - g1) * (-2.0 / 9) + * w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90) + * w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90) + * w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180) + * w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180) + * w7 = g2 + */ + // weight shape is [out_channel, in_channel, kernel_h, kernel_w] + // package weight into [roundup(out_channel/4), 64, in_channel, 4] tiles + int out_channel = weight.dims()[0]; + int in_channel = weight.dims()[1]; + // reshape and alloc transformed weight + framework::DDim transformed_shape = framework::make_ddim( + std::vector{(out_channel + 3) / 4, 64, in_channel, 4}); + float *trans_outptr = output->mutable_data(transformed_shape); + memset(trans_outptr, 0, output->numel() * sizeof(float)); + + const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180}; + const float *inptr = weight.data(); + int remain_start = out_channel & 0xFFFC; +#if 0 + remain_start = 0; +#else + #pragma omp parallel for + for (int oc = 0; oc < out_channel - 3; oc += 4) { + float gw[96]; // gw[3][8][4] + const float *inptr0 = inptr + oc * in_channel * 9; + const float *inptr1 = inptr + (oc + 1) * in_channel * 9; + const float *inptr2 = inptr + (oc + 2) * in_channel * 9; + const float *inptr3 = inptr + (oc + 3) * in_channel * 9; + // oc * 64 * in_channel + float *outptr = trans_outptr + ((oc * in_channel) << 6); + for (int ic = 0; ic < in_channel; ++ic) { + float *gw_ptr = gw; + asm volatile( + "vld1.32 {d0-d1}, [%[tm_ptr]] \n" + + "mov r0, #24 \n" + "vld1.32 {d2-d5}, [%[inptr0]], r0 \n" + "vld1.32 {d6-d9}, [%[inptr1]], r0 \n" + "vld1.32 {d10-d13}, [%[inptr2]], r0 \n" + "vld1.32 {d14-d17}, [%[inptr3]], r0 \n" + "vtrn.32 q1, q3 \n" + "vtrn.32 q2, q4 \n" + "vtrn.32 q5, q7 \n" + "vtrn.32 q6, q8 \n" + "vswp.32 d3, d10 \n" + "vswp.32 d7, d14 \n" + "vswp.32 d5, d12 \n" + "vswp.32 d9, d16 \n" + + // q1: g0, q3: g1, q5: g2 + "vst1.32 {d2-d3}, [%[gw_ptr]]! \n" + "vadd.f32 q9, q1, q5 \n" + "vadd.f32 q10, q9, q3 \n" + "vsub.f32 q11, q9, q3 \n" + "vmul.f32 q10, q10, d0[1] \n" + "vst1.32 {d20-d21}, [%[gw_ptr]]! \n" + "vmul.f32 q11, q11, d0[1] \n" + "vst1.32 {d22-d23}, [%[gw_ptr]]! \n" + + "vmul.f32 q9, q1, d0[0] \n" + "vmul.f32 q9, q9, d0[0] \n" // 4 * g0 + "vmul.f32 q10, q3, d0[0] \n" // 2 * g1 + "vmul.f32 q11, q5, d0[0] \n" + "vmul.f32 q11, q11, d0[0] \n" // 4 * g2 + + "vadd.f32 q12, q1, q11 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + + "vadd.f32 q12, q5, q9 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + + "vst1.32 {d10-d11}, [%[gw_ptr]]! \n" + + // q7: g0, q2: g1, q4: g2 + "vst1.32 {d14-d15}, [%[gw_ptr]]! \n" + "vadd.f32 q9, q7, q4 \n" + "vadd.f32 q10, q9, q2 \n" + "vsub.f32 q11, q9, q2 \n" + "vmul.f32 q10, q10, d0[1] \n" + "vst1.32 {d20-d21}, [%[gw_ptr]]! \n" + "vmul.f32 q11, q11, d0[1] \n" + "vst1.32 {d22-d23}, [%[gw_ptr]]! \n" + + "vmul.f32 q9, q7, d0[0] \n" + "vmul.f32 q9, q9, d0[0] \n" // 4 * g0 + "vmul.f32 q10, q2, d0[0] \n" // 2 * g1 + "vmul.f32 q11, q4, d0[0] \n" + "vmul.f32 q11, q11, d0[0] \n" // 4 * g2 + + "vadd.f32 q12, q7, q11 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + + "vadd.f32 q12, q4, q9 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + + "vst1.32 {d8-d9}, [%[gw_ptr]]! \n" + + "mov r0, #12 \n" + "vld1.32 {d2-d3}, [%[inptr0]], r0 \n" + "vld1.32 {d6-d7}, [%[inptr1]], r0 \n" + "vld1.32 {d10-d11}, [%[inptr2]], r0 \n" + "vld1.32 {d14-d15}, [%[inptr3]], r0 \n" + "vtrn.32 q1, q3 \n" + "vtrn.32 q5, q7 \n" + "vswp.32 d3, d10 \n" + "vswp.32 d7, d14 \n" + + // q1: g0, q3: g1, q5: g2 + "vst1.32 {d2-d3}, [%[gw_ptr]]! \n" + "vadd.f32 q9, q1, q5 \n" + "vadd.f32 q10, q9, q3 \n" + "vsub.f32 q11, q9, q3 \n" + "vmul.f32 q10, q10, d0[1] \n" + "vst1.32 {d20-d21}, [%[gw_ptr]]! \n" + "vmul.f32 q11, q11, d0[1] \n" + "vst1.32 {d22-d23}, [%[gw_ptr]]! \n" + + "vmul.f32 q9, q1, d0[0] \n" + "vmul.f32 q9, q9, d0[0] \n" // 4 * g0 + "vmul.f32 q10, q3, d0[0] \n" // 2 * g1 + "vmul.f32 q11, q5, d0[0] \n" + "vmul.f32 q11, q11, d0[0] \n" // 4 * g2 + + "vadd.f32 q12, q1, q11 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + + "vadd.f32 q12, q5, q9 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[gw_ptr]]! \n" + + "vst1.32 {d10-d11}, [%[gw_ptr]]! \n" + : [gw_ptr] "+r"(gw_ptr), [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3) + : [tm_ptr] "r"((float *)transform_matrix) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "r0"); + + float *gw_ptr0 = gw; + float *gw_ptr1 = gw + 32; + float *gw_ptr2 = gw + 64; + float *outptr0 = outptr + (ic << 2); // ic * 4 + int steps = (in_channel << 2) * sizeof(float); // in_channel * 4 + asm volatile( + "vld1.32 {d0-d1}, [%[tm_ptr]] \n" + "mov r0, #8 \n" + + "loop_8_%=: \n" + "vld1.32 {d2-d3}, [%[gw_ptr0]]! \n" + "vld1.32 {d4-d5}, [%[gw_ptr1]]! \n" + "vld1.32 {d6-d7}, [%[gw_ptr2]]! \n" + + // q1: g0, q2: g1, q3: g2 + "vst1.32 {d2-d3}, [%[outptr0]], %[steps] \n" + "vadd.f32 q9, q1, q3 \n" + "vadd.f32 q10, q9, q2 \n" + "vsub.f32 q11, q9, q2 \n" + "vmul.f32 q10, q10, d0[1] \n" + "vst1.32 {d20-d21}, [%[outptr0]], %[steps] \n" + "vmul.f32 q11, q11, d0[1] \n" + "vst1.32 {d22-d23}, [%[outptr0]], %[steps] \n" + + "vmul.f32 q9, q1, d0[0] \n" + "vmul.f32 q9, q9, d0[0] \n" // 4 * g0 + "vmul.f32 q10, q2, d0[0] \n" // 2 * g1 + "vmul.f32 q11, q3, d0[0] \n" + "vmul.f32 q11, q11, d0[0] \n" // 4 * g2 + + "vadd.f32 q12, q1, q11 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[0] \n" + "vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n" + + // w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180) + "vadd.f32 q12, q3, q9 \n" + "vadd.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n" + "vsub.f32 q13, q12, q10 \n" + "vmul.f32 q13, q13, d1[1] \n" + "vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n" + + "vst1.32 {d6-d7}, [%[outptr0]], %[steps] \n" + + "subs r0, #1 \n" + "bne loop_8_%= \n" + : [outptr0] "+r"(outptr0), [gw_ptr0] "+r"(gw_ptr0), + [gw_ptr1] "+r"(gw_ptr1), [gw_ptr2] "+r"(gw_ptr2) + : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) + : "cc", "memory", "q0", "q1", "q2", "q3", "q9", "q10", "q11", "q12", + "q13", "r0"); + } + } +#endif + + // remain output channel + #pragma omp parallel for + for (int oc = remain_start; oc < out_channel; ++oc) { + float gw[3][8]; // gw[3][8] + const float *inptr0 = inptr + oc * in_channel * 9; // + // (oc / 4) * 64 * in_channel * 4 + oc % 4 + int offset = ((oc & 0xFFFC) << 6) * in_channel + (oc & 0x3); + int steps = (in_channel << 2); // in_channel * 4 + float *outptr = trans_outptr + offset; + for (int ic = 0; ic < in_channel; ++ic) { + for (int i = 0; i < 3; ++i, inptr0 += 3) { + float g0 = inptr0[0]; + float g1 = inptr0[1]; + float g2 = inptr0[2]; + float d0 = g0 + g2; + float d1 = g0 + 4 * g2; + float d2 = g2 + 4 * g0; + float d3 = 2 * g1; + gw[i][0] = g0; + gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2) + gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2) + gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2) + gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2) + gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2) + gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2) + gw[i][7] = g2; + } + for (int i = 0; i < 8; ++i) { + float g0 = gw[0][i]; + float g1 = gw[1][i]; + float g2 = gw[2][i]; + float d0 = g0 + g2; + float d1 = g0 + 4 * g2; + float d2 = g2 + 4 * g0; + float d3 = 2 * g1; + int offset = i * 8 * steps; + outptr[offset] = g0; + outptr[offset + 1 * steps] = -2.f / 9 * (d0 + g1); + outptr[offset + 2 * steps] = -2.f / 9 * (d0 - g1); + outptr[offset + 3 * steps] = 1.f / 90 * (d1 + d3); + outptr[offset + 4 * steps] = 1.f / 90 * (d1 - d3); + outptr[offset + 5 * steps] = 1.f / 180 * (d2 + d3); + outptr[offset + 6 * steps] = 1.f / 180 * (d2 - d3); + outptr[offset + 7 * steps] = g2; + } + outptr += 4; + } + } +} + +template <> +void winograd_transform_input<8, 3>(const framework::Tensor &input, + framework::Tensor *output) { + /* + * x0 = (d0 - d6) + (d4 - d2) * 5.25 + * x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5) + * x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5) + * x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5) + * x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5) + * x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5) + * x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5) + * x7 = (d7 - d1) + (d3 - d5) * 5.25 + */ + // package input into [roundup(tiles/8), 64, channel, 8] tiles + int channel = input.dims()[1]; + int height = input.dims()[2]; + int width = input.dims()[3]; + int h_tiles = (height + 3) / 6; // (height - 8 + 5 + 6) / 6 + int w_tiles = (width + 3) / 6; // (width - 8 + 5 + 6) / 6 + int tiles = (h_tiles * w_tiles + 7) / 8; + framework::DDim transformed_shape = + framework::make_ddim(std::vector{tiles, 64, channel, 8}); + float *outptr = output->mutable_data(transformed_shape); + memset(outptr, 0, output->numel() * sizeof(float)); + + const float *inptr = input.data(); + int inter_h = (height - 2) / 6; + int inter_w = (width - 2) / 6; + int remain_h = height - (inter_h * 6); + int remain_w = width - (inter_w * 6); + framework::Tensor input_pad; + if (remain_h > 2 || remain_w > 2) { + inter_h += (remain_h > 2); + inter_w += (remain_w > 2); + height = (inter_h - 1) * 6 + 8; + width = (inter_w - 1) * 6 + 8; + framework::DDim input_shape = + framework::make_ddim(std::vector{1, channel, height, width}); + PadFunctor pad; + inptr = input_pad.mutable_data(input_shape); + pad(input, 0, height - input.dims()[2], 0, width - input.dims()[3], + &input_pad); + } + size_t image_size = height * width; + const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f, + 2.f, -1.25f, 0.5f, 0.25f}; + int remain_c_start = channel & 0xFFFC; +#if 1 + remain_c_start = 0; +#else + #pragma omp parallel for + for (int c = 0; c < channel - 3; c += 4) { + const float *in = inptr + c * image_size; + float d_bt[64 * 4]; // d * B_t + for (int h = 0; h < h_tiles; ++h) { + for (int w = 0; w < w_tiles; ++w) { + const float *in0 = in + (h * width + w) * 6; + const float *in1 = in0 + image_size; + const float *in2 = in1 + image_size; + const float *in3 = in2 + image_size; + int steps = width * sizeof(float); + float *d_bt_ptr = d_bt; + asm volatile( + "mov r0, #8 \n" + "vld1.32 {d0-d3}, [%[tm_ptr]] \n" + // row loop + "loop_r_%=: \n" + "vld1.32 {d4-d7}, [%[in0]], %[steps] \n" + "vld1.32 {d8-d11}, [%[in1]], %[steps] \n" + "vld1.32 {d12-d15}, [%[in2]], %[steps] \n" + "vld1.32 {d16-d19}, [%[in3]], %[steps] \n" + "vtrn.32 q2, q4 \n" // d0: q2 + "vtrn.32 q3, q5 \n" // d1: q4 + "vtrn.32 q6, q8 \n" // d2: q6 + "vtrn.32 q7, q9 \n" // d3: q8 + "vswp.32 d5, d12 \n" // d4: q3 + "vswp.32 d9, d16 \n" // d5: q5 + "vswp.32 d7, d14 \n" // d6: q7 + "vswp.32 d11, d18 \n" // d7: q9 + + "vsub.f32 q10, q2, q7 \n" + "vsub.f32 q11, q3, q6 \n" + "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - + // d2) * 5.25 + "vst1.32 {d20-d21}, [%[d_bt]]! \n" + + "vadd.f32 q10, q6, q7 \n" + "vadd.f32 q11, q4, q5 \n" + "vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 + + // d6 + "vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 + + // d5 + "vadd.f32 q12, q10, q11 \n" + "vsub.f32 q13, q10, q11 \n" + "vst1.32 {d24-d27}, [%[d_bt]]! \n" + + "vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2 + "vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1 + "vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6 + "vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 * + // d5 + "vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6 + // - 1.25 * d4 + "vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 * + // d5 - 2.5 * d3 + "vadd.f32 q12, q10, q11 \n" + "vsub.f32 q13, q10, q11 \n" + "vst1.32 {d24-d27}, [%[d_bt]]! \n" + + "vmul.f32 q10, q6, d2[0] \n" // 2 * d2 + "vmul.f32 q11, q4, d2[0] \n" // 2 * d1 + "vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 * + // d4 + "vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 * + // d3 + "vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 * + // d3 + 0.5 * d6 + "vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 * + // d4 + 0.5 * d5 + "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 + // + d6 + "vadd.f32 q12, q10, q11 \n" + "vsub.f32 q13, q10, q11 \n" + "vst1.32 {d24-d27}, [%[d_bt]]! \n" + + "vsub.f32 q10, q9, q4 \n" + "vsub.f32 q11, q8, q5 \n" + "vmla.f32 q10, q11, d0[0] \n" + "vst1.32 {d20-d21}, [%[d_bt]]! \n" + + "subs r0, #1 \n" + "bne loop_r_%= \n" + : [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1), + [in2] "+r"(in2), [in3] "+r"(in3) + : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "r0"); + + float *ptr0 = d_bt; + float *ptr1 = ptr0 + 32; + float *ptr2 = ptr1 + 32; + float *ptr3 = ptr2 + 32; + float *ptr4 = ptr3 + 32; + float *ptr5 = ptr4 + 32; + float *ptr6 = ptr5 + 32; + float *ptr7 = ptr6 + 32; + int tile_indics = h * w_tiles + w; + int tile_block = tile_indics >> 3; + int block_indics = tile_indics & 0x7; + // (tiles / 8, 64, channel, 8) + float *out0 = + outptr + (tile_block * 64 * channel + c) * 8 + block_indics; + steps = (channel - 3) * 8 * sizeof(float); + asm volatile( + "vld1.32 {d0-d3}, [%[tm_ptr]] \n" + "mov r0, 4 \n" + "mov r1, 32 \n" + "loop_col_%=: \n" + // col 0: + "vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0 + "vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1 + "vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2 + "vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3 + "vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4 + "vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5 + "vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6 + "vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7 + + "vsub.f32 q10, q2, q8 \n" // d0 - d6 + "vsub.f32 q11, q6, q4 \n" // d4 - d2 + "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - + // d2) * 5.25 + "vst1.32 {d20[0]}, [%[out0]], r1 \n" + "vst1.32 {d20[1]}, [%[out0]], r1 \n" + "vst1.32 {d21[0]}, [%[out0]], r1 \n" + "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" + + "vadd.f32 q10, q4, q8 \n" + "vadd.f32 q11, q3, q7 \n" + "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + + // d6 + "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + + // d5 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 + "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 + "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 + "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * + // d5 + "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 + // - 1.25 * d4 + "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * + // d5 - 2.5 * d3 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 + "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 + "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * + // d4 + "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * + // d3 + "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * + // d3 + 0.5 * d6 + "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * + // d4 + 0.5 * d5 + "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 + // + d6 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vsub.f32 q10, q9, q3 \n" + "vsub.f32 q11, q5, q7 \n" + "vmla.f32 q10, q11, d0[0] \n" + "vst1.32 {d20[0]}, [%[out0]], r1 \n" + "vst1.32 {d20[1]}, [%[out0]], r1 \n" + "vst1.32 {d21[0]}, [%[out0]], r1 \n" + "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" + + // col 1: + "vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0 + "vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1 + "vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2 + "vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3 + "vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4 + "vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5 + "vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6 + "vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7 + + "vsub.f32 q10, q2, q8 \n" // d0 - d6 + "vsub.f32 q11, q6, q4 \n" // d4 - d2 + "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - + // d2) * 5.25 + "vst1.32 {d20[0]}, [%[out0]], r1 \n" + "vst1.32 {d20[1]}, [%[out0]], r1 \n" + "vst1.32 {d21[0]}, [%[out0]], r1 \n" + "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" + + "vadd.f32 q10, q4, q8 \n" + "vadd.f32 q11, q3, q7 \n" + "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + + // d6 + "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + + // d5 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 + "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 + "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 + "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * + // d5 + "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 + // - 1.25 * d4 + "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * + // d5 - 2.5 * d3 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 + "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 + "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * + // d4 + "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * + // d3 + "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * + // d3 + 0.5 * d6 + "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * + // d4 + 0.5 * d5 + "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 + // + d6 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out0]], r1 \n" + "vst1.32 {d24[1]}, [%[out0]], r1 \n" + "vst1.32 {d25[0]}, [%[out0]], r1 \n" + "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" + + "vsub.f32 q10, q9, q3 \n" + "vsub.f32 q11, q5, q7 \n" + "vmla.f32 q10, q11, d0[0] \n" + "vst1.32 {d20[0]}, [%[out0]], r1 \n" + "vst1.32 {d20[1]}, [%[out0]], r1 \n" + "vst1.32 {d21[0]}, [%[out0]], r1 \n" + "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" + + "subs r0, #1 \n" + "bne loop_col_%= \n" + : [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4), + [ptr5] "+r"(ptr5), [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7) + : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "r0", "r1"); + } + } + } +#endif + + // remainer channels + #pragma omp parallel for + for (int c = remain_c_start; c < channel; ++c) { + const float *in = inptr + c * image_size; + float d_bt[64]; // d * B_t + for (int h = 0; h < h_tiles; ++h) { + for (int w = 0; w < w_tiles; ++w) { + const float *in0 = in + (h * width + w) * 6; + const float *in1 = in0 + width; + const float *in2 = in1 + width; + const float *in3 = in2 + width; + float *d_bt_ptr = d_bt; + int steps = 4 * width * sizeof(float); + asm volatile( + "vld1.32 {d0-d3}, [%[tm_ptr]] \n" + "mov r0, #2 \n" + // row loop + "loop_r_%=: \n" + "vld1.32 {d4-d7}, [%[in0]], %[steps] \n" + "vld1.32 {d8-d11}, [%[in1]], %[steps] \n" + "vld1.32 {d12-d15}, [%[in2]], %[steps] \n" + "vld1.32 {d16-d19}, [%[in3]], %[steps] \n" + "vtrn.32 q2, q4 \n" // d0: q2 + "vtrn.32 q3, q5 \n" // d1: q4 + "vtrn.32 q6, q8 \n" // d2: q6 + "vtrn.32 q7, q9 \n" // d3: q8 + "vswp.32 d5, d12 \n" // d4: q3 + "vswp.32 d9, d16 \n" // d5: q5 + "vswp.32 d7, d14 \n" // d6: q7 + "vswp.32 d11, d18 \n" // d7: q9 + + "vsub.f32 q10, q2, q7 \n" + "vsub.f32 q11, q3, q6 \n" + "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - + // d2) * 5.25" + "vst1.32 {d20-d21}, [%[d_bt]]! \n" + + "vadd.f32 q10, q6, q7 \n" + "vadd.f32 q11, q4, q5 \n" + "vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 + + // d6 + "vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 + + // d5 + "vadd.f32 q12, q10, q11 \n" + "vsub.f32 q13, q10, q11 \n" + "vst1.32 {d24-d27}, [%[d_bt]]! \n" + + "vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2 + "vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1 + "vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6 + "vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 * + // d5 + "vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6 + // - 1.25 * d4 + "vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 * + // d5 - 2.5 * d3 + "vadd.f32 q12, q10, q11 \n" + "vsub.f32 q13, q10, q11 \n" + "vst1.32 {d24-d27}, [%[d_bt]]! \n" + + "vmul.f32 q10, q6, d2[0] \n" // 2 * d2 + "vmul.f32 q11, q4, d2[0] \n" // 2 * d1 + "vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 * + // d4 + "vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 * + // d3 + "vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 * + // d3 + 0.5 * d6 + "vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 * + // d4 + 0.5 * d5 + "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 + // + d6 + "vadd.f32 q12, q10, q11 \n" + "vsub.f32 q13, q10, q11 \n" + "vst1.32 {d24-d27}, [%[d_bt]]! \n" + + "vsub.f32 q10, q9, q4 \n" + "vsub.f32 q11, q8, q5 \n" + "vmla.f32 q10, q11, d0[0] \n" + "vst1.32 {d20-d21}, [%[d_bt]]! \n" + + "subs r0, #1 \n" + "bne loop_r_%= \n" + : [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1), + [in2] "+r"(in2), [in3] "+r"(in3) + : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "r0"); + + float *ptr0 = d_bt; + float *ptr1 = ptr0 + 32; + int tile_indics = h * w_tiles + w; + int tile_block = tile_indics >> 3; + int block_indics = tile_indics & 0x7; + // (tiles / 8, 64, channel, 8) + float *out0 = + outptr + (tile_block * 64 * channel + c) * 8 + block_indics; + float *out1 = out0 + channel * 8; + float *out2 = out1 + channel * 8; + float *out3 = out2 + channel * 8; + float *out4 = out3 + channel * 8; + float *out5 = out4 + channel * 8; + float *out6 = out5 + channel * 8; + float *out7 = out6 + channel * 8; + steps = 8 * channel * 8 * sizeof(float); + asm volatile( + "mov r0, #2 \n" + "vld1.32 {d0-d3}, [%[tm_ptr]] \n" + // row loop + "loop_r_%=: \n" + "vld1.32 {d4-d7}, [%[ptr0]]! \n" // q2: d0, q3: d1 + "vld1.32 {d8-d11}, [%[ptr0]]! \n" // q4: d2, q5: d3 + "vld1.32 {d12-d15}, [%[ptr1]]! \n" // q6: d4, q7: d5 + "vld1.32 {d16-d19}, [%[ptr1]]! \n" // q8: d6, q9: d7 + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vtrn.32 q6, q7 \n" + "vtrn.32 q8, q9 \n" + "vswp.32 d5, d8 \n" + "vswp.32 d7, d10 \n" + "vswp.32 d13, d16 \n" + "vswp.32 d15, d18 \n" + + "vsub.f32 q10, q2, q8 \n" // d0 - d6 + "vsub.f32 q11, q6, q4 \n" // d4 - d2 + "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - + // d2) * 5.25 + "vst1.32 {d20[0]}, [%[out0]], %[steps] \n" + "vst1.32 {d20[1]}, [%[out0]], %[steps] \n" + "vst1.32 {d21[0]}, [%[out0]], %[steps] \n" + "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" + + "vadd.f32 q10, q4, q8 \n" + "vadd.f32 q11, q3, q7 \n" + "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + + // d6 + "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + + // d5 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out1]], %[steps] \n" + "vst1.32 {d24[1]}, [%[out1]], %[steps] \n" + "vst1.32 {d25[0]}, [%[out1]], %[steps] \n" + "vst1.32 {d25[1]}, [%[out1]], %[steps] \n" + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out2]], %[steps] \n" + "vst1.32 {d24[1]}, [%[out2]], %[steps] \n" + "vst1.32 {d25[0]}, [%[out2]], %[steps] \n" + "vst1.32 {d25[1]}, [%[out2]], %[steps] \n" + + "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 + "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 + "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 + "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * + // d5 + "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 + // - 1.25 * d4 + "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * + // d5 - 2.5 * d3 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out3]], %[steps] \n" + "vst1.32 {d24[1]}, [%[out3]], %[steps] \n" + "vst1.32 {d25[0]}, [%[out3]], %[steps] \n" + "vst1.32 {d25[1]}, [%[out3]], %[steps] \n" + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out4]], %[steps] \n" + "vst1.32 {d24[1]}, [%[out4]], %[steps] \n" + "vst1.32 {d25[0]}, [%[out4]], %[steps] \n" + "vst1.32 {d25[1]}, [%[out4]], %[steps] \n" + + "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 + "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 + "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * + // d4 + "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * + // d3 + "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * + // d3 + 0.5 * d6 + "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * + // d4 + 0.5 * d5 + "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 + // + d6 + "vadd.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out5]], %[steps] \n" + "vst1.32 {d24[1]}, [%[out5]], %[steps] \n" + "vst1.32 {d25[0]}, [%[out5]], %[steps] \n" + "vst1.32 {d25[1]}, [%[out5]], %[steps] \n" + "vsub.f32 q12, q10, q11 \n" + "vst1.32 {d24[0]}, [%[out6]], %[steps] \n" + "vst1.32 {d24[1]}, [%[out6]], %[steps] \n" + "vst1.32 {d25[0]}, [%[out6]], %[steps] \n" + "vst1.32 {d25[1]}, [%[out6]], %[steps] \n" + + "vsub.f32 q10, q9, q3 \n" + "vsub.f32 q11, q5, q7 \n" + "vmla.f32 q10, q11, d0[0] \n" + "vst1.32 {d20[0]}, [%[out7]], %[steps] \n" + "vst1.32 {d20[1]}, [%[out7]], %[steps] \n" + "vst1.32 {d21[0]}, [%[out7]], %[steps] \n" + "vst1.32 {d21[1]}, [%[out7]], %[steps] \n" + + "subs r0, #1 \n" + "bne loop_r_%= \n" + : [out0] "+r"(out0), [out1] "+r"(out1), [out2] "+r"(out2), + [out3] "+r"(out3), [out4] "+r"(out4), [out5] "+r"(out5), + [out6] "+r"(out6), [out7] "+r"(out7), [ptr0] "+r"(ptr0), + [ptr1] "+r"(ptr1) + : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "r0"); + } + } + } +} + +template <> +void winograd_transform_output<8, 3>(const framework::Tensor &input, + const framework::Tensor &weight, + framework::Tensor *output) { + // weight shape is [out_channel/4, 64, in_channel, 4], + // input shape is [hw/8, 64, in_channel, 8] + int in_channel = input.dims()[2]; + int tiles = input.dims()[0]; + int out_channel = weight.dims()[0]; + + // compute U*V first + framework::Tensor uv_trans; + framework::DDim shape = + framework::make_ddim(std::vector{out_channel, tiles, 64, 32}); + float *uv_trans_ptr = uv_trans.mutable_data(shape); + memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float)); + const float *input_ptr = input.data(); + const float *weight_ptr = weight.data(); + + #pragma omp parallel for + for (int i = 0; i < out_channel; ++i) { + float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32); + for (int j = 0; j < tiles; ++j) { + for (int k = 0; k < 64; ++k) { + const float *w_ptr = weight_ptr + (i * 64 + k) * in_channel * 4; + const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8; + int inter_channel = in_channel >> 1; + int remain_channel = in_channel & 0x1; + asm volatile( + "veor q8, q8, q8 \n" + "veor q9, q9, q9 \n" + "veor q10, q10, q10 \n" + "veor q11, q11, q11 \n" + "veor q12, q12, q12 \n" + "veor q13, q13, q13 \n" + "veor q14, q14, q14 \n" + "veor q15, q15, q15 \n" + + "b store_res_%= \n" + // loop 2 channels + "loop_2c_%=: \n" + "vld1.32 {d0-d3}, [%[w_ptr]]! \n" + "vld1.32 {d4-d7}, [%[in_ptr]]! \n" + "vld1.32 {d8-d11}, [%[in_ptr]]! \n" + "vmla.f32 q8, q2, d0[0] \n" + "vmla.f32 q9, q3, d0[0] \n" + "vmla.f32 q10, q2, d0[1] \n" + "vmla.f32 q11, q3, d0[1] \n" + "vmla.f32 q12, q2, d1[0] \n" + "vmla.f32 q13, q3, d1[0] \n" + "vmla.f32 q14, q2, d1[1] \n" + "vmla.f32 q15, q3, d1[1] \n" + + "vmla.f32 q8, q4, d2[0] \n" + "vmla.f32 q9, q5, d2[0] \n" + "vmla.f32 q10, q4, d2[1] \n" + "vmla.f32 q11, q5, d2[1] \n" + "vmla.f32 q12, q4, d3[0] \n" + "vmla.f32 q13, q5, d3[0] \n" + "vmla.f32 q14, q4, d3[1] \n" + "vmla.f32 q15, q5, d3[1] \n" + + "subs %[inter_channel], #1 \n" + "bne loop_2c_%= \n" + "mov pc, lr \n" + + // loop 1 channel + "loop_c_%=: \n" + "vld1.32 {d0-d1}, [%[w_ptr]]! \n" + "vld1.32 {d4-d7}, [%[in_ptr]]! \n" + + "vmla.f32 q8, q2, d0[0] \n" + "vmla.f32 q9, q3, d0[0] \n" + "vmla.f32 q10, q2, d0[1] \n" + "vmla.f32 q11, q3, d0[1] \n" + "vmla.f32 q12, q2, d1[0] \n" + "vmla.f32 q13, q3, d1[0] \n" + "vmla.f32 q14, q2, d1[1] \n" + "vmla.f32 q15, q3, d1[1] \n" + + "subs %[remain_channel], #1 \n" + "bne loop_c_%= \n" + "mov pc, lr \n" + + "store_res_%=: \n" + "cmp %[inter_channel], #0 \n" + "it gt \n" + "blgt loop_2c_%= \n" + "cmp %[remain_channel], #0 \n" + "it gt \n" + "blgt loop_c_%= \n" + + "vst1.32 {d16-d19}, [%[uv_ptr]]! \n" + "vst1.32 {d20-d23}, [%[uv_ptr]]! \n" + "vst1.32 {d24-d27}, [%[uv_ptr]]! \n" + "vst1.32 {d28-d31}, [%[uv_ptr]]! \n" + : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), + [remain_channel] "+r"(remain_channel), + [inter_channel] "+r"(inter_channel) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "pc", "lr"); + } + } + } + + /* + * s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) + * s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) + * s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6) + * s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6) + * s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6) + * s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) + m7 + */ + int out_h = output->dims()[2]; + int out_w = output->dims()[3]; + int h_tiles = (out_h + 5) / 6; + int w_tiles = (out_w + 5) / 6; + int remain_h = out_h - out_h / 6 * 6; + int remain_w = out_w - out_w / 6 * 6; + float *output_ptr = output->mutable_data(); + float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f}; + + #pragma omp parallel for + for (int oc = 0; oc < output->dims()[1]; ++oc) { + float at_m[48]; // [6][8] + float output_tmp[36]; // [6][6], temporarily restore results + // (oc / 4) * tiles * 64 * 32 + (oc & 0x3) * 8 + const float *uv_ptr = + uv_trans_ptr + (oc >> 2) * tiles * 64 * 32 + (oc & 0x3) * 8; + for (int tile_h = 0; tile_h < h_tiles; ++tile_h) { + for (int tile_w = 0; tile_w < w_tiles; ++tile_w) { + float *at_m_ptr = at_m; + int tile_indics = tile_h * w_tiles + tile_w; + int tile_block = tile_indics >> 3; + int block_indics = tile_indics & 0x7; + const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics; + int steps = 32 * sizeof(float); + asm volatile( + "vld1.32 {d0-d1}, [%[tm_ptr]] \n" + "mov r0, #2 \n" + + "loop_%=: \n" + "vld1.32 {d2[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d6[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d10[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d14[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d4[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d8[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d12[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d16[0]}, [%[uv_ptr0]], %[steps] \n" + + "vld1.32 {d2[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d6[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d10[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d14[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d4[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d8[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d12[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d16[1]}, [%[uv_ptr0]], %[steps] \n" + + "vld1.32 {d3[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d7[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d11[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d15[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d5[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d9[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d13[0]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d17[0]}, [%[uv_ptr0]], %[steps] \n" + + "vld1.32 {d3[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d7[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d11[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d15[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d5[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d9[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d13[1]}, [%[uv_ptr0]], %[steps] \n" + "vld1.32 {d17[1]}, [%[uv_ptr0]], %[steps] \n" + + "vadd.f32 q9, q3, q5 \n" // m1 + m2 + "vadd.f32 q10, q7, q2 \n" // m3 + m4 + "vadd.f32 q11, q4, q6 \n" // m5 + m6 + "vsub.f32 q12, q3, q5 \n" // m1 - m2 + "vsub.f32 q13, q7, q2 \n" // m3 - m4 + "vsub.f32 q14, q4, q6 \n" // m5 - m6 + "vmul.f32 q2, q13, d0[0] \n" // 2 * (m3 - m4) + "vmul.f32 q3, q11, d0[0] \n" // 2 * (m5 + m6) + + "vadd.f32 q15, q1, q9 \n" + "vadd.f32 q15, q15, q10 \n" + "vmla.f32 q15, q3, d1[1] \n" + "vst1.32 {d30-d31}, [%[at_m_ptr]]! \n" + + "vadd.f32 q15, q12, q2 \n" + "vmla.f32 q15, q14, d1[1] \n" + "vst1.32 {d30-d31}, [%[at_m_ptr]]! \n" + + "vmov.32 q15, q9 \n" + "vmla.f32 q15, q10, d0[1] \n" + "vmla.f32 q15, q11, d1[0] \n" + "vst1.32 {d30-d31}, [%[at_m_ptr]]! \n" + + "vmov.32 q15, q12 \n" + "vmla.f32 q15, q13, d1[0] \n" + "vmla.f32 q15, q14, d0[1] \n" + "vst1.32 {d30-d31}, [%[at_m_ptr]]! \n" + + "vadd.f32 q15, q9, q3 \n" + "vmla.f32 q15, q10, d1[1] \n" + "vst1.32 {d30-d31}, [%[at_m_ptr]]! \n" + + "vadd.f32 q15, q12, q8 \n" + "vadd.f32 q15, q15, q14 \n" + "vmla.f32 q15, q2, d1[1] \n" + "vst1.32 {d30-d31}, [%[at_m_ptr]]! \n" + + "subs r0, #1 \n" + "bne loop_%= \n" + : [uv_ptr0] "+r"(uv_ptr0), [at_m_ptr] "+r"(at_m_ptr) + : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); + + float *at_m_ptr0 = at_m; + float *at_m_ptr1 = at_m + 24; + if ((remain_w > 0 && tile_w == w_tiles - 1) || + (remain_h > 0 && tile_h == h_tiles - 1)) { + float *out_ptr0 = output_tmp; + float *out_ptr1 = output_tmp + 6; + float *out_ptr2 = output_tmp + 12; + float *out_ptr3 = output_tmp + 18; + float *out_ptr4 = output_tmp + 24; + float *out_ptr5 = output_tmp + 30; + asm volatile( + "vld1.32 {d0-d1}, [%[tm_ptr]] \n" + // process 4 rows + "vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // q1: m0, q2: m1 + "vld1.32 {d6-d9}, [%[at_m_ptr0]]! \n" // q3: m2, q4: m3 + "vld1.32 {d10-d13}, [%[at_m_ptr1]]! \n" // q5: m4, q6: m5 + "vld1.32 {d14-d17}, [%[at_m_ptr1]]! \n" // q7: m6, q8: m7 + "vtrn.32 q1, q2 \n" + "vtrn.32 q3, q4 \n" + "vtrn.32 q5, q6 \n" + "vtrn.32 q7, q8 \n" + "vswp.32 d3, d6 \n" + "vswp.32 d5, d8 \n" + "vswp.32 d11, d14 \n" + "vswp.32 d13, d16 \n" + + "vadd.f32 q9, q2, q3 \n" // m1 + m2 + "vadd.f32 q10, q4, q5 \n" // m3 + m4 + "vadd.f32 q11, q6, q7 \n" // m5 + m6 + "vsub.f32 q12, q2, q3 \n" // m1 - m2 + "vsub.f32 q13, q4, q5 \n" // m3 - m4 + "vsub.f32 q14, q6, q7 \n" // m5 - m6 + "vmul.f32 q6, q13, d0[0] \n" // 2 * (m3 - m4) + "vmul.f32 q7, q11, d0[0] \n" // 2 * (m5 + m6) + + "vadd.f32 q1, q1, q9 \n" + "vadd.f32 q1, q1, q10 \n" + "vmla.f32 q1, q7, d1[1] \n" + + "vadd.f32 q2, q12, q6 \n" + "vmla.f32 q2, q14, d1[1] \n" + + "vmov.32 q3, q9 \n" + "vmla.f32 q3, q10, d0[1] \n" + "vmla.f32 q3, q11, d1[0] \n" + + "vmov.32 q4, q12 \n" + "vmla.f32 q4, q13, d1[0] \n" + "vmla.f32 q4, q14, d0[1] \n" + + "vtrn.32 q1, q2 \n" + "vtrn.32 q3, q4 \n" + "vswp.32 d3, d6 \n" + "vswp.32 d5, d8 \n" + "vst1.32 {d2-d3}, [%[out_ptr0]]! \n" + "vst1.32 {d4-d5}, [%[out_ptr1]]! \n" + "vst1.32 {d6-d7}, [%[out_ptr2]]! \n" + "vst1.32 {d8-d9}, [%[out_ptr3]]! \n" + + "vadd.f32 q1, q9, q7 \n" + "vmla.f32 q1, q10, d1[1] \n" + + "vadd.f32 q2, q12, q8 \n" + "vadd.f32 q2, q2, q14 \n" + "vmla.f32 q2, q6, d1[1] \n" + + "vtrn.32 q1, q2 \n" + "vst1.32 {d2}, [%[out_ptr0]]! \n" + "vst1.32 {d4}, [%[out_ptr1]]! \n" + "vst1.32 {d3}, [%[out_ptr2]]! \n" + "vst1.32 {d5}, [%[out_ptr3]]! \n" + + // remain 2 rows + "vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // d2: m0, d3: m2, + // d4: m1, d5: m3 + "vld1.32 {d6-d9}, [%[at_m_ptr1]]! \n" // d6: m4, d7: m6, + // d8: m5, d9: m7 + "vtrn.32 q1, q2 \n" + "vtrn.32 q3, q4 \n" + + "vadd.f32 d10, d4, d3 \n" // m1 + m2 + "vadd.f32 d11, d5, d6 \n" // m3 + m4 + "vadd.f32 d12, d8, d7 \n" // m5 + m6 + "vsub.f32 d13, d4, d3 \n" // m1 - m2 + "vsub.f32 d14, d5, d6 \n" // m3 - m4 + "vsub.f32 d15, d8, d7 \n" // m5 - m6 + "vmul.f32 d16, d14, d0[0] \n" // 2 * (m3 - m4) + "vmul.f32 d17, d12, d0[0] \n" // 2 * (m5 + m6) + + "vadd.f32 d18, d2, d10 \n" + "vadd.f32 d18, d18, d11 \n" + "vmla.f32 d18, d17, d1[1] \n" + + "vadd.f32 d20, d13, d16 \n" + "vmla.f32 d20, d15, d1[1] \n" + + "vmov.32 d19, d10 \n" + "vmla.f32 d19, d11, d0[1] \n" + "vmla.f32 d19, d12, d1[0] \n" + + "vmov.32 d21, d13 \n" + "vmla.f32 d21, d14, d1[0] \n" + "vmla.f32 d21, d15, d0[1] \n" + + "vtrn.32 d18, d20 \n" + "vtrn.32 d19, d21 \n" + "vst1.32 {d18-d19}, [%[out_ptr4]]! \n" + "vst1.32 {d20-d21}, [%[out_ptr5]]! \n" + + "vadd.f32 d18, d10, d17 \n" + "vmla.f32 d18, d11, d1[1] \n" + + "vadd.f32 d19, d13, d9 \n" + "vadd.f32 d19, d19, d15 \n" + "vmla.f32 d19, d16, d1[1] \n" + + "vtrn.32 d18, d19 \n" + "vst1.32 {d18}, [%[out_ptr4]]! \n" + "vst1.32 {d19}, [%[out_ptr5]]! \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), + [out_ptr2] "+r"(out_ptr2), [out_ptr3] "+r"(out_ptr3), + [out_ptr4] "+r"(out_ptr4), [out_ptr5] "+r"(out_ptr5), + [at_m_ptr0] "+r"(at_m_ptr0), [at_m_ptr1] "+r"(at_m_ptr1) + : [tm_ptr] "r"((float *)transform_matrix) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; + float *out_ptr = output_ptr + offset; + int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h; + int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w; + for (int i = 0; i < remain_row; ++i, out_ptr += out_w) { + memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float)); + } + } else { + size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; + float *out_ptr0 = output_ptr + offset; + float *out_ptr1 = out_ptr0 + out_w; + float *out_ptr2 = out_ptr1 + out_w; + float *out_ptr3 = out_ptr2 + out_w; + float *out_ptr4 = out_ptr3 + out_w; + float *out_ptr5 = out_ptr4 + out_w; + asm volatile( + "vld1.32 {d0-d1}, [%[tm_ptr]] \n" + // process 4 rows + "vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // q1: m0, q2: m1 + "vld1.32 {d6-d9}, [%[at_m_ptr0]]! \n" // q3: m2, q4: m3 + "vld1.32 {d10-d13}, [%[at_m_ptr1]]! \n" // q5: m4, q6: m5 + "vld1.32 {d14-d17}, [%[at_m_ptr1]]! \n" // q7: m6, q8: m7 + "vtrn.32 q1, q2 \n" + "vtrn.32 q3, q4 \n" + "vtrn.32 q5, q6 \n" + "vtrn.32 q7, q8 \n" + "vswp.32 d3, d6 \n" + "vswp.32 d5, d8 \n" + "vswp.32 d11, d14 \n" + "vswp.32 d13, d16 \n" + + "vadd.f32 q9, q2, q3 \n" // m1 + m2 + "vadd.f32 q10, q4, q5 \n" // m3 + m4 + "vadd.f32 q11, q6, q7 \n" // m5 + m6 + "vsub.f32 q12, q2, q3 \n" // m1 - m2 + "vsub.f32 q13, q4, q5 \n" // m3 - m4 + "vsub.f32 q14, q6, q7 \n" // m5 - m6 + "vmul.f32 q6, q13, d0[0] \n" // 2 * (m3 - m4) + "vmul.f32 q7, q11, d0[0] \n" // 2 * (m5 + m6) + + "vadd.f32 q1, q1, q9 \n" + "vadd.f32 q1, q1, q10 \n" + "vmla.f32 q1, q7, d1[1] \n" + + "vadd.f32 q2, q12, q6 \n" + "vmla.f32 q2, q14, d1[1] \n" + + "vmov.32 q3, q9 \n" + "vmla.f32 q3, q10, d0[1] \n" + "vmla.f32 q3, q11, d1[0] \n" + + "vmov.32 q4, q12 \n" + "vmla.f32 q4, q13, d1[0] \n" + "vmla.f32 q4, q14, d0[1] \n" + + "vtrn.32 q1, q2 \n" + "vtrn.32 q3, q4 \n" + "vswp.32 d3, d6 \n" + "vswp.32 d5, d8 \n" + "vst1.32 {d2-d3}, [%[out_ptr0]]! \n" + "vst1.32 {d4-d5}, [%[out_ptr1]]! \n" + "vst1.32 {d6-d7}, [%[out_ptr2]]! \n" + "vst1.32 {d8-d9}, [%[out_ptr3]]! \n" + + "vadd.f32 q1, q9, q7 \n" + "vmla.f32 q1, q10, d1[1] \n" + + "vadd.f32 q2, q12, q8 \n" + "vadd.f32 q2, q2, q14 \n" + "vmla.f32 q2, q6, d1[1] \n" + + "vtrn.32 q1, q2 \n" + "vst1.32 {d2}, [%[out_ptr0]]! \n" + "vst1.32 {d4}, [%[out_ptr1]]! \n" + "vst1.32 {d3}, [%[out_ptr2]]! \n" + "vst1.32 {d5}, [%[out_ptr3]]! \n" + + // remain 2 rows + "vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // d2: m0, d3: m2, + // d4: m1, d5: m3 + "vld1.32 {d6-d9}, [%[at_m_ptr1]]! \n" // d6: m4, d7: m6, + // d8: m5, d9: m7 + "vtrn.32 q1, q2 \n" + "vtrn.32 q3, q4 \n" + + "vadd.f32 d10, d4, d3 \n" // m1 + m2 + "vadd.f32 d11, d5, d6 \n" // m3 + m4 + "vadd.f32 d12, d8, d7 \n" // m5 + m6 + "vsub.f32 d13, d4, d3 \n" // m1 - m2 + "vsub.f32 d14, d5, d6 \n" // m3 - m4 + "vsub.f32 d15, d8, d7 \n" // m5 - m6 + "vmul.f32 d16, d14, d0[0] \n" // 2 * (m3 - m4) + "vmul.f32 d17, d12, d0[0] \n" // 2 * (m5 + m6) + + "vadd.f32 d18, d2, d10 \n" + "vadd.f32 d18, d18, d11 \n" + "vmla.f32 d18, d17, d1[1] \n" + + "vadd.f32 d20, d13, d16 \n" + "vmla.f32 d20, d15, d1[1] \n" + + "vmov.32 d19, d10 \n" + "vmla.f32 d19, d11, d0[1] \n" + "vmla.f32 d19, d12, d1[0] \n" + + "vmov.32 d21, d13 \n" + "vmla.f32 d21, d14, d1[0] \n" + "vmla.f32 d21, d15, d0[1] \n" + + "vtrn.32 d18, d20 \n" + "vtrn.32 d19, d21 \n" + "vst1.32 {d18-d19}, [%[out_ptr4]]! \n" + "vst1.32 {d20-d21}, [%[out_ptr5]]! \n" + + "vadd.f32 d18, d10, d17 \n" + "vmla.f32 d18, d11, d1[1] \n" + + "vadd.f32 d19, d13, d9 \n" + "vadd.f32 d19, d19, d15 \n" + "vmla.f32 d19, d16, d1[1] \n" + + "vtrn.32 d18, d19 \n" + "vst1.32 {d18}, [%[out_ptr4]]! \n" + "vst1.32 {d19}, [%[out_ptr5]]! \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), + [out_ptr2] "+r"(out_ptr2), [out_ptr3] "+r"(out_ptr3), + [out_ptr4] "+r"(out_ptr4), [out_ptr5] "+r"(out_ptr5), + [at_m_ptr0] "+r"(at_m_ptr0), [at_m_ptr1] "+r"(at_m_ptr1) + : [tm_ptr] "r"((float *)transform_matrix) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + } + } + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __aarch64__ +#endif // CONV_OP diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 24af798a84c89ba6434178d5c4392a03c0ba5d87..3593ecc9831f6bf627273b0abb5e75cf8a168dbf 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -405,9 +405,9 @@ class ConvParam : public OpParam { const RType *Input() const { return input_; } - RType *Filter() const { return filter_; } + RType *&Filter() const { return filter_; } - RType *Output() const { return output_; } + RType *&Output() const { return output_; } const vector &Strides() const { return strides_; } @@ -415,6 +415,19 @@ class ConvParam : public OpParam { const vector &Dilations() const { return dilations_; } + enum ExecMode { + EXEC_INVALID = 0, + EXEC_GEMM_FLOAT, + EXEC_DEPTHWISE3x3S1P1_FLOAT, + EXEC_DEPTHWISE3x3_FLOAT, + EXEC_WINOGRAD3X3_FLOAT, + EXEC_WINOGRAD5X5_FLOAT, + EXEC_GEMM_INT8, + EXEC_DEPTHWISE3x3_INT8, + }; + + ExecMode &ExecMode() const { return exec_mode_; } + const int &Groups() const { return groups; } #ifdef PADDLE_MOBILE_CL @@ -426,11 +439,12 @@ class ConvParam : public OpParam { private: RType *input_; - RType *output_; - RType *filter_; + mutable RType *output_; + mutable RType *filter_; vector strides_; vector paddings_; vector dilations_; + mutable enum ExecMode exec_mode_; int groups; #ifdef PADDLE_MOBILE_CL @@ -2509,10 +2523,10 @@ class QuantizeParam : public OpParam { QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + output_ = OutFrom(outputs, scope); // online // scale = max(abs(x)) - online_scale_ = GetVarValue("OutScale", outputs, scope); + online_scale_ = OpParam::GetVarValue("OutScale", outputs, scope); // offline if (HasAttr("static_scale", attrs)) { is_static_ = true; @@ -2522,14 +2536,18 @@ class QuantizeParam : public OpParam { if (HasAttr("round_type", attrs)) { round_type_ = GetAttr("round_type", attrs); } + // get paddings + paddings_ = std::vector({0, 0}); + if (HasAttr("paddings", attrs)) { + paddings_ = GetAttr>("paddings", attrs); + } } public: // op input RType *input_; // op output - RType *out_; - // + RType *output_; RType *online_scale_; // if static scale or not bool is_static_ = false; @@ -2537,7 +2555,11 @@ class QuantizeParam : public OpParam { float static_scale_ = 1.0f; // round method type // nearest_zero and nearest_even is valid currently - RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; + // optional paddings + std::vector paddings_; + int8_t padding_val_; }; #endif @@ -2551,8 +2573,8 @@ class DequantizeParam : public OpParam { DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); - activation_scale_ = GetVarValue("Scale", inputs, scope); + output_ = OutFrom(outputs, scope); + activation_scale_ = OpParam::GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale if (HasAttr("weight_scale", attrs)) { weight_scale_ = GetAttr("weight_scale", attrs); @@ -2565,11 +2587,50 @@ class DequantizeParam : public OpParam { // op input RType *input_; // op output - RType *out_; + RType *output_; RType *activation_scale_; float weight_scale_; }; #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluParam : public DequantizeParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : DequantizeParam(inputs, outputs, attrs, scope) { + // element wise add params + axis_ = OpParam::GetAttr("axis", attrs); + bias_ = OpParam::InputYFrom(inputs, scope); + // batch norm params + bn_mean_ = OpParam::GetVarValue("BNMean", inputs, scope); + bn_variance_ = OpParam::GetVarValue("BNVariance", inputs, scope); + bn_scale_ = OpParam::GetVarValue("BNScale", inputs, scope); + bn_bias_ = OpParam::GetVarValue("BNBias", inputs, scope); + epsilon_ = OpParam::GetAttr("epsilon", attrs); + // output + output_ = OpParam::OutFrom(outputs, scope); + } + + public: + // elementwise add + int axis_; + RType *bias_; + // batch norm + RType *bn_mean_; + RType *bn_variance_; + RType *bn_scale_; + RType *bn_bias_; + float epsilon_; + // output + RType *output_; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/quantize_op.cpp b/src/operators/quantize_op.cpp index 865539d7d26de41b319b4d82ed168b2ec74d722d..6dd9d75af463753008b273b93253cb986eb90e80 100644 --- a/src/operators/quantize_op.cpp +++ b/src/operators/quantize_op.cpp @@ -22,8 +22,11 @@ namespace operators { template void QuantizeOp::InferShape() const { - const auto& input_dims = this->param_.input_->dims(); - this->param_.out_->Resize(input_dims); + auto input_dims = this->param_.input_->dims(); + const std::vector &paddings = this->param_.paddings_; + input_dims[2] += 2 * paddings[0]; + input_dims[3] += 2 * paddings[1]; + this->param_.output_->Resize(input_dims); auto scale_dims = framework::make_ddim(std::vector{1}); this->param_.online_scale_->Resize(scale_dims); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index edd63a7c684eeb976d4673ef6cc9d3510c287c42..bfd125ce5b75091cfac1a2a4e2f2f025da0178dc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -155,7 +155,7 @@ if (NOT FOUND_MATCH) target_link_libraries(test-googlenet-quali paddle-mobile) # gen test - ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h) + ADD_EXECUTABLE(test-conv-op operators/test_conv_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-conv-op paddle-mobile) # gen test @@ -242,10 +242,6 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h) target_link_libraries(test-dequantize-op paddle-mobile) - # test int8 conv op - ADD_EXECUTABLE(test-int8-conv-op operators/test_int8_conv_op.cpp test_helper.h test_include.h) - target_link_libraries(test-int8-conv-op paddle-mobile) - # gen test log ADD_EXECUTABLE(test-log common/test_log.cpp) target_link_libraries(test-log paddle-mobile) @@ -368,6 +364,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h) target_link_libraries(test-multi-process paddle-mobile) + # gen test benchmark + ADD_EXECUTABLE(test-benchmark net/test_benchmark.cpp) + target_link_libraries(test-benchmark paddle-mobile) + # gen test ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h) target_link_libraries(test-eng paddle-mobile) diff --git a/test/framework/test_load_memory.cpp b/test/framework/test_load_memory.cpp index 162dba372774578952e4c306bb20a6a95c655c94..afab17d5e7e01d4060cbe92ea3228eb267d2bf32 100644 --- a/test/framework/test_load_memory.cpp +++ b/test/framework/test_load_memory.cpp @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include - #include "../test_helper.h" #include "../test_include.h" + static size_t ReadBuffer(const char *file_name, uint8_t **out) { FILE *fp; fp = fopen(file_name, "rb"); diff --git a/test/net/test_benchmark.cpp b/test/net/test_benchmark.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3378229d0fb95745fb7b779f3ce043198d77681b --- /dev/null +++ b/test/net/test_benchmark.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: " << std::endl + << "./test_benchmark fluid_model feed_shape thread_num [use_fuse]" + << std::endl; + std::cout << "use_fuse: optional, bool, default is 1\n"; + return 1; + } + bool optimize = true; + char* fluid_model = argv[1]; + char* feed_shape = argv[2]; + int thread_num = atoi(argv[3]); + if (argc == 5) { + optimize = atoi(argv[4]); + } + + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(thread_num); + auto time1 = time(); + if (paddle_mobile.Load(fluid_model, optimize)) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time2) << "ms\n"; + paddle_mobile::framework::Tensor input; + std::shared_ptr output; + std::vector dims{1, 3, 224, 224}; + if (feed_shape) { + sscanf(feed_shape, "%d,%d,%d,%d", &dims[0], &dims[1], &dims[2], &dims[3]); + } + std::cout << "feed shape: [" << dims[0] << ", " << dims[1] << ", " + << dims[2] << ", " << dims[3] << "]\n"; + paddle_mobile::framework::DDim in_shape = + paddle_mobile::framework::make_ddim(dims); + SetupTensor(&input, in_shape, 0.f, 255.f); + // warmup + for (int i = 0; i < 10; ++i) { + output = paddle_mobile.Predict(input); + } + auto time3 = time(); + for (int i = 0; i < 10; ++i) { + output = paddle_mobile.Predict(input); + } + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n"; + } + return 0; +} diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index 527f2067496eac1df1e0fb10d1dfd2ca66fe4cfd..c3379df609fc1e18b8c3545e25849f8a7ff0461b 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -20,12 +20,11 @@ int main() { #ifdef PADDLE_MOBILE_FPGA paddle_mobile::PaddleMobile paddle_mobile; #endif - #ifdef PADDLE_MOBILE_CPU paddle_mobile::PaddleMobile paddle_mobile; #endif - paddle_mobile.SetThreadNum(4); + paddle_mobile.SetThreadNum(1); bool optimize = true; auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { @@ -36,7 +35,7 @@ int main() { std::vector output; std::vector dims{1, 3, 224, 224}; GetInput(g_test_image_1x3x224x224, &input, dims); - // 预热十次 + // warmup for (int i = 0; i < 10; ++i) { output = paddle_mobile.Predict(input, dims); } @@ -46,8 +45,7 @@ int main() { } auto time4 = time(); - std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" - << std::endl; + std::cout << "predict cost: " << time_diff(time3, time4) / 10 << "ms\n"; } return 0; } diff --git a/test/operators/test_int8_conv_op.cpp b/test/operators/test_conv_op.cpp similarity index 70% rename from test/operators/test_int8_conv_op.cpp rename to test/operators/test_conv_op.cpp index 2ab40ba5833939e4456bb13bf4d5f9819a332693..bd0fbdad4d4cfba89c7ad4fac3e7a1227c9794a0 100644 --- a/test/operators/test_int8_conv_op.cpp +++ b/test/operators/test_conv_op.cpp @@ -18,7 +18,7 @@ limitations under the License. */ namespace paddle_mobile { -// Reference convolution for checking results: +// Reference convolution from Caffe for checking results. // accumulate through explicit loops over input, output, and filters. template void conv2d(const framework::Tensor *input, const framework::Tensor *filter, @@ -129,7 +129,7 @@ void conv2d(const framework::Tensor *input, const framework::Tensor *filter, } template -int TestConvOp() { +int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { int kernel_h = Kernel; int kernel_w = Kernel; int pad_h = Pad; @@ -140,10 +140,10 @@ int TestConvOp() { int dilation_w = 1; int batch_size = 1; - int input_c = 3; - int input_h = 100; - int input_w = 100; - int output_c = 10; + int input_c = in_channels; + int input_h = in_height; + int input_w = in_width; + int output_c = out_channels; framework::DDim input_shape = framework::make_ddim({batch_size, input_c, input_h, input_w}); framework::DDim filter_shape = @@ -158,7 +158,7 @@ int TestConvOp() { auto input_var = scope.get()->Var("input"); auto input = input_var->template GetMutable(); - SetupTensor(input, input_shape, -20, 20); + SetupTensor(input, input_shape, -20.0, 20.0); auto filter_var = scope.get()->Var("filter"); auto filter = filter_var->template GetMutable(); @@ -174,8 +174,9 @@ int TestConvOp() { auto *op = new operators::ConvOp("conv2d", inputs, outputs, attrs, scope); - // struct timespec ts_begin, ts_end; op->InferShape(); + op->Init(); + // struct timespec ts_begin, ts_end; // warmup // op->Run(); // clock_gettime(CLOCK_MONOTONIC, &ts_begin); @@ -202,9 +203,16 @@ int TestConvOp() { const Otype *output_data = output->data(); Otype *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { - PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], + float gap = output_data[i] - output_cmp_data[i]; + PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3, "output[%d] = %d, output_cmp[%d] = %d", i, output_data[i], i, output_cmp_data[i]); + // if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + // LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + // << ", output_cmp_data[" << i << "] = " << + // output_cmp_data[i]; + // return 1; + // } } delete op; return 0; @@ -212,68 +220,88 @@ int TestConvOp() { } // namespace paddle_mobile -int main() { +int main(int argc, char *argv[]) { + if (argc < 5) { + LOG(paddle_mobile::kLOG_INFO) + << "Usage:\n" + << " ./test-int8-conv-op in_channels in_height in_width out_channels\n" + << " params:\n" + << " -in_channels: int, input image's channels\n" + << " -in_height: int, input image's height\n" + << " -in_width: int, input image's width\n" + << " -out_channels: int, conv output channels\n"; + return 1; + } + int in_channels = atoi(argv[1]); + int in_height = atoi(argv[2]); + int in_width = atoi(argv[3]); + int out_channels = atoi(argv[4]); + // kernel = 3, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 0, stride = 2 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 1, stride = 2 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 3, stride = 2 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 1, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 3, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 5, stride = 3 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; - paddle_mobile::TestConvOp(); - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 7, pad = 3, stride = 4 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; - paddle_mobile::TestConvOp(); - LOG(paddle_mobile::kLOG_INFO) << "\n"; - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 3, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; - paddle_mobile::TestConvOp(); + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 3, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; - paddle_mobile::TestConvOp(); - LOG(paddle_mobile::kLOG_INFO) << "\n"; - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 3, pad = 1, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; - paddle_mobile::TestConvOp(); + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 3, pad = 1, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; - paddle_mobile::TestConvOp(); - LOG(paddle_mobile::kLOG_INFO) << "\n"; - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 5, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; - paddle_mobile::TestConvOp(); + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 5, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; - paddle_mobile::TestConvOp(); - LOG(paddle_mobile::kLOG_INFO) << "\n"; - + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 5, pad = 2, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; - paddle_mobile::TestConvOp(); + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); // kernel = 5, pad = 2, stride = 1 LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; - paddle_mobile::TestConvOp(); + paddle_mobile::TestConvOp(in_channels, in_height, + in_width, out_channels); } diff --git a/test/operators/test_cov_op.cpp b/test/operators/test_cov_op.cpp deleted file mode 100644 index 535d82c4be6cedcc77e9e9cf97a9a813f4ca518d..0000000000000000000000000000000000000000 --- a/test/operators/test_cov_op.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "../test_include.h" -#include "operators/conv_op.h" - -int main() { - paddle_mobile::framework::Loader loader; - // ../models/image_classification_resnet.inference.model - auto program = loader.Load(g_googlenet); - - PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, - "program file read fail"); - - Executor4Test> - executor(program, "conv2d"); - - paddle_mobile::framework::Tensor input; - GetInput(g_test_image_1x3x224x224, &input, {1, 3, 224, 224}); - // // use SetupTensor if not has local input image . - // SetupTensor(&input, {1, 3, 224, 224}, static_cast(0), - // static_cast(1)); - - auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 112, 112}); - auto output = executor.Predict(input, "data", "conv2d_0.tmp_0", out_ddim); - - auto output_ptr = output->data(); - for (int j = 0; j < 20; ++j) { - DLOG << " value of output: " << output_ptr[j]; - } - return 0; -} diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index 5b1f276bebb0b956a7907a500645612c5aeaf8f9..9988661bcb898daa5e79b6d22d65d90cfa03c668 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/quantize_op.h" namespace paddle_mobile { - -static float find_abs_max(const Tensor *input) { - float max_abs = 0.f; - const float *x = input->data(); - size_t size = input->numel(); - for (size_t i = 0; i < size; ++i) { - float value = std::abs(x[i]); - if (value > max_abs) { - max_abs = value; - } - } - return max_abs; +namespace round { +enum RoundType { + RoundToEven = 0, + RoundAwayZero = 1, + RoundTowardsZero = 2, +}; } -static void quantize_round_to_even(const Tensor *input, const float scale, - Tensor *output) { - const float *x = input->data(); - int8_t *y = output->mutable_data(); - size_t size = input->numel(); - for (size_t i = 0; i < size; ++i) { - float value = x[i] * scale; - float v = round(value); +template +struct Round { + int8_t operator()(float x); +}; + +template <> +struct Round { + int8_t operator()(float x) { return std::round(x); } +}; + +template <> +struct Round { + int8_t operator()(float x) { return int8_t(x); } +}; + +template <> +struct Round { + int8_t operator()(float x) { + int8_t ret = 0; + float v = std::round(x); int32_t q = (int32_t)v; - if (abs(abs(q - value) - 0.5) > 0) { - y[i] = q; + if (abs(abs(q - x) - 0.5) > 0) { + ret = q; } else { if (abs(q) % 2 == 0) { - y[i] = q; + ret = q; } else { - y[i] = q + ((q > 0) ? -1 : 1); + ret = q + ((q > 0) ? -1 : 1); + } + } + return ret; + } +}; + +template +static void quantize(const Tensor *input, const float scale, const int pad, + const int8_t pad_val, Tensor *output) { + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int input_h = input->dims()[2]; + int input_w = input->dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + size_t input_spatial = input_h * input_w; + size_t output_spatial = output_h * output_w; + const float *x = input->data(); + int8_t *y = output->mutable_data(); + + for (int nc = 0; nc < batch_size * channels; ++nc) { + const float *xh = x + nc * input_spatial; + int8_t *yh = y + nc * output_spatial; + // pad top + for (int h = 0; h < pad; ++h, yh += output_w) { + for (int w = 0; w < output_w; ++w) { + yh[w] = pad_val; + } + } + for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) { + // pad left + for (int w = 0; w < pad; ++w) { + yh[w] = pad_val; + } + for (int w = 0; w < input_w; ++w) { + yh[w + pad] = Round()(xh[w] * scale); + } + // pad right + for (int w = 0; w < pad; ++w) { + yh[pad + input_w + w] = pad_val; + } + } + // pad bottom + for (int h = 0; h < pad; ++h, yh += output_w) { + for (int w = 0; w < output_w; ++w) { + yh[w] = pad_val; } } } } -static void quantize_round_to_nearest(const Tensor *input, const float scale, - Tensor *output) { +static float find_abs_max(const Tensor *input) { + float max_abs = 0.f; const float *x = input->data(); - int8_t *y = output->mutable_data(); size_t size = input->numel(); for (size_t i = 0; i < size; ++i) { - y[i] = round(x[i] * scale); + float value = std::abs(x[i]); + if (value > max_abs) { + max_abs = value; + } } + return max_abs; } -int TestQuqntizeOp() { - framework::DDim dim = framework::make_ddim({1, 3, 224, 224}); +int TestQuqntizeOp(int argc, char *argv[]) { + if (argc < 5) { + std::cout + << "Usage: ./test-quantize-op batch_size channel height width [pad]" + << std::endl; + return 1; + } + int pad = 0; + int batch_size = atoi(argv[1]); + int channel = atoi(argv[2]); + int height = atoi(argv[3]); + int width = atoi(argv[4]); + if (argc == 6) { + pad = atoi(argv[5]); + } + std::cout << "batch_size: " << batch_size << ", channel: " << channel + << ", height: " << height << ", width: " << width << std::endl; + framework::DDim dim = + framework::make_ddim({batch_size, channel, height, width}); VariableNameMap inputs; VariableNameMap outputs; @@ -80,6 +153,7 @@ int TestQuqntizeOp() { auto output_scale_var = scope.get()->Var("output_scale"); framework::AttributeMap attrs; + attrs["paddings"].Set>(std::vector({pad, pad})); auto *op = new operators::QuantizeOp("quantize", inputs, outputs, attrs, scope); op->InferShape(); @@ -96,10 +170,11 @@ int TestQuqntizeOp() { output_scale_cmp, output_scale_data[0]); framework::Tensor output_cmp; - output_cmp.Resize(dim); + output_cmp.Resize(output->dims()); float scale = 127 / output_scale_cmp; - // quantize_round_to_even(input, scale, &output_cmp); - quantize_round_to_nearest(input, scale, &output_cmp); + // quantize(input, scale, pad, 0, &output_cmp); + // quantize(input, scale, pad, 0, &output_cmp); + quantize(input, scale, pad, 0, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], @@ -113,4 +188,6 @@ int TestQuqntizeOp() { } // namespace paddle_mobile -int main() { return paddle_mobile::TestQuqntizeOp(); } +int main(int argc, char *argv[]) { + return paddle_mobile::TestQuqntizeOp(argc, argv); +} diff --git a/tools/build.sh b/tools/build.sh index 3489ccd7397ee79ad16256519dba4e239a4c53a0..6e96404c3eac36de53b810d563720d485816f0f9 100755 --- a/tools/build.sh +++ b/tools/build.sh @@ -212,4 +212,4 @@ else else build_error "$1" fi -fi \ No newline at end of file +fi diff --git a/tools/op.cmake b/tools/op.cmake index 3a4a0597a44694c4edea8173af47627cb5680df2..98a5ce437ae6520a4cc27f9fceeadaeb30ba6e99 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -249,6 +249,7 @@ if(NOT FOUND_MATCH) set(SUM_OP ON) set(QUANT_OP ON) set(DEQUANT_OP ON) + set(FUSION_DEQUANT_ADD_BN_RELU ON) endif() # option(BATCHNORM_OP "" ON) @@ -450,6 +451,9 @@ endif() if (DEQUANT_OP) add_definitions(-DDEQUANT_OP) endif() +if (FUSION_DEQUANT_ADD_BN_RELU) + add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) +endif() if (TANH_OP) add_definitions(-DTANH_OP) @@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP) endif() if (FUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP) -endif() \ No newline at end of file +endif()