diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index e0b40cebf7f14e0b927e4666d63e740213918333..611b134eaab7376dfba9ef3294f8fa14bf13e0da 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -64,9 +64,9 @@ void OperatorBase::Run() { for (const auto key : input_keys) { auto var_vec_in = inputs_.at(key); for (int i = 0; i < var_vec_in.size(); ++i) { - auto vari = scope_->FindVar(var_vec_in[i]); + auto vari = this->scope_->FindVar(var_vec_in[i]); if (vari->IsInitialized()) { - Tensor *tensor = vari->template GetMutable(); + const Tensor *tensor = vari->template Get(); if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor; } } @@ -76,7 +76,7 @@ void OperatorBase::Run() { for (int i = 0; i < var_vec_out.size(); ++i) { auto vari = scope_->FindVar(var_vec_out[i]); if (vari->IsInitialized()) { - Tensor *tensor = vari->template GetMutable(); + const Tensor *tensor = vari->template Get(); if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor; } } @@ -97,10 +97,10 @@ void OperatorBase::Run() { auto vari = scope_->FindVar(var_vec_in[i]); if (vari->IsInitialized()) { if (type_ == "feed") { - Tensor *tensor = vari->template GetMutable(); + const Tensor *tensor = vari->template Get(); if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor; } else { - CLImage *cl_image = vari->template GetMutable(); + const CLImage *cl_image = vari->template Get(); if (cl_image) { DLOG << type_ << " input- " << key << "=" << *cl_image; } @@ -114,12 +114,12 @@ void OperatorBase::Run() { auto vari = scope_->FindVar(var_vec_out[i]); if (vari->IsInitialized()) { if (type_ == "fetch") { - Tensor *tensor = vari->template GetMutable(); + const Tensor *tensor = vari->template Get(); if (tensor) { DLOG << type_ << " output- " << key << "=" << *tensor; } } else { - CLImage *cl_image = vari->template GetMutable(); + const CLImage *cl_image = vari->template Get(); if (cl_image) { DLOG << type_ << " output- " << key << "=" << *cl_image; } diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index 74ae2ef9ee3b60713623f54577d52d7a58a441b4..dd3b1b7317ecbebc1f6c65da66db65b7368f23f1 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -14,6 +14,7 @@ #include "io/api_paddle_mobile.h" #include +#include "common/enforce.h" #include "framework/tensor.h" namespace paddle_mobile { diff --git a/src/io/api_paddle_mobile.h b/src/io/api_paddle_mobile.h index a552d9152d4eda18da8256ed3f1fde5aa6a29e4b..bca169a2ed7786ce5dbd58ddecf6d637e4c4854c 100644 --- a/src/io/api_paddle_mobile.h +++ b/src/io/api_paddle_mobile.h @@ -12,19 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -/* - * This file contains the implementation of inference API with Anakin engine - * embeded, this API can only support Anakin models. - */ - #pragma once #include -#include "io/paddle_inference_api.h" - -// from paddle_mobile -#include "common/enforce.h" #include "common/types.h" +#include "io/paddle_inference_api.h" #include "io/paddle_mobile.h" namespace paddle_mobile { diff --git a/src/io/paddle_inference_api.h b/src/io/paddle_inference_api.h index 5326f864a4b5238c8498ee1fe9e5810ca0a657cf..afbd93dede6b5406f572c3b20b48a5904660e5e3 100644 --- a/src/io/paddle_inference_api.h +++ b/src/io/paddle_inference_api.h @@ -104,6 +104,8 @@ class PaddlePredictor { // The common configs for all the predictors. struct Config { std::string model_dir; // path to the model directory. + std::string prog_file; + std::string param_file; }; protected: @@ -128,9 +130,8 @@ struct PaddleMobileConfig : public PaddlePredictor::Config { int batch_size = 1; bool optimize = true; bool quantification = false; + bool lod_mode = false; int thread_num = 1; - std::string prog_file; - std::string param_file; std::string cl_path; struct PaddleModelMemoryPack memory_pack; }; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index addaefad1466e7157a553bedbc869377723a9213..995cd175baa51bdc7fedbb43d6345cc5ce6e6be2 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -15,6 +15,9 @@ limitations under the License. */ #include "io/paddle_mobile.h" #include #include "common/common.h" +#ifdef _OPENMP +#include +#endif // _OPENMP #ifdef PADDLE_MOBILE_CL #include #include "framework/cl/cl_tensor.h" @@ -33,7 +36,7 @@ void PaddleMobile::SetThreadNum(int num) { template PMStatus PaddleMobile::Load(const std::string &dirname, bool optimize, bool quantification, - int batch_size, bool loddable) { + int batch_size, bool lod_mode) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -43,7 +46,7 @@ PMStatus PaddleMobile::Load(const std::string &dirname, if (executor_.get() == nullptr) { executor_ = std::make_shared>( loader_->Load(dirname, optimize, quantification), batch_size, optimize, - loddable); + lod_mode); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -55,7 +58,7 @@ template PMStatus PaddleMobile::Load(const std::string &model_path, const std::string ¶_path, bool optimize, bool quantification, - int batch_size, bool loddable) { + int batch_size, bool lod_mode) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -65,7 +68,7 @@ PMStatus PaddleMobile::Load(const std::string &model_path, if (executor_.get() == nullptr) { executor_ = std::make_shared>( loader_->Load(model_path, para_path, optimize, quantification), - batch_size, optimize, loddable); + batch_size, optimize, lod_mode); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -73,11 +76,26 @@ PMStatus PaddleMobile::Load(const std::string &model_path, return PMSuccess; } +template +PMStatus PaddleMobile::Load(const PaddleMobileConfig &config) { + if (!config.model_dir.empty()) { + return this->Load(config.model_dir, config.optimize, config.quantification, + config.batch_size, config.lod_mode); + } else if (!config.prog_file.empty() && !config.param_file.empty()) { + return this->Load(config.prog_file, config.param_file, config.optimize, + config.quantification, config.batch_size, + config.lod_mode); + } else { + LOG(kLOG_ERROR) << "Failed to load inference model"; + return PMNotInitialized; + } +} + template bool PaddleMobile::LoadCombinedMemory( size_t model_len, const uint8_t *model_buf, size_t combined_params_len, uint8_t *combined_params_buf, bool optimize, bool quantification, - int batch_size, bool loddable) { + int batch_size, bool lod_mode) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -88,7 +106,7 @@ bool PaddleMobile::LoadCombinedMemory( loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len, combined_params_buf, optimize, quantification), - batch_size, optimize, loddable); + batch_size, optimize, lod_mode); } else { LOG(kLOG_INFO) << "executor inited"; } diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index b98da215eb4dac5af4e424461f6a233ccf33a612..22ffd6ce9c9c77ec119511f20b13f77ca2e706ac 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -18,15 +18,12 @@ limitations under the License. */ #include #include #include -#ifdef _OPENMP -#include -#endif // _OPENMP - #include "common/types.h" #include "framework/executor.h" #include "framework/load_ops.h" #include "framework/loader.h" #include "framework/tensor.h" +#include "io/paddle_inference_api.h" #ifdef PADDLE_MOBILE_CL #include "framework/cl/cl_engine.h" #endif @@ -46,10 +43,12 @@ class PaddleMobile { PMStatus Load(const std::string &dirname, const bool optimize = false, const bool quantification = false, const int batch_size = 1, - const bool lod = false); + const bool lod_mode = false); PMStatus Load(const std::string &model_path, const std::string ¶_path, const bool optimize = false, const bool quantification = false, - const int batch_size = 1, const bool lod = false); + const int batch_size = 1, const bool lod_mode = false); + + PMStatus Load(const PaddleMobileConfig &config); PMStatus Predict(const framework::Tensor &input); PMStatus Predict(const framework::LoDTensor &input); @@ -75,7 +74,7 @@ class PaddleMobile { size_t combined_params_len, uint8_t *combined_params_buf, bool optimize = false, bool quantification = false, int batch_size = 1, - bool loddable = false); + bool lod_mode = false); void SetThreadNum(int count); void Clear(); diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 1eba4cd3304b945a09c2f48131abe24f5c07ab07..f7f55b790db9cf8151b792129cfa858075dbe9a0 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -24,15 +24,26 @@ template <> bool ConvKernel::Init(ConvParam *param) { bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == 3; + bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 5; bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] && param->Input()->dims()[1] == param->Output()->dims()[1]; + bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1]; if (param->Filter()->type() == typeid(int8_t)) { +#ifndef __aarch64__ if (depth3x3 && param->Strides()[0] < 3 && param->Strides()[0] == param->Strides()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; + } else if (depth5x5 && param->Strides()[0] < 2 && + param->Strides()[0] == param->Strides()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_INT8; } else { +#endif // __aarch64__ param->ExecMode() = ConvParam::EXEC_GEMM_INT8; +#ifndef __aarch64__ } +#endif // __aarch64__ } else { if (depth3x3 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1 && param->Paddings()[0] == 1 && @@ -47,6 +58,9 @@ bool ConvKernel::Init(ConvParam *param) { param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT; #ifndef __aarch64__ + } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 1) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && param->Strides()[0] == 1 && param->Dilations()[0] == 1 && @@ -72,9 +86,14 @@ void ConvKernel::Compute(const ConvParam ¶m) { case ConvParam::EXEC_GEMM_INT8: GemmConv(param); break; +#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE3x3_INT8: DepthwiseConv3x3(param); break; + case ConvParam::EXEC_DEPTHWISE5x5_INT8: + DepthwiseConv5x5(param); + break; +#endif // __aarch64__ case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), nullptr, false, false); @@ -87,9 +106,14 @@ void ConvKernel::Compute(const ConvParam ¶m) { math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), nullptr, false, false); break; +#ifndef __aarch64__ + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); break; +#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; diff --git a/src/operators/kernel/arm/pool_kernel.cpp b/src/operators/kernel/arm/pool_kernel.cpp index 58d6359efa48b0db215269a631e7e4cb57c429d9..703a73d64bc9726c477952e7a2cbfcf59be6c5fb 100644 --- a/src/operators/kernel/arm/pool_kernel.cpp +++ b/src/operators/kernel/arm/pool_kernel.cpp @@ -15,7 +15,8 @@ limitations under the License. */ #ifdef POOL_OP #include "operators/kernel/pool_kernel.h" -#include "../central-arm-func/pool_arm_func.h" +#include "operators/kernel/central-arm-func/pool_arm_func.h" + namespace paddle_mobile { namespace operators { @@ -28,7 +29,8 @@ template <> void PoolKernel::Compute(const PoolParam ¶m) { PoolCompute(param); } + } // namespace operators } // namespace paddle_mobile -#endif +#endif // POOL_OP diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 848f1b113b78d145f730255a33a82006b94e03f0..86a3c7a9694e8d686f41911ea4af784a33c2cd0a 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "operators/math/conv_func.h" #include "operators/math/depthwise_conv3x3.h" +#include "operators/math/depthwise_conv5x5.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/pad.h" @@ -160,6 +161,7 @@ inline void WinogradConv3x3(const ConvParam ¶m) { } } +#ifndef __aarch64__ template inline void DepthwiseConv3x3(const ConvParam ¶m) { const Tensor *input = param.Input(); @@ -180,14 +182,34 @@ inline void DepthwiseConv3x3(const ConvParam ¶m) { math::DepthwiseConv3x3S2(in_batch, *filter, paddings, &out_batch); } else { - // math::DepthwiseConv3x3(input_pad, *filter, - // &out_batch); - PADDLE_MOBILE_THROW_EXCEPTION( - "Depthwise conv with generic strides has not been implemented."); + GemmConv(param); } } } +template +inline void DepthwiseConv5x5(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.Filter(); + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = input->dims()[0]; + Tensor *output = param.Output(); + output->mutable_data(); + + if (strides[0] == 1) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + math::DepthwiseConv5x5S1(in_batch, *filter, paddings, + &out_batch); + } + } else { + GemmConv(param); + } +} +#endif // __aarch64__ + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h index 19561d6b84e82b96463abc553042aa56ae38036d..df78b96147b270b592eea68668550b0c55fde0bf 100644 --- a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h @@ -59,12 +59,11 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam ¶m) { const float *input = input_data + offset; const float bias = bias_data[j]; float *output = output_data + offset; - int remain = elementwise_num; #if defined(__ARM_NEON__) || defined(__ARM_NEON) int loop = elementwise_num >> 0x4; - remain = elementwise_num & 0xF; + int remain = elementwise_num & 0xF; + float32x4_t rb = vdupq_n_f32(bias); for (int k = 0; k < loop; ++k) { - float32x4_t rb = vdupq_n_f32(bias); float32x4_t r0 = vld1q_f32(input); float32x4_t r1 = vld1q_f32(input + 4); float32x4_t r2 = vld1q_f32(input + 8); @@ -80,10 +79,46 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam ¶m) { input += 16; output += 16; } -#endif - for (int k = 0; k < remain; ++k) { + if (remain >= 8) { + float32x4_t r0 = vld1q_f32(input); + float32x4_t r1 = vld1q_f32(input + 4); + r0 = vaddq_f32(r0, rb); + r1 = vaddq_f32(r1, rb); + vst1q_f32(output, r0); + vst1q_f32(output + 4, r1); + input += 8; + output += 8; + remain -= 8; + } + if (remain >= 4) { + float32x4_t r0 = vld1q_f32(input); + r0 = vaddq_f32(r0, rb); + vst1q_f32(output, r0); + input += 4; + output += 4; + remain -= 4; + } + if (remain > 0) { + float32x4_t r0 = vld1q_f32(input); + r0 = vaddq_f32(r0, rb); + switch (remain) { + case 1: + vst1q_lane_f32(output, r0, 0); + break; + case 2: + vst1_f32(output, vget_low_f32(r0)); + break; + case 3: + vst1_f32(output, vget_low_f32(r0)); + vst1q_lane_f32(output, r0, 2); + break; + } + } +#else + for (int k = 0; k < elementwise_num; ++k) { output[k] = input[k] + bias; } +#endif // __ARM_NEON__ } } } diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 757d64480fa2fba46ba599a7f5cf9aaddfa5567a..b8086b4ecbc2592c3789a7d176eefb02bb02ada5 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef POOL_OP + #pragma once #include @@ -54,8 +55,24 @@ void PoolCompute(const PoolParam ¶m) { } else { math::Pooling()(*input, ksize, strides, paddings, output); } - } else { - // Others + } + } else if (ksize[0] == 2 && ksize[0] == ksize[1]) { + if (pooling_type == "max" && strides[0] == strides[1]) { + if (strides[0] == 1) { + math::Pooling2x2()(*input, paddings, output); + } else if (strides[0] == 2) { + math::Pooling2x2()(*input, paddings, output); + } else { + math::Pooling()(*input, ksize, strides, paddings, output); + } + } else if (pooling_type == "avg" && strides[0] == strides[1]) { + if (strides[0] == 1) { + math::Pooling2x2()(*input, paddings, output); + } else if (strides[0] == 2) { + math::Pooling2x2()(*input, paddings, output); + } else { + math::Pooling()(*input, ksize, strides, paddings, output); + } } } else { if (pooling_type == "max") { diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 90edc3111b3f54255af01c1e511bd74955e75d35..292801b5f581d9301973de3cf852141391104222 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -253,7 +253,6 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input, framework::Tensor *output, framework::Tensor *bias, bool if_bias, bool if_relu) { #if __ARM_NEON - const float *bias_data = bias->data(); const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); const int h = static_cast(input->dims()[2]); @@ -267,6 +266,11 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input, const int lb = (h - 1) * w; const int rb = h * w - 1; + const float *bias_data; + if (if_bias) { + bias_data = bias->data(); + } + float32x4_t zero = vdupq_n_f32(0.0); for (int b = 0; b < batch_size; ++b) { @@ -1966,7 +1970,6 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, framework::Tensor *output, framework::Tensor *bias, bool if_bias, bool if_relu) { #if __ARM_NEON - const int batch_size = static_cast(input->dims()[0]); const int input_channel = static_cast(input->dims()[1]); @@ -1983,7 +1986,12 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, for (int c = 0; c < input_channel; c++) { const float *filter_data = filter->data() + c * 9; const float *input_data = input->data() + c * inhxw; - const float *bias_data = bias->data() + c; + const float *bias_data; + float32x4_t biasv; + if (if_bias) { + bias_data = bias->data() + c; + biasv = vld1q_dup_f32(bias_data); + } float *output_data = output->data() + c * outhxw; float w00 = filter_data[0]; float w01 = filter_data[1]; @@ -1994,7 +2002,6 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, float w20 = filter_data[6]; float w21 = filter_data[7]; float w22 = filter_data[8]; - float32x4_t biasv = vld1q_dup_f32(bias_data); for (int i = 0; i < output_height; i += 1) { for (int m = 0; m < output_width - 2; m += 3) { float *output_ptr = output_data + i * output_width + m; diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index 9b4c6096ecdbd7adee27728ebaae47149392dad9..91e682c14590a10fc393aaefb5d37c015065fc0a 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -14,185 +14,13 @@ limitations under the License. */ #if defined(__ARM_NEON__) && !defined(__aarch64__) -#include "operators/math/depthwise_conv3x3.h" -#ifdef __ARM_NEON__ #include -#endif +#include "operators/math/depthwise_conv3x3.h" namespace paddle_mobile { namespace operators { namespace math { -template -inline void Depth3x3ValidColLoadInput(const int8_t *input, const int input_w, - const int valid_cols, int16x8_t *y0, - int16x8_t *y1, int16x8_t *y2) { - PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride); -} - -template <> -inline void Depth3x3ValidColLoadInput<1>(const int8_t *input, const int input_w, - const int valid_cols, int16x8_t *y0, - int16x8_t *y1, int16x8_t *y2) { - int8_t fake_input[3][8]; - if (valid_cols == 1) { - for (int i = 0; i < 8; ++i, input += input_w) { - fake_input[0][i] = input[0]; - } - } else if (valid_cols == 2) { - for (int i = 0; i < 8; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - } - } else { - for (int i = 0; i < 8; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - fake_input[2][i] = input[2]; - } - } - int8x8_t input0 = vld1_s8(fake_input[0]); - int8x8_t input1 = vld1_s8(fake_input[1]); - int8x8_t input2 = vld1_s8(fake_input[2]); - y0[0] = vmovl_s8(input0); - y1[0] = vmovl_s8(input1); - y2[0] = vmovl_s8(input2); - y0[1] = vextq_s16(y0[0], y0[0], 1); - y0[2] = vextq_s16(y0[0], y0[0], 2); - y1[1] = vextq_s16(y1[0], y1[0], 1); - y1[2] = vextq_s16(y1[0], y1[0], 2); - y2[1] = vextq_s16(y2[0], y2[0], 1); - y2[2] = vextq_s16(y2[0], y2[0], 2); -} - -template <> -inline void Depth3x3ValidColLoadInput<2>(const int8_t *input, const int input_w, - const int valid_cols, int16x8_t *y0, - int16x8_t *y1, int16x8_t *y2) { - int8_t fake_input[3][13]; - if (valid_cols == 1) { - for (int i = 0; i < 13; ++i, input += input_w) { - fake_input[0][i] = input[0]; - } - } else if (valid_cols == 2) { - for (int i = 0; i < 13; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - } - } else { - for (int i = 0; i < 13; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - fake_input[2][i] = input[2]; - } - } - int8x8x2_t input0 = vld2_s8(fake_input[0]); - int8x8x2_t input1 = vld2_s8(fake_input[1]); - int8x8x2_t input2 = vld2_s8(fake_input[2]); - y0[0] = vmovl_s8(input0.val[0]); - y0[1] = vmovl_s8(input0.val[1]); - y0[2] = vextq_s16(y0[0], y0[0], 1); - y1[0] = vmovl_s8(input1.val[0]); - y1[1] = vmovl_s8(input1.val[1]); - y1[2] = vextq_s16(y1[0], y1[0], 1); - y2[0] = vmovl_s8(input2.val[0]); - y2[1] = vmovl_s8(input2.val[1]); - y2[2] = vextq_s16(y2[0], y2[0], 1); -} - -template -inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter, - const int h_output, const int h_output_end, - const int w_output, const int input_h, - const int input_w, const int padding_h, - const int padding_w, const int output_w, - int32_t *output) { - const int w_in_start = -padding_w + w_output * Stride_w; - const int w_in_end = w_in_start + 3; - const int w_start = w_in_start > 0 ? w_in_start : 0; - const int w_end = w_in_end < input_w ? w_in_end : input_w; - int remain_start = h_output; - -#ifdef __ARM_NEON__ - int output_tiles = (h_output_end - h_output) / 6; - remain_start = h_output + output_tiles * 6; - int input_h_start = h_output * Stride_h - padding_h; - size_t input_offset = input_h_start * input_w + w_start; - size_t output_offset = h_output * output_w + w_output; - int16x8_t _input[3][3]; - int16x4_t _kernel[3]; - int32x4_t _sum0, _sum1; - const int8_t *filter_ptr = filter; - asm volatile( - "mov r0, #3 \n" - "vld1.s8 d10, [%[filter]], r0 \n" - "vld1.s8 d11, [%[filter]], r0 \n" - "vld1.s8 d12, [%[filter]] \n" - "vtrn.8 d10, d11 \n" - "vtrn.8 d12, d13 \n" - "vtrn.16 d10, d12 \n" - "vtrn.16 d11, d13 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d11 \n" - "vmovl.s8 q9, d12 \n" - "vmov.32 %[_kernel0], d14 \n" - "vmov.32 %[_kernel1], d16 \n" - "vmov.32 %[_kernel2], d18 \n" - : [_kernel0] "+w"(_kernel[0]), [_kernel1] "+w"(_kernel[1]), - [_kernel2] "+w"(_kernel[2]) - : [filter] "r"(filter_ptr) - : "memory", "q5", "q6", "q7", "q8", "q9", "r0"); - int valid_cols = w_end - w_start; - for (int h = 0; h < output_tiles * 6; h += 6) { - int32_t *output0 = output + output_offset; - int32_t *output1 = output0 + output_w; - int32_t *output2 = output1 + output_w; - int32_t *output3 = output2 + output_w; - int32_t *output4 = output3 + output_w; - int32_t *output5 = output4 + output_w; - Depth3x3ValidColLoadInput(input + input_offset, input_w, - valid_cols, _input[0], _input[1], - _input[2]); - _sum0 = veorq_s32(_sum0, _sum0); - _sum1 = veorq_s32(_sum1, _sum1); - for (int w_in = 0; w_in < valid_cols; ++w_in) { - int index = w_in + w_start - w_in_start; - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][0]), - _kernel[index], 0); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][1]), - _kernel[index], 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][2]), - _kernel[index], 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][0]), - _kernel[index], 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][1]), - _kernel[index], 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][2]), - _kernel[index], 2); - } - vst1q_lane_s32(output0, _sum0, 0); - vst1q_lane_s32(output1, _sum0, 1); - vst1q_lane_s32(output2, _sum0, 2); - vst1q_lane_s32(output3, _sum0, 3); - vst1q_lane_s32(output4, _sum1, 0); - vst1q_lane_s32(output5, _sum1, 1); - input_offset += 6 * Stride_h * input_w; - output_offset += 6 * output_w; - } -#endif - for (int h = remain_start; h < h_output_end; ++h) { - int32_t value = 0; - const int h_in_start = -padding_h + h * Stride_h; - for (int i = 0; i < 3; ++i) { - for (int w_in = w_start; w_in < w_end; ++w_in) { - value += filter[i * 3 + (w_in - w_in_start)] * - input[(h_in_start + i) * input_w + w_in]; - } - } - output[h * output_w + w_output] = value; - } -} - #define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ for (int w = start; w < end; ++w) { \ const int w_in_start = -padding_w + w * Stride_w; \ @@ -209,34 +37,19 @@ inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter, output_ptr[w] = value; \ } -template -inline void Depth3x3NormalRowLoadInput(const int8_t *input, - int16x8_t &y0, // NOLINT - int16x8_t &y1, // NOLINT - int16x8_t &y2) { // NOLINT - PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride); -} - -template <> -inline void Depth3x3NormalRowLoadInput<1>(const int8_t *input, - int16x8_t &y0, // NOLINT - int16x8_t &y1, // NOLINT - int16x8_t &y2) { // NOLINT - int8x8_t x0 = vld1_s8(input); - y0 = vmovl_s8(x0); - y1 = vextq_s16(y0, y0, 1); - y2 = vextq_s16(y1, y1, 1); +template +inline void Depth3x3NormalRowLoadInput(const int8_t *input, int16x8_t *y) { + y[0] = vmovl_s8(vld1_s8(input)); + y[1] = vextq_s16(y[0], y[0], 1); + y[2] = vextq_s16(y[1], y[1], 1); } template <> -inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input, - int16x8_t &y0, // NOLINT - int16x8_t &y1, // NOLINT - int16x8_t &y2) { // NOLINT +inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input, int16x8_t *y) { int8x8x2_t x0 = vld2_s8(input); - y0 = vmovl_s8(x0.val[0]); - y1 = vmovl_s8(x0.val[1]); - y2 = vextq_s16(y0, y0, 1); + y[0] = vmovl_s8(x0.val[0]); + y[1] = vmovl_s8(x0.val[1]); + y[2] = vextq_s16(y[0], y[0], 1); } template @@ -244,15 +57,14 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, const int h_output, const int input_h, const int input_w, const int padding_h, const int padding_w, const int output_w, - int32_t *output) { + int32_t *output, int16x4_t *ker) { const int h_in_start = -padding_h + h_output * Stride_h; const int h_in_end = h_in_start + 3; const int h_start = h_in_start > 0 ? h_in_start : 0; const int h_end = h_in_end < input_h ? h_in_end : input_h; - int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; - int valid_w_end = output_w - valid_w_start; - + const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1; int32_t *output_ptr = output + h_output * output_w; // border left DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) @@ -262,14 +74,7 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, int output_tiles = (valid_w_end - valid_w_start) / 6; remain_start = valid_w_start + output_tiles * 6; int32x4_t _sum0, _sum1; - int16x8_t y0, y1, y2; - int16x4_t _kernel[3]; - for (int h_in = h_start; h_in < h_end; ++h_in) { - int index = h_in - h_in_start; - int8x8_t w0 = vld1_s8(filter + index * 3); - int16x8_t w1 = vmovl_s8(w0); - _kernel[index] = vget_low_s16(w1); - } + int16x8_t _y[3]; for (int w = 0; w < output_tiles * 6; w += 6) { _sum0 = veorq_s32(_sum0, _sum0); _sum1 = veorq_s32(_sum1, _sum1); @@ -278,19 +83,18 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, for (int h_in = h_start; h_in < h_end; ++h_in) { int index = h_in - h_in_start; Depth3x3NormalRowLoadInput( - input + h_in * input_w + input_w_offset, y0, y1, y2); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y0), _kernel[index], 0); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y1), _kernel[index], 1); - _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y2), _kernel[index], 2); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y0), _kernel[index], 0); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y1), _kernel[index], 1); - _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y2), _kernel[index], 2); + input + h_in * input_w + input_w_offset, _y); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[0]), ker[index], 0); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[1]), ker[index], 1); + _sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[2]), ker[index], 2); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[0]), ker[index], 0); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[1]), ker[index], 1); + _sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[2]), ker[index], 2); } vst1q_s32(output_ptr + output_offset, _sum0); - vst1q_lane_s32(output_ptr + output_offset + 4, _sum1, 0); - vst1q_lane_s32(output_ptr + output_offset + 5, _sum1, 1); + vst1_s32(output_ptr + output_offset + 4, vget_low_s32(_sum1)); } -#endif +#endif // __ARM_NEON__ for (int w = remain_start; w < valid_w_end; ++w) { int32_t value = 0; int input_start = -padding_w + w * Stride_w; @@ -306,14 +110,6 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) } -// template<> -// void DepthwiseConv3x3( -// const framework::Tensor *input, const framework::Tensor *filter, -// const std::vector &strides, framework::Tensor *output) { -// PADDLE_MOBILE_THROW_EXCEPTION( -// "Depthwise conv with generic strides has not been implemented."); -// } - template <> void DepthwiseConv3x3S1(const framework::Tensor &input, const framework::Tensor &filter, @@ -342,29 +138,22 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, const int8_t *input_ptr = input_data + g * image_size; const int8_t *filter_ptr = filter_data + g * 9; int32_t *output_ptr = out_data + g * out_image_size; + + const int8_t *filter_ptr0 = filter_ptr; + const int8_t *filter_ptr1 = filter_ptr0 + 3; + const int8_t *filter_ptr2 = filter_ptr1 + 3; + int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0))); + int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1))); + int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2))); + int16x8_t _ker0 = vcombine_s16(_k0, _k1); + int16x8_t _ker1 = vcombine_s16(_k2, _k2); + int16x4_t zero = vdup_n_s16(0); + int16x4_t _ker[3] = {_k0, _k1, _k2}; // top for (int h = 0; h < valid_h_start; ++h) { DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, input_w, padding_h, padding_w, output_w, - output_ptr); - } - // left - for (int w = 0; w < valid_w_start; ++w) { - DepthwiseConv3x3ValidCol<1, 1>( - input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, - input_w, padding_h, padding_w, output_w, output_ptr); - } - // right - for (int w = valid_w_end; w < output_w; ++w) { - DepthwiseConv3x3ValidCol<1, 1>( - input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, - input_w, padding_h, padding_w, output_w, output_ptr); - } - // bottom - for (int h = valid_h_end; h < output_h; ++h) { - DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, - input_w, padding_h, padding_w, output_w, - output_ptr); + output_ptr, _ker); } // valid int output_w_tiles = valid_w / 6; @@ -376,334 +165,419 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, const int8_t *input_ptr3 = input_ptr2 + input_w; const int8_t *input_ptr4 = input_ptr3 + input_w; const int8_t *input_ptr5 = input_ptr4 + input_w; - int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr0 = output_ptr + h * output_w; int32_t *output_ptr1 = output_ptr0 + output_w; int32_t *output_ptr2 = output_ptr1 + output_w; int32_t *output_ptr3 = output_ptr2 + output_w; + // pad left + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4))); + int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5))); + int32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0; + output_ptr1[w] = 0; + output_ptr2[w] = 0; + output_ptr3[w] = 0; + } else { + row0 = vext_s16(zero, row0, 3); + row1 = vext_s16(zero, row1, 3); + row2 = vext_s16(zero, row2, 3); + row3 = vext_s16(zero, row3, 3); + row4 = vext_s16(zero, row4, 3); + row5 = vext_s16(zero, row5, 3); + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + acc = vmull_s16(row1, _ker[0]); + acc = vmlal_s16(acc, row2, _ker[1]); + acc = vmlal_s16(acc, row3, _ker[2]); + output_ptr1[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + acc = vmull_s16(row2, _ker[0]); + acc = vmlal_s16(acc, row3, _ker[1]); + acc = vmlal_s16(acc, row4, _ker[2]); + output_ptr2[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + acc = vmull_s16(row3, _ker[0]); + acc = vmlal_s16(acc, row4, _ker[1]); + acc = vmlal_s16(acc, row5, _ker[2]); + output_ptr3[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + } + } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + output_ptr2 += valid_w_start; + output_ptr3 += valid_w_start; + } + // valid int loop = output_w_tiles; asm volatile( - "vld1.32 {q0}, [%[filter_ptr]] \n" - "vmovl.s8 q14, d0 \n" - "vmovl.s8 q15, d1 \n" - "vdup.s16 d0, d28[0] \n" - "vdup.s16 d1, d28[1] \n" - "vdup.s16 d2, d28[2] \n" - "vdup.s16 d3, d28[3] \n" - "vdup.s16 d4, d29[0] \n" - "vdup.s16 d5, d29[1] \n" - "vdup.s16 d6, d29[2] \n" - "vdup.s16 d7, d29[3] \n" - "vdup.s16 d8, d30[0] \n" - : - : [filter_ptr] "r"(filter_ptr) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); - asm volatile( - "mov r0, #6 \n" - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - // loop 6 widths - "loop_4h6w_%=: \n" - "vld1.32 {d9}, [%[input_ptr0]], r0 \n" - "vld1.32 {d10}, [%[input_ptr1]], r0 \n" - "vld1.32 {d11}, [%[input_ptr2]], r0 \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d12, d10, d10, #1 \n" - "vext.s8 d13, d10, d10, #2 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vmull.s16 q12, d14, d0 \n" - "vmlal.s16 q12, d16, d1 \n" - "vmlal.s16 q12, d18, d2 \n" - "vmull.s16 q13, d15, d0 \n" - "vmlal.s16 q13, d17, d1 \n" - "vmlal.s16 q13, d19, d2 \n" - - "vext.s8 d12, d11, d11, #1 \n" - "vext.s8 d13, d11, d11, #2 \n" - "vmovl.s8 q7, d11 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #6 \n" + // loop 6 width + "loop_4h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vmull.s16 q12, d14, %e[ker0][0] \n" + "vmlal.s16 q12, d16, %e[ker0][1] \n" + "vmlal.s16 q12, d18, %e[ker0][2] \n" + "vmull.s16 q13, d15, %e[ker0][0] \n" + "vmlal.s16 q13, d17, %e[ker0][1] \n" + "vmlal.s16 q13, d19, %e[ker0][2] \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" // store row 0, reuse q10/q11 "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" - "vmlal.s16 q12, d14, d3 \n" - "vmlal.s16 q12, d16, d4 \n" - "vmlal.s16 q12, d18, d5 \n" - "vmlal.s16 q13, d15, d3 \n" - "vmlal.s16 q13, d17, d4 \n" - "vmlal.s16 q13, d19, d5 \n" + "vmlal.s16 q12, d14, %f[ker0][0] \n" + "vmlal.s16 q12, d16, %f[ker0][1] \n" + "vmlal.s16 q12, d18, %f[ker0][2] \n" + "vmlal.s16 q13, d15, %f[ker0][0] \n" + "vmlal.s16 q13, d17, %f[ker0][1] \n" + "vmlal.s16 q13, d19, %f[ker0][2] \n" - "vmull.s16 q14, d14, d0 \n" - "vmlal.s16 q14, d16, d1 \n" - "vmlal.s16 q14, d18, d2 \n" - "vmull.s16 q15, d15, d0 \n" - "vmlal.s16 q15, d17, d1 \n" - "vmlal.s16 q15, d19, d2 \n" + "vmull.s16 q14, d14, %e[ker0][0] \n" + "vmlal.s16 q14, d16, %e[ker0][1] \n" + "vmlal.s16 q14, d18, %e[ker0][2] \n" + "vmull.s16 q15, d15, %e[ker0][0] \n" + "vmlal.s16 q15, d17, %e[ker0][1] \n" + "vmlal.s16 q15, d19, %e[ker0][2] \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n" "vld1.32 {d10}, [%[input_ptr4]], r0 \n" "vld1.32 {d11}, [%[input_ptr5]], r0 \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q12, d14, d6 \n" - "vmlal.s16 q12, d16, d7 \n" - "vmlal.s16 q12, d18, d8 \n" - "vmlal.s16 q13, d15, d6 \n" - "vmlal.s16 q13, d17, d7 \n" - "vmlal.s16 q13, d19, d8 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, %e[ker1][0] \n" + "vmlal.s16 q12, d16, %e[ker1][1] \n" + "vmlal.s16 q12, d18, %e[ker1][2] \n" + "vmlal.s16 q13, d15, %e[ker1][0] \n" + "vmlal.s16 q13, d17, %e[ker1][1] \n" + "vmlal.s16 q13, d19, %e[ker1][2] \n" // store row 1 "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" - "vmlal.s16 q14, d14, d3 \n" - "vmlal.s16 q14, d16, d4 \n" - "vmlal.s16 q14, d18, d5 \n" - "vmlal.s16 q15, d15, d3 \n" - "vmlal.s16 q15, d17, d4 \n" - "vmlal.s16 q15, d19, d5 \n" - - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d12, d10, d10, #1 \n" - "vext.s8 d13, d10, d10, #2 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q14, d14, d6 \n" - "vmlal.s16 q14, d16, d7 \n" - "vmlal.s16 q14, d18, d8 \n" - "vmlal.s16 q15, d15, d6 \n" - "vmlal.s16 q15, d17, d7 \n" - "vmlal.s16 q15, d19, d8 \n" + "vmlal.s16 q14, d14, %f[ker0][0] \n" + "vmlal.s16 q14, d16, %f[ker0][1] \n" + "vmlal.s16 q14, d18, %f[ker0][2] \n" + "vmlal.s16 q15, d15, %f[ker0][0] \n" + "vmlal.s16 q15, d17, %f[ker0][1] \n" + "vmlal.s16 q15, d19, %f[ker0][2] \n" + + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q14, d14, %e[ker1][0] \n" + "vmlal.s16 q14, d16, %e[ker1][1] \n" + "vmlal.s16 q14, d18, %e[ker1][2] \n" + "vmlal.s16 q15, d15, %e[ker1][0] \n" + "vmlal.s16 q15, d17, %e[ker1][1] \n" + "vmlal.s16 q15, d19, %e[ker1][2] \n" // store row 2 "vst1.32 {d28-d30}, [%[output_ptr2]]! \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vext.s8 d12, d11, d11, #1 \n" - "vext.s8 d13, d11, d11, #2 \n" - "vmovl.s8 q7, d11 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" // store row 3 "vst1.32 {d20-d22}, [%[output_ptr3]]! \n" - "subs %[loop], #1 \n" - "bne loop_4h6w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - - "vld1.32 {d9}, [%[input_ptr0]] \n" - "vmovl.s8 q7, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q8, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vld1.32 {d9}, [%[input_ptr1]] \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vmovl.s8 q7, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q8, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vmull.s16 q12, d14, d0 \n" - "vmlal.s16 q12, d16, d1 \n" - "vmlal.s16 q12, d18, d2 \n" - "vld1.32 {d9}, [%[input_ptr2]] \n" - "vmull.s16 q13, d15, d0 \n" - "vmlal.s16 q13, d17, d1 \n" - "vmlal.s16 q13, d19, d2 \n" - - "vmovl.s8 q7, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q8, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" - - "vmlal.s16 q12, d14, d3 \n" - "vmlal.s16 q12, d16, d4 \n" - "vmlal.s16 q12, d18, d5 \n" - "vmlal.s16 q13, d15, d3 \n" - "vmlal.s16 q13, d17, d4 \n" - "vmlal.s16 q13, d19, d5 \n" - - "vmull.s16 q14, d14, d0 \n" - "vmlal.s16 q14, d16, d1 \n" - "vmlal.s16 q14, d18, d2 \n" - "vld1.32 {d9}, [%[input_ptr3]] \n" - "vmull.s16 q15, d15, d0 \n" - "vmlal.s16 q15, d17, d1 \n" - "vmlal.s16 q15, d19, d2 \n" - - "vmovl.s8 q7, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q8, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmlal.s16 q12, d14, d6 \n" - "vmlal.s16 q12, d16, d7 \n" - "vmlal.s16 q12, d18, d8 \n" - "vmlal.s16 q13, d15, d6 \n" - "vmlal.s16 q13, d17, d7 \n" - "vmlal.s16 q13, d19, d8 \n" - - "vmlal.s16 q14, d14, d3 \n" - "vmlal.s16 q14, d16, d4 \n" - "vmlal.s16 q14, d18, d5 \n" - "vmlal.s16 q15, d15, d3 \n" - "vmlal.s16 q15, d17, d4 \n" - "vmlal.s16 q15, d19, d5 \n" - - "vmull.s16 q5, d14, d0 \n" - "vmlal.s16 q5, d16, d1 \n" - "vmlal.s16 q5, d18, d2 \n" - "vld1.32 {d9}, [%[input_ptr4]] \n" - "vmull.s16 q6, d15, d0 \n" - "vmlal.s16 q6, d17, d1 \n" - "vmlal.s16 q6, d19, d2 \n" - - "vmovl.s8 q7, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q8, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmlal.s16 q14, d14, d6 \n" - "vmlal.s16 q14, d16, d7 \n" - "vmlal.s16 q14, d18, d8 \n" - "vmlal.s16 q15, d15, d6 \n" - "vmlal.s16 q15, d17, d7 \n" - "vmlal.s16 q15, d19, d8 \n" - - "vmlal.s16 q5, d14, d3 \n" - "vmlal.s16 q5, d16, d4 \n" - "vmlal.s16 q5, d18, d5 \n" - "vld1.32 {d9}, [%[input_ptr5]] \n" - "vmlal.s16 q6, d15, d3 \n" - "vmlal.s16 q6, d17, d4 \n" - "vmlal.s16 q6, d19, d5 \n" - - "vmovl.s8 q7, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q8, d9 \n" - "vext.s8 d9, d9, d9, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmlal.s16 q5, d14, d6 \n" - "vmlal.s16 q5, d16, d7 \n" - "vmlal.s16 q5, d18, d8 \n" - "vmlal.s16 q6, d15, d6 \n" - "vmlal.s16 q6, d17, d7 \n" - "vmlal.s16 q6, d19, d8 \n" - - "cmp %[remain], #4 \n" - "blt store_4h2w_%= \n" - "vst1.32 {q10}, [%[output_ptr0]]! \n" - "vst1.32 {q12}, [%[output_ptr1]]! \n" - "vst1.32 {q14}, [%[output_ptr2]]! \n" - "vst1.32 {q5}, [%[output_ptr3]]! \n" - "cmp %[remain], #5 \n" - "blt end_%= \n" - "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" - "vst1.32 {d12[0]}, [%[output_ptr3]]! \n" - "b end_%= \n" - - "store_4h2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_4h1w_%= \n" - "vst1.32 {d20}, [%[output_ptr0]]! \n" - "vst1.32 {d24}, [%[output_ptr1]]! \n" - "vst1.32 {d28}, [%[output_ptr2]]! \n" - "vst1.32 {d10}, [%[output_ptr3]]! \n" - "cmp %[remain], #3 \n" - "blt end_%= \n" - "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" - "vst1.32 {d11[0]}, [%[output_ptr3]]! \n" - "b end_%= \n" - - "store_4h1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" - "vst1.32 {d10[0]}, [%[output_ptr3]]! \n" - "end_%=: \n" + "subs %[loop], #1 \n" + "bne loop_4h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "mov r0, %[remain] \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vld1.32 {d9}, [%[input_ptr1]], r0 \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vmull.s16 q12, d14, %e[ker0][0] \n" + "vmlal.s16 q12, d16, %e[ker0][1] \n" + "vmlal.s16 q12, d18, %e[ker0][2] \n" + "vld1.32 {d9}, [%[input_ptr2]], r0 \n" + "vmull.s16 q13, d15, %e[ker0][0] \n" + "vmlal.s16 q13, d17, %e[ker0][1] \n" + "vmlal.s16 q13, d19, %e[ker0][2] \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" + + "vmlal.s16 q12, d14, %f[ker0][0] \n" + "vmlal.s16 q12, d16, %f[ker0][1] \n" + "vmlal.s16 q12, d18, %f[ker0][2] \n" + "vmlal.s16 q13, d15, %f[ker0][0] \n" + "vmlal.s16 q13, d17, %f[ker0][1] \n" + "vmlal.s16 q13, d19, %f[ker0][2] \n" + + "vmull.s16 q14, d14, %e[ker0][0] \n" + "vmlal.s16 q14, d16, %e[ker0][1] \n" + "vmlal.s16 q14, d18, %e[ker0][2] \n" + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vmull.s16 q15, d15, %e[ker0][0] \n" + "vmlal.s16 q15, d17, %e[ker0][1] \n" + "vmlal.s16 q15, d19, %e[ker0][2] \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q12, d14, %e[ker1][0] \n" + "vmlal.s16 q12, d16, %e[ker1][1] \n" + "vmlal.s16 q12, d18, %e[ker1][2] \n" + "vmlal.s16 q13, d15, %e[ker1][0] \n" + "vmlal.s16 q13, d17, %e[ker1][1] \n" + "vmlal.s16 q13, d19, %e[ker1][2] \n" + + "vmlal.s16 q14, d14, %f[ker0][0] \n" + "vmlal.s16 q14, d16, %f[ker0][1] \n" + "vmlal.s16 q14, d18, %f[ker0][2] \n" + "vmlal.s16 q15, d15, %f[ker0][0] \n" + "vmlal.s16 q15, d17, %f[ker0][1] \n" + "vmlal.s16 q15, d19, %f[ker0][2] \n" + + "vmull.s16 q5, d14, %e[ker0][0] \n" + "vmlal.s16 q5, d16, %e[ker0][1] \n" + "vmlal.s16 q5, d18, %e[ker0][2] \n" + "vld1.32 {d9}, [%[input_ptr4]], r0 \n" + "vmull.s16 q6, d15, %e[ker0][0] \n" + "vmlal.s16 q6, d17, %e[ker0][1] \n" + "vmlal.s16 q6, d19, %e[ker0][2] \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q14, d14, %e[ker1][0] \n" + "vmlal.s16 q14, d16, %e[ker1][1] \n" + "vmlal.s16 q14, d18, %e[ker1][2] \n" + "vmlal.s16 q15, d15, %e[ker1][0] \n" + "vmlal.s16 q15, d17, %e[ker1][1] \n" + "vmlal.s16 q15, d19, %e[ker1][2] \n" + + "vmlal.s16 q5, d14, %f[ker0][0] \n" + "vmlal.s16 q5, d16, %f[ker0][1] \n" + "vmlal.s16 q5, d18, %f[ker0][2] \n" + "vld1.32 {d9}, [%[input_ptr5]], r0 \n" + "vmlal.s16 q6, d15, %f[ker0][0] \n" + "vmlal.s16 q6, d17, %f[ker0][1] \n" + "vmlal.s16 q6, d19, %f[ker0][2] \n" + + "vmovl.s8 q7, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q8, d9 \n" + "vext.s8 d9, d9, d9, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmlal.s16 q5, d14, %e[ker1][0] \n" + "vmlal.s16 q5, d16, %e[ker1][1] \n" + "vmlal.s16 q5, d18, %e[ker1][2] \n" + "vmlal.s16 q6, d15, %e[ker1][0] \n" + "vmlal.s16 q6, d17, %e[ker1][1] \n" + "vmlal.s16 q6, d19, %e[ker1][2] \n" + + "cmp %[remain], #4 \n" + "blt store_4h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "vst1.32 {q14}, [%[output_ptr2]]! \n" + "vst1.32 {q5}, [%[output_ptr3]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d12[0]}, [%[output_ptr3]]! \n" + "b end_%= \n" + + "store_4h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_4h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "vst1.32 {d28}, [%[output_ptr2]]! \n" + "vst1.32 {d10}, [%[output_ptr3]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d11[0]}, [%[output_ptr3]]! \n" + "b end_%= \n" + + "store_4h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" + "vst1.32 {d10[0]}, [%[output_ptr3]]! \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), [output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), [loop] "+r"(loop) - : [remain] "r"(output_w_remain) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); + : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - 2))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4 - 2))); + int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5 - 2))); + row0 = vext_s16(row0, zero, 2); + row1 = vext_s16(row1, zero, 2); + row2 = vext_s16(row2, zero, 2); + row3 = vext_s16(row3, zero, 2); + row4 = vext_s16(row4, zero, 2); + row5 = vext_s16(row5, zero, 2); + int32x4_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + *output_ptr1 = 0; + *output_ptr2 = 0; + *output_ptr3 = 0; + } else { + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + *output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + acc = vmull_s16(row1, _ker[0]); + acc = vmlal_s16(acc, row2, _ker[1]); + acc = vmlal_s16(acc, row3, _ker[2]); + *output_ptr1 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + acc = vmull_s16(row2, _ker[0]); + acc = vmlal_s16(acc, row3, _ker[1]); + acc = vmlal_s16(acc, row4, _ker[2]); + *output_ptr2 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + acc = vmull_s16(row3, _ker[0]); + acc = vmlal_s16(acc, row4, _ker[1]); + acc = vmlal_s16(acc, row5, _ker[2]); + *output_ptr3 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + + row0 = vext_s16(row0, zero, 1); + row1 = vext_s16(row1, zero, 1); + row2 = vext_s16(row2, zero, 1); + row3 = vext_s16(row3, zero, 1); + row4 = vext_s16(row4, zero, 1); + row5 = vext_s16(row5, zero, 1); + } + output_ptr0++; + output_ptr1++; + output_ptr2++; + output_ptr3++; + } + } } // remain height int start_h = valid_h_start + (valid_h & 0xFFFC); @@ -712,208 +586,259 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w; const int8_t *input_ptr3 = input_ptr2 + input_w; - int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr0 = output_ptr + h * output_w; int32_t *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0; + output_ptr1[w] = 0; + } else { + row0 = vext_s16(zero, row0, 3); + row1 = vext_s16(zero, row1, 3); + row2 = vext_s16(zero, row2, 3); + row3 = vext_s16(zero, row3, 3); + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + acc = vmull_s16(row1, _ker[0]); + acc = vmlal_s16(acc, row2, _ker[1]); + acc = vmlal_s16(acc, row3, _ker[2]); + output_ptr1[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + } + } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + } + // valid int loop = output_w_tiles; asm volatile( - "vld1.32 {q0}, [%[filter_ptr]] \n" - "vmovl.s8 q14, d0 \n" - "vmovl.s8 q15, d1 \n" - "vdup.s16 d0, d28[0] \n" - "vdup.s16 d1, d28[1] \n" - "vdup.s16 d2, d28[2] \n" - "vdup.s16 d3, d28[3] \n" - "vdup.s16 d4, d29[0] \n" - "vdup.s16 d5, d29[1] \n" - "vdup.s16 d6, d29[2] \n" - "vdup.s16 d7, d29[3] \n" - "vdup.s16 d8, d30[0] \n" - : - : [filter_ptr] "r"(filter_ptr) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); - asm volatile( - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #6 \n" // loop 6 widths - "loop_2h6w_%=: \n" - "vld1.32 {d9}, [%[input_ptr0]], r0 \n" - "vld1.32 {d10}, [%[input_ptr1]], r0 \n" - "vld1.32 {d11}, [%[input_ptr2]], r0 \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d12, d10, d10, #1 \n" - "vext.s8 d13, d10, d10, #2 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vmull.s16 q12, d14, d0 \n" - "vmlal.s16 q12, d16, d1 \n" - "vmlal.s16 q12, d18, d2 \n" - "vmull.s16 q13, d15, d0 \n" - "vmlal.s16 q13, d17, d1 \n" - "vmlal.s16 q13, d19, d2 \n" - - "vext.s8 d12, d11, d11, #1 \n" - "vext.s8 d13, d11, d11, #2 \n" - "vmovl.s8 q7, d11 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" + "loop_2h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vmull.s16 q12, d14, %e[ker0][0] \n" + "vmlal.s16 q12, d16, %e[ker0][1] \n" + "vmlal.s16 q12, d18, %e[ker0][2] \n" + "vmull.s16 q13, d15, %e[ker0][0] \n" + "vmlal.s16 q13, d17, %e[ker0][1] \n" + "vmlal.s16 q13, d19, %e[ker0][2] \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" // store row 0, reuse q10/q11 "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" - "vmlal.s16 q12, d14, d3 \n" - "vmlal.s16 q12, d16, d4 \n" - "vmlal.s16 q12, d18, d5 \n" - "vmlal.s16 q13, d15, d3 \n" - "vmlal.s16 q13, d17, d4 \n" - "vmlal.s16 q13, d19, d5 \n" + "vmlal.s16 q12, d14, %f[ker0][0] \n" + "vmlal.s16 q12, d16, %f[ker0][1] \n" + "vmlal.s16 q12, d18, %f[ker0][2] \n" + "vmlal.s16 q13, d15, %f[ker0][0] \n" + "vmlal.s16 q13, d17, %f[ker0][1] \n" + "vmlal.s16 q13, d19, %f[ker0][2] \n" "vld1.32 {d9}, [%[input_ptr3]], r0 \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q12, d14, d6 \n" - "vmlal.s16 q12, d16, d7 \n" - "vmlal.s16 q12, d18, d8 \n" - "vmlal.s16 q13, d15, d6 \n" - "vmlal.s16 q13, d17, d7 \n" - "vmlal.s16 q13, d19, d8 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, %e[ker1][0] \n" + "vmlal.s16 q12, d16, %e[ker1][1] \n" + "vmlal.s16 q12, d18, %e[ker1][2] \n" + "vmlal.s16 q13, d15, %e[ker1][0] \n" + "vmlal.s16 q13, d17, %e[ker1][1] \n" + "vmlal.s16 q13, d19, %e[ker1][2] \n" // store row 1 "vst1.32 {d24-d26}, [%[output_ptr1]]! \n" - "subs %[loop], #1 \n" - "bne loop_2h6w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - - "vld1.32 {d9}, [%[input_ptr0]] \n" - "vld1.32 {d10}, [%[input_ptr1]] \n" - "vld1.32 {d11}, [%[input_ptr2]] \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d12, d10, d10, #1 \n" - "vext.s8 d13, d10, d10, #2 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vmull.s16 q12, d14, d0 \n" - "vmlal.s16 q12, d16, d1 \n" - "vmlal.s16 q12, d18, d2 \n" - "vmull.s16 q13, d15, d0 \n" - "vmlal.s16 q13, d17, d1 \n" - "vmlal.s16 q13, d19, d2 \n" - - "vext.s8 d12, d11, d11, #1 \n" - "vext.s8 d13, d11, d11, #2 \n" - "vmovl.s8 q7, d11 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" - - "vmlal.s16 q12, d14, d3 \n" - "vmlal.s16 q12, d16, d4 \n" - "vmlal.s16 q12, d18, d5 \n" - "vmlal.s16 q13, d15, d3 \n" - "vmlal.s16 q13, d17, d4 \n" - "vmlal.s16 q13, d19, d5 \n" - - "vld1.32 {d9}, [%[input_ptr3]] \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q12, d14, d6 \n" - "vmlal.s16 q12, d16, d7 \n" - "vmlal.s16 q12, d18, d8 \n" - "vmlal.s16 q13, d15, d6 \n" - "vmlal.s16 q13, d17, d7 \n" - "vmlal.s16 q13, d19, d8 \n" - - "cmp %[remain], #4 \n" - "blt store_2h2w_%= \n" - "vst1.32 {q10}, [%[output_ptr0]]! \n" - "vst1.32 {q12}, [%[output_ptr1]]! \n" - "cmp %[remain], #5 \n" - "blt end_%= \n" - "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" - "b end_%= \n" - - "store_2h2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_2h1w_%= \n" - "vst1.32 {d20}, [%[output_ptr0]]! \n" - "vst1.32 {d24}, [%[output_ptr1]]! \n" - "cmp %[remain], #3 \n" - "blt end_%= \n" - "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" - "b end_%= \n" - - "store_2h1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" - "end_%=: \n" + "subs %[loop], #1 \n" + "bne loop_2h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + + "mov r0, %[remain] \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vmull.s16 q12, d14, %e[ker0][0] \n" + "vmlal.s16 q12, d16, %e[ker0][1] \n" + "vmlal.s16 q12, d18, %e[ker0][2] \n" + "vmull.s16 q13, d15, %e[ker0][0] \n" + "vmlal.s16 q13, d17, %e[ker0][1] \n" + "vmlal.s16 q13, d19, %e[ker0][2] \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" + + "vmlal.s16 q12, d14, %f[ker0][0] \n" + "vmlal.s16 q12, d16, %f[ker0][1] \n" + "vmlal.s16 q12, d18, %f[ker0][2] \n" + "vmlal.s16 q13, d15, %f[ker0][0] \n" + "vmlal.s16 q13, d17, %f[ker0][1] \n" + "vmlal.s16 q13, d19, %f[ker0][2] \n" + + "vld1.32 {d9}, [%[input_ptr3]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d14, %e[ker1][0] \n" + "vmlal.s16 q12, d16, %e[ker1][1] \n" + "vmlal.s16 q12, d18, %e[ker1][2] \n" + "vmlal.s16 q13, d15, %e[ker1][0] \n" + "vmlal.s16 q13, d17, %e[ker1][1] \n" + "vmlal.s16 q13, d19, %e[ker1][2] \n" + + "cmp %[remain], #4 \n" + "blt store_2h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_2h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [loop] "+r"(loop) - : [remain] "r"(output_w_remain) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "r0"); + : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - 2))); + row0 = vext_s16(row0, zero, 2); + row1 = vext_s16(row1, zero, 2); + row2 = vext_s16(row2, zero, 2); + row3 = vext_s16(row3, zero, 2); + int32x4_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + *output_ptr1 = 0; + } else { + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + *output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + acc = vmull_s16(row1, _ker[0]); + acc = vmlal_s16(acc, row2, _ker[1]); + acc = vmlal_s16(acc, row3, _ker[2]); + *output_ptr1 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + + row0 = vext_s16(row0, zero, 1); + row1 = vext_s16(row1, zero, 1); + row2 = vext_s16(row2, zero, 1); + row3 = vext_s16(row3, zero, 1); + } + output_ptr0++; + output_ptr1++; + } + } } start_h = valid_h_start + (valid_h & 0xFFFE); @@ -921,145 +846,185 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w; - int32_t *output_ptr0 = output_ptr + start_h * output_w + valid_w_start; + int32_t *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0; + } else { + row0 = vext_s16(zero, row0, 3); + row1 = vext_s16(zero, row1, 3); + row2 = vext_s16(zero, row2, 3); + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2); + } + } + output_ptr0 += valid_w_start; + } + // valid int loop = output_w_tiles; asm volatile( - "vld1.32 {q0}, [%[filter_ptr]] \n" - "vmovl.s8 q14, d0 \n" - "vmovl.s8 q15, d1 \n" - "vdup.s16 d0, d28[0] \n" - "vdup.s16 d1, d28[1] \n" - "vdup.s16 d2, d28[2] \n" - "vdup.s16 d3, d28[3] \n" - "vdup.s16 d4, d29[0] \n" - "vdup.s16 d5, d29[1] \n" - "vdup.s16 d6, d29[2] \n" - "vdup.s16 d7, d29[3] \n" - "vdup.s16 d8, d30[0] \n" - : - : [filter_ptr] "r"(filter_ptr) - : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); - asm volatile( - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - "mov r0, #6 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #6 \n" // loop 6 widths - "loop_1h6w_%=: \n" - "vld1.32 {d9}, [%[input_ptr0]], r0 \n" - "vld1.32 {d10}, [%[input_ptr1]], r0 \n" - "vld1.32 {d11}, [%[input_ptr2]], r0 \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d12, d10, d10, #1 \n" - "vext.s8 d13, d10, d10, #2 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vext.s8 d12, d11, d11, #1 \n" - "vext.s8 d13, d11, d11, #2 \n" - "vmovl.s8 q7, d11 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" + "loop_1h6w_%=: \n" + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" // store row 0, reuse q10/q11 "vst1.32 {d20-d22}, [%[output_ptr0]]! \n" - "subs %[loop], #1 \n" - "bne loop_1h6w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - - "vld1.32 {d9}, [%[input_ptr0]] \n" - "vld1.32 {d10}, [%[input_ptr1]] \n" - "vld1.32 {d11}, [%[input_ptr2]] \n" - "vext.s8 d12, d9, d9, #1 \n" - "vext.s8 d13, d9, d9, #2 \n" - "vmovl.s8 q7, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d12, d10, d10, #1 \n" - "vext.s8 d13, d10, d10, #2 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vext.s8 d12, d11, d11, #1 \n" - "vext.s8 d13, d11, d11, #2 \n" - "vmovl.s8 q7, d11 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" - - "cmp %[remain], #4 \n" - "blt store_1h2w_%= \n" - "vst1.32 {q10}, [%[output_ptr0]]! \n" - "cmp %[remain], #5 \n" - "blt end_%= \n" - "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" - "b end_%= \n" - - "store_1h2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_1h1w_%= \n" - "vst1.32 {d20}, [%[output_ptr0]]! \n" - "cmp %[remain], #3 \n" - "blt end_%= \n" - "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" - "b end_%= \n" - - "store_1h1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" - "end_%=: \n" + "subs %[loop], #1 \n" + "bne loop_1h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain] \n" + + "vld1.32 {d9}, [%[input_ptr0]], r0 \n" + "vld1.32 {d10}, [%[input_ptr1]], r0 \n" + "vld1.32 {d11}, [%[input_ptr2]], r0 \n" + "vext.s8 d12, d9, d9, #1 \n" + "vext.s8 d13, d9, d9, #2 \n" + "vmovl.s8 q7, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d12, d10, d10, #1 \n" + "vext.s8 d13, d10, d10, #2 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vext.s8 d12, d11, d11, #1 \n" + "vext.s8 d13, d11, d11, #2 \n" + "vmovl.s8 q7, d11 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" + + "cmp %[remain], #4 \n" + "blt store_1h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [loop] "+r"(loop) - : [remain] "r"(output_w_remain) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "r0"); + : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2))); + row0 = vext_s16(row0, zero, 2); + row1 = vext_s16(row1, zero, 2); + row2 = vext_s16(row2, zero, 2); + int32x4_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + } else { + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + *output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1); + + row0 = vext_s16(row0, zero, 1); + row1 = vext_s16(row1, zero, 1); + row2 = vext_s16(row2, zero, 1); + } + output_ptr0++; + } + } + } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); } } } @@ -1081,11 +1046,13 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, int image_size = input_h * input_w; int out_image_size = output_h * output_w; int valid_h_start = (padding_h + 1) / 2; - int valid_h_end = output_h - valid_h_start; + int valid_h_end = (input_h + padding_h - 1) / 2; int valid_h = valid_h_end - valid_h_start; int valid_w_start = (padding_w + 1) / 2; - int valid_w_end = output_w - valid_w_start; + int valid_w_end = (input_w + padding_w - 1) / 2; int valid_w = valid_w_end - valid_w_start; + // for pad left + int valid_input_w_start = (valid_w_start << 1) - padding_w; // DLOG << "valid_h_start: " << valid_h_start; // DLOG << "valid_h_end: " << valid_h_end; @@ -1097,459 +1064,579 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, const int8_t *input_ptr = input_data + g * image_size; const int8_t *filter_ptr = filter_data + g * 9; int32_t *output_ptr = out_data + g * out_image_size; + + const int8_t *filter_ptr0 = filter_ptr; + const int8_t *filter_ptr1 = filter_ptr0 + 3; + const int8_t *filter_ptr2 = filter_ptr1 + 3; + int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0))); + int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1))); + int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2))); + int16x8_t _ker0 = vcombine_s16(_k0, _k1); + int16x8_t _ker1 = vcombine_s16(_k2, _k2); + int16x4_t _ker[3] = {_k0, _k1, _k2}; + // top for (int h = 0; h < valid_h_start; ++h) { DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, input_w, padding_h, padding_w, output_w, - output_ptr); - } - // left - for (int w = 0; w < valid_w_start; ++w) { - DepthwiseConv3x3ValidCol<2, 2>( - input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, - input_w, padding_h, padding_w, output_w, output_ptr); - } - // right - for (int w = valid_w_end; w < output_w; ++w) { - DepthwiseConv3x3ValidCol<2, 2>( - input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h, - input_w, padding_h, padding_w, output_w, output_ptr); - } - // bottom - for (int h = valid_h_end; h < output_h; ++h) { - DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, - input_w, padding_h, padding_w, output_w, - output_ptr); + output_ptr, _ker); } // valid int input_w_start = 2 * valid_w_start - padding_w; int output_w_tiles = valid_w / 6; int output_w_remain = valid_w - output_w_tiles * 6; for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { - size_t offset = (2 * h - padding_h) * input_w + input_w_start; - const int8_t *input_ptr0 = input_ptr + offset; + const int8_t *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w; const int8_t *input_ptr3 = input_ptr2 + input_w; const int8_t *input_ptr4 = input_ptr3 + input_w; const int8_t *input_ptr5 = input_ptr4 + input_w; const int8_t *input_ptr6 = input_ptr5 + input_w; - int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr0 = output_ptr + h * output_w; int32_t *output_ptr1 = output_ptr0 + output_w; int32_t *output_ptr2 = output_ptr1 + output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0; + output_ptr1[w] = 0; + output_ptr2[w] = 0; + } else { + int16x4_t row0 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - padding))); + int16x4_t row1 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - padding))); + int16x4_t row2 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - padding))); + int16x4_t row3 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - padding))); + int16x4_t row4 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr4 - padding))); + int16x4_t row5 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr5 - padding))); + int16x4_t row6 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr6 - padding))); + int32x4_t acc0 = vmull_s16(row0, _ker[0]); + acc0 = vmlal_s16(acc0, row1, _ker[1]); + acc0 = vmlal_s16(acc0, row2, _ker[2]); + int32x4_t acc1 = vmull_s16(row2, _ker[0]); + acc1 = vmlal_s16(acc1, row3, _ker[1]); + acc1 = vmlal_s16(acc1, row4, _ker[2]); + int32x4_t acc2 = vmull_s16(row4, _ker[0]); + acc2 = vmlal_s16(acc2, row5, _ker[1]); + acc2 = vmlal_s16(acc2, row6, _ker[2]); + int32_t sum0 = vgetq_lane_s32(acc0, 2); + int32_t sum1 = vgetq_lane_s32(acc1, 2); + int32_t sum2 = vgetq_lane_s32(acc2, 2); + if (padding == 1) { + sum0 += vgetq_lane_s32(acc0, 1); + sum1 += vgetq_lane_s32(acc1, 1); + sum2 += vgetq_lane_s32(acc2, 1); + } + output_ptr0[w] = sum0; + output_ptr1[w] = sum1; + output_ptr2[w] = sum2; + } + } + input_ptr0 += valid_input_w_start; + input_ptr1 += valid_input_w_start; + input_ptr2 += valid_input_w_start; + input_ptr3 += valid_input_w_start; + input_ptr4 += valid_input_w_start; + input_ptr5 += valid_input_w_start; + input_ptr6 += valid_input_w_start; + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + output_ptr2 += valid_w_start; + } + // valid int loop = output_w_tiles; asm volatile( - "vld1.32 {q0}, [%[filter_ptr]] \n" - "vmovl.s8 q14, d0 \n" - "vmovl.s8 q15, d1 \n" - "vdup.s16 d0, d28[0] \n" - "vdup.s16 d1, d28[1] \n" - "vdup.s16 d2, d28[2] \n" - "vdup.s16 d3, d28[3] \n" - "vdup.s16 d4, d29[0] \n" - "vdup.s16 d5, d29[1] \n" - "vdup.s16 d6, d29[2] \n" - "vdup.s16 d7, d29[3] \n" - "vdup.s16 d8, d30[0] \n" - : - : [filter_ptr] "r"(filter_ptr) - : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); - asm volatile( - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - "mov r0, #12 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #12 \n" // loop 6 widths - "loop_3h6w_%=: \n" - "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" - "vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n" - "vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d10 \n" - "vmovl.s8 q9, d11 \n" - "vmull.s16 q11, d16, d0 \n" - "vmlal.s16 q11, d18, d1 \n" - "vmlal.s16 q11, d20, d2 \n" - "vmull.s16 q12, d17, d0 \n" - "vmlal.s16 q12, d19, d1 \n" - "vmlal.s16 q12, d21, d2 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q11, d16, d3 \n" - "vmlal.s16 q11, d18, d4 \n" - "vmlal.s16 q11, d20, d5 \n" - "vmlal.s16 q12, d17, d3 \n" - "vmlal.s16 q12, d19, d4 \n" - "vmlal.s16 q12, d21, d5 \n" - - "vext.s8 d9, d14, d14, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d14 \n" - "vmovl.s8 q9, d15 \n" - "vmlal.s16 q11, d16, d6 \n" - "vmlal.s16 q11, d18, d7 \n" - "vmlal.s16 q11, d20, d8 \n" - "vmlal.s16 q12, d17, d6 \n" - "vmlal.s16 q12, d19, d7 \n" - "vmlal.s16 q12, d21, d8 \n" + "loop_3h6w_%=: \n" + "vld2.8 {d10-d11}, [%[input_ptr0]], r0 \n" + "vld2.8 {d12-d13}, [%[input_ptr1]], r0 \n" + "vld2.8 {d14-d15}, [%[input_ptr2]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q11, d16, %e[ker0][0] \n" + "vmlal.s16 q11, d18, %e[ker0][1] \n" + "vmlal.s16 q11, d20, %e[ker0][2] \n" + "vmull.s16 q12, d17, %e[ker0][0] \n" + "vmlal.s16 q12, d19, %e[ker0][1] \n" + "vmlal.s16 q12, d21, %e[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q11, d16, %f[ker0][0] \n" + "vmlal.s16 q11, d18, %f[ker0][1] \n" + "vmlal.s16 q11, d20, %f[ker0][2] \n" + "vmlal.s16 q12, d17, %f[ker0][0] \n" + "vmlal.s16 q12, d19, %f[ker0][1] \n" + "vmlal.s16 q12, d21, %f[ker0][2] \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, %e[ker1][0] \n" + "vmlal.s16 q11, d18, %e[ker1][1] \n" + "vmlal.s16 q11, d20, %e[ker1][2] \n" + "vmlal.s16 q12, d17, %e[ker1][0] \n" + "vmlal.s16 q12, d19, %e[ker1][1] \n" + "vmlal.s16 q12, d21, %e[ker1][2] \n" // store row 0, reuse q11/q12 - "vst1.32 {d22-d24}, [%[output_ptr0]]! \n" - - "vmull.s16 q13, d16, d0 \n" - "vmlal.s16 q13, d18, d1 \n" - "vmlal.s16 q13, d20, d2 \n" - "vmull.s16 q14, d17, d0 \n" - "vmlal.s16 q14, d19, d1 \n" - "vmlal.s16 q14, d21, d2 \n" - - "vld2.8 {d10, d11}, [%[input_ptr3]], r0 \n" - "vld2.8 {d12, d13}, [%[input_ptr4]], r0 \n" - "vld2.8 {d14, d15}, [%[input_ptr5]], r0 \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d10 \n" - "vmovl.s8 q9, d11 \n" - "vmlal.s16 q13, d16, d3 \n" - "vmlal.s16 q13, d18, d4 \n" - "vmlal.s16 q13, d20, d5 \n" - "vmlal.s16 q14, d17, d3 \n" - "vmlal.s16 q14, d19, d4 \n" - "vmlal.s16 q14, d21, d5 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q13, d16, d6 \n" - "vmlal.s16 q13, d18, d7 \n" - "vmlal.s16 q13, d20, d8 \n" - "vmlal.s16 q14, d17, d6 \n" - "vmlal.s16 q14, d19, d7 \n" - "vmlal.s16 q14, d21, d8 \n" + "vst1.32 {d22-d24}, [%[output_ptr0]]! \n" + + "vmull.s16 q13, d16, %e[ker0][0] \n" + "vmlal.s16 q13, d18, %e[ker0][1] \n" + "vmlal.s16 q13, d20, %e[ker0][2] \n" + "vmull.s16 q14, d17, %e[ker0][0] \n" + "vmlal.s16 q14, d19, %e[ker0][1] \n" + "vmlal.s16 q14, d21, %e[ker0][2] \n" + + "vld2.8 {d10-d11}, [%[input_ptr3]], r0 \n" + "vld2.8 {d12-d13}, [%[input_ptr4]], r0 \n" + "vld2.8 {d14-d15}, [%[input_ptr5]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmlal.s16 q13, d16, %f[ker0][0] \n" + "vmlal.s16 q13, d18, %f[ker0][1] \n" + "vmlal.s16 q13, d20, %f[ker0][2] \n" + "vmlal.s16 q14, d17, %f[ker0][0] \n" + "vmlal.s16 q14, d19, %f[ker0][1] \n" + "vmlal.s16 q14, d21, %f[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q13, d16, %e[ker1][0] \n" + "vmlal.s16 q13, d18, %e[ker1][1] \n" + "vmlal.s16 q13, d20, %e[ker1][2] \n" + "vmlal.s16 q14, d17, %e[ker1][0] \n" + "vmlal.s16 q14, d19, %e[ker1][1] \n" + "vmlal.s16 q14, d21, %e[ker1][2] \n" // store row 1 - "vst1.32 {d26-d28}, [%[output_ptr1]]! \n" - - "vmull.s16 q11, d16, d0 \n" - "vmlal.s16 q11, d18, d1 \n" - "vmlal.s16 q11, d20, d2 \n" - "vmull.s16 q12, d17, d0 \n" - "vmlal.s16 q12, d19, d1 \n" - "vmlal.s16 q12, d21, d2 \n" - - "vext.s8 d9, d14, d14, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d14 \n" - "vmovl.s8 q9, d15 \n" - "vmlal.s16 q11, d16, d3 \n" - "vmlal.s16 q11, d18, d4 \n" - "vmlal.s16 q11, d20, d5 \n" - "vmlal.s16 q12, d17, d3 \n" - "vmlal.s16 q12, d19, d4 \n" - "vmlal.s16 q12, d21, d5 \n" - - "vld2.8 {d10, d11}, [%[input_ptr6]], r0 \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d10 \n" - "vmovl.s8 q9, d11 \n" - "vmlal.s16 q11, d16, d6 \n" - "vmlal.s16 q11, d18, d7 \n" - "vmlal.s16 q11, d20, d8 \n" - "vmlal.s16 q12, d17, d6 \n" - "vmlal.s16 q12, d19, d7 \n" - "vmlal.s16 q12, d21, d8 \n" + "vst1.32 {d26-d28}, [%[output_ptr1]]! \n" + + "vmull.s16 q11, d16, %e[ker0][0] \n" + "vmlal.s16 q11, d18, %e[ker0][1] \n" + "vmlal.s16 q11, d20, %e[ker0][2] \n" + "vmull.s16 q12, d17, %e[ker0][0] \n" + "vmlal.s16 q12, d19, %e[ker0][1] \n" + "vmlal.s16 q12, d21, %e[ker0][2] \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, %f[ker0][0] \n" + "vmlal.s16 q11, d18, %f[ker0][1] \n" + "vmlal.s16 q11, d20, %f[ker0][2] \n" + "vmlal.s16 q12, d17, %f[ker0][0] \n" + "vmlal.s16 q12, d19, %f[ker0][1] \n" + "vmlal.s16 q12, d21, %f[ker0][2] \n" + + "vld2.8 {d10-d11}, [%[input_ptr6]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmlal.s16 q11, d16, %e[ker1][0] \n" + "vmlal.s16 q11, d18, %e[ker1][1] \n" + "vmlal.s16 q11, d20, %e[ker1][2] \n" + "vmlal.s16 q12, d17, %e[ker1][0] \n" + "vmlal.s16 q12, d19, %e[ker1][1] \n" + "vmlal.s16 q12, d21, %e[ker1][2] \n" // store row 2 - "vst1.32 {d22-d24}, [%[output_ptr2]]! \n" - - "subs %[loop], #1 \n" - "bne loop_3h6w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - - "vld2.8 {d10, d11}, [%[input_ptr0]] \n" - "vld2.8 {d12, d13}, [%[input_ptr1]] \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d11 \n" - "vmull.s16 q10, d14, d0 \n" - "vmlal.s16 q10, d16, d1 \n" - "vmlal.s16 q10, d18, d2 \n" - "vmull.s16 q11, d15, d0 \n" - "vmlal.s16 q11, d17, d1 \n" - "vmlal.s16 q11, d19, d2 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d12 \n" - "vmovl.s8 q8, d13 \n" - "vmlal.s16 q10, d14, d3 \n" - "vmlal.s16 q10, d16, d4 \n" - "vmlal.s16 q10, d18, d5 \n" - "vmlal.s16 q11, d15, d3 \n" - "vmlal.s16 q11, d17, d4 \n" - "vmlal.s16 q11, d19, d5 \n" - - "vld2.8 {d10, d11}, [%[input_ptr2]] \n" - "vld2.8 {d12, d13}, [%[input_ptr3]] \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d11 \n" - "vmlal.s16 q10, d14, d6 \n" - "vmlal.s16 q10, d16, d7 \n" - "vmlal.s16 q10, d18, d8 \n" - "vmlal.s16 q11, d15, d6 \n" - "vmlal.s16 q11, d17, d7 \n" - "vmlal.s16 q11, d19, d8 \n" - - "vmull.s16 q12, d14, d0 \n" - "vmlal.s16 q12, d16, d1 \n" - "vmlal.s16 q12, d18, d2 \n" - "vmull.s16 q13, d15, d0 \n" - "vmlal.s16 q13, d17, d1 \n" - "vmlal.s16 q13, d19, d2 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d12 \n" - "vmovl.s8 q8, d13 \n" - "vmlal.s16 q12, d14, d3 \n" - "vmlal.s16 q12, d16, d4 \n" - "vmlal.s16 q12, d18, d5 \n" - "vmlal.s16 q13, d15, d3 \n" - "vmlal.s16 q13, d17, d4 \n" - "vmlal.s16 q13, d19, d5 \n" - - "vld2.8 {d10, d11}, [%[input_ptr4]] \n" - "vld2.8 {d12, d13}, [%[input_ptr5]] \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d11 \n" - "vmlal.s16 q12, d14, d6 \n" - "vmlal.s16 q12, d16, d7 \n" - "vmlal.s16 q12, d18, d8 \n" - "vmlal.s16 q13, d15, d6 \n" - "vmlal.s16 q13, d17, d7 \n" - "vmlal.s16 q13, d19, d8 \n" - - "vmull.s16 q14, d14, d0 \n" - "vmlal.s16 q14, d16, d1 \n" - "vmlal.s16 q14, d18, d2 \n" - "vmull.s16 q15, d15, d0 \n" - "vmlal.s16 q15, d17, d1 \n" - "vmlal.s16 q15, d19, d2 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d12 \n" - "vmovl.s8 q8, d13 \n" - "vmlal.s16 q14, d14, d3 \n" - "vmlal.s16 q14, d16, d4 \n" - "vmlal.s16 q14, d18, d5 \n" - "vmlal.s16 q15, d15, d3 \n" - "vmlal.s16 q15, d17, d4 \n" - "vmlal.s16 q15, d19, d5 \n" - - "vld2.8 {d10, d11}, [%[input_ptr6]] \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q9, d9 \n" - "vmovl.s8 q7, d10 \n" - "vmovl.s8 q8, d11 \n" - "vmlal.s16 q14, d14, d6 \n" - "vmlal.s16 q14, d16, d7 \n" - "vmlal.s16 q14, d18, d8 \n" - "vmlal.s16 q15, d15, d6 \n" - "vmlal.s16 q15, d17, d7 \n" - "vmlal.s16 q15, d19, d8 \n" - - "cmp %[remain], #4 \n" - "blt store_3h2w_%= \n" - "vst1.32 {q10}, [%[output_ptr0]]! \n" - "vst1.32 {q12}, [%[output_ptr1]]! \n" - "vst1.32 {q14}, [%[output_ptr2]]! \n" - "cmp %[remain], #5 \n" - "blt end_%= \n" - "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" - "b end_%= \n" - - "store_3h2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_3h1w_%= \n" - "vst1.32 {d20}, [%[output_ptr0]]! \n" - "vst1.32 {d24}, [%[output_ptr1]]! \n" - "vst1.32 {d28}, [%[output_ptr2]]! \n" - "cmp %[remain], #3 \n" - "blt end_%= \n" - "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" - "b end_%= \n" - - "store_3h1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" - "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" - "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" - "end_%=: \n" + "vst1.32 {d22-d24}, [%[output_ptr2]]! \n" + + "subs %[loop], #1 \n" + "bne loop_3h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain], lsl #1 \n" + + "vld2.8 {d10-d11}, [%[input_ptr0]], r0 \n" + "vld2.8 {d12-d13}, [%[input_ptr1]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmull.s16 q10, d14, %e[ker0][0] \n" + "vmlal.s16 q10, d16, %e[ker0][1] \n" + "vmlal.s16 q10, d18, %e[ker0][2] \n" + "vmull.s16 q11, d15, %e[ker0][0] \n" + "vmlal.s16 q11, d17, %e[ker0][1] \n" + "vmlal.s16 q11, d19, %e[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d12 \n" + "vmovl.s8 q8, d13 \n" + "vmlal.s16 q10, d14, %f[ker0][0] \n" + "vmlal.s16 q10, d16, %f[ker0][1] \n" + "vmlal.s16 q10, d18, %f[ker0][2] \n" + "vmlal.s16 q11, d15, %f[ker0][0] \n" + "vmlal.s16 q11, d17, %f[ker0][1] \n" + "vmlal.s16 q11, d19, %f[ker0][2] \n" + + "vld2.8 {d10-d11}, [%[input_ptr2]], r0 \n" + "vld2.8 {d12-d13}, [%[input_ptr3]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmlal.s16 q10, d14, %e[ker1][0] \n" + "vmlal.s16 q10, d16, %e[ker1][1] \n" + "vmlal.s16 q10, d18, %e[ker1][2] \n" + "vmlal.s16 q11, d15, %e[ker1][0] \n" + "vmlal.s16 q11, d17, %e[ker1][1] \n" + "vmlal.s16 q11, d19, %e[ker1][2] \n" + + "vmull.s16 q12, d14, %e[ker0][0] \n" + "vmlal.s16 q12, d16, %e[ker0][1] \n" + "vmlal.s16 q12, d18, %e[ker0][2] \n" + "vmull.s16 q13, d15, %e[ker0][0] \n" + "vmlal.s16 q13, d17, %e[ker0][1] \n" + "vmlal.s16 q13, d19, %e[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d12 \n" + "vmovl.s8 q8, d13 \n" + "vmlal.s16 q12, d14, %f[ker0][0] \n" + "vmlal.s16 q12, d16, %f[ker0][1] \n" + "vmlal.s16 q12, d18, %f[ker0][2] \n" + "vmlal.s16 q13, d15, %f[ker0][0] \n" + "vmlal.s16 q13, d17, %f[ker0][1] \n" + "vmlal.s16 q13, d19, %f[ker0][2] \n" + + "vld2.8 {d10-d11}, [%[input_ptr4]], r0 \n" + "vld2.8 {d12-d13}, [%[input_ptr5]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmlal.s16 q12, d14, %e[ker1][0] \n" + "vmlal.s16 q12, d16, %e[ker1][1] \n" + "vmlal.s16 q12, d18, %e[ker1][2] \n" + "vmlal.s16 q13, d15, %e[ker1][0] \n" + "vmlal.s16 q13, d17, %e[ker1][1] \n" + "vmlal.s16 q13, d19, %e[ker1][2] \n" + + "vmull.s16 q14, d14, %e[ker0][0] \n" + "vmlal.s16 q14, d16, %e[ker0][1] \n" + "vmlal.s16 q14, d18, %e[ker0][2] \n" + "vmull.s16 q15, d15, %e[ker0][0] \n" + "vmlal.s16 q15, d17, %e[ker0][1] \n" + "vmlal.s16 q15, d19, %e[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d12 \n" + "vmovl.s8 q8, d13 \n" + "vmlal.s16 q14, d14, %f[ker0][0] \n" + "vmlal.s16 q14, d16, %f[ker0][1] \n" + "vmlal.s16 q14, d18, %f[ker0][2] \n" + "vmlal.s16 q15, d15, %f[ker0][0] \n" + "vmlal.s16 q15, d17, %f[ker0][1] \n" + "vmlal.s16 q15, d19, %f[ker0][2] \n" + + "vld2.8 {d10-d11}, [%[input_ptr6]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q9, d9 \n" + "vmovl.s8 q7, d10 \n" + "vmovl.s8 q8, d11 \n" + "vmlal.s16 q14, d14, %e[ker1][0] \n" + "vmlal.s16 q14, d16, %e[ker1][1] \n" + "vmlal.s16 q14, d18, %e[ker1][2] \n" + "vmlal.s16 q15, d15, %e[ker1][0] \n" + "vmlal.s16 q15, d17, %e[ker1][1] \n" + "vmlal.s16 q15, d19, %e[ker1][2] \n" + + "cmp %[remain], #4 \n" + "blt store_3h2w_%= \n" + "vst1.32 {q10}, [%[output_ptr0]]! \n" + "vst1.32 {q12}, [%[output_ptr1]]! \n" + "vst1.32 {q14}, [%[output_ptr2]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d26[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d30[0]}, [%[output_ptr2]]! \n" + "b end_%= \n" + + "store_3h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_3h1w_%= \n" + "vst1.32 {d20}, [%[output_ptr0]]! \n" + "vst1.32 {d24}, [%[output_ptr1]]! \n" + "vst1.32 {d28}, [%[output_ptr2]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d21[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d25[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr2]]! \n" + "b end_%= \n" + + "store_3h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d20[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d24[0]}, [%[output_ptr1]]! \n" + "vst1.32 {d28[0]}, [%[output_ptr2]]! \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), [output_ptr2] "+r"(output_ptr2), [input_ptr6] "+r"(input_ptr6), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), [loop] "+r"(loop) - : [remain] "r"(output_w_remain) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); + : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w > 0) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4))); + int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5))); + int16x4_t row6 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr6))); + int32x4_t acc0, acc1, acc2; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + *output_ptr1 = 0; + *output_ptr2 = 0; + } else { + acc0 = vmull_s16(row0, _ker[0]); + acc0 = vmlal_s16(acc0, row1, _ker[1]); + acc0 = vmlal_s16(acc0, row2, _ker[2]); + acc1 = vmull_s16(row2, _ker[0]); + acc1 = vmlal_s16(acc1, row3, _ker[1]); + acc1 = vmlal_s16(acc1, row4, _ker[2]); + acc2 = vmull_s16(row4, _ker[0]); + acc2 = vmlal_s16(acc2, row5, _ker[1]); + acc2 = vmlal_s16(acc2, row6, _ker[2]); + int32_t sum0 = vgetq_lane_s32(acc0, 0); + int32_t sum1 = vgetq_lane_s32(acc1, 0); + int32_t sum2 = vgetq_lane_s32(acc2, 0); + if (padding == 1) { + sum0 += vgetq_lane_s32(acc0, 1); + sum1 += vgetq_lane_s32(acc1, 1); + sum2 += vgetq_lane_s32(acc2, 1); + } + *output_ptr0 = sum0; + *output_ptr1 = sum1; + *output_ptr2 = sum2; + } + output_ptr0++; + output_ptr1++; + output_ptr2++; + } + } } - + // remain height int start_h = valid_h_start + valid_h / 3 * 3; for (int h = start_h; h < valid_h_end; ++h) { - size_t offset = (2 * h - padding_h) * input_w + input_w_start; - const int8_t *input_ptr0 = input_ptr + offset; + const int8_t *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; const int8_t *input_ptr1 = input_ptr0 + input_w; const int8_t *input_ptr2 = input_ptr1 + input_w; - int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start; + int32_t *output_ptr0 = output_ptr + h * output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0; + } else { + int16x4_t row0 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - padding))); + int16x4_t row1 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - padding))); + int16x4_t row2 = + vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - padding))); + int32x4_t acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + int32_t sum0 = vgetq_lane_s32(acc, 2); + if (padding == 1) { + sum0 += vgetq_lane_s32(acc, 1); + } + output_ptr0[w] = sum0; + } + } + input_ptr0 += valid_input_w_start; + input_ptr1 += valid_input_w_start; + input_ptr2 += valid_input_w_start; + output_ptr0 += valid_w_start; + } + // valid int loop = output_w_tiles; asm volatile( - "vld1.32 {q0}, [%[filter_ptr]] \n" - "vmovl.s8 q14, d0 \n" - "vmovl.s8 q15, d1 \n" - "vdup.s16 d0, d28[0] \n" - "vdup.s16 d1, d28[1] \n" - "vdup.s16 d2, d28[2] \n" - "vdup.s16 d3, d28[3] \n" - "vdup.s16 d4, d29[0] \n" - "vdup.s16 d5, d29[1] \n" - "vdup.s16 d6, d29[2] \n" - "vdup.s16 d7, d29[3] \n" - "vdup.s16 d8, d30[0] \n" - : - : [filter_ptr] "r"(filter_ptr) - : "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15"); - asm volatile( - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - "mov r0, #12 \n" + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #12 \n" // loop 6 widths "loop_1h6w_%=: \n" "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" "vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n" "vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d10 \n" - "vmovl.s8 q9, d11 \n" - "vmull.s16 q11, d16, d0 \n" - "vmlal.s16 q11, d18, d1 \n" - "vmlal.s16 q11, d20, d2 \n" - "vmull.s16 q12, d17, d0 \n" - "vmlal.s16 q12, d19, d1 \n" - "vmlal.s16 q12, d21, d2 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q11, d16, d3 \n" - "vmlal.s16 q11, d18, d4 \n" - "vmlal.s16 q11, d20, d5 \n" - "vmlal.s16 q12, d17, d3 \n" - "vmlal.s16 q12, d19, d4 \n" - "vmlal.s16 q12, d21, d5 \n" - - "vext.s8 d9, d14, d14, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d14 \n" - "vmovl.s8 q9, d15 \n" - "vmlal.s16 q11, d16, d6 \n" - "vmlal.s16 q11, d18, d7 \n" - "vmlal.s16 q11, d20, d8 \n" - "vmlal.s16 q12, d17, d6 \n" - "vmlal.s16 q12, d19, d7 \n" - "vmlal.s16 q12, d21, d8 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q11, d16, %e[ker0][0] \n" + "vmlal.s16 q11, d18, %e[ker0][1] \n" + "vmlal.s16 q11, d20, %e[ker0][2] \n" + "vmull.s16 q12, d17, %e[ker0][0] \n" + "vmlal.s16 q12, d19, %e[ker0][1] \n" + "vmlal.s16 q12, d21, %e[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q11, d16, %f[ker0][0] \n" + "vmlal.s16 q11, d18, %f[ker0][1] \n" + "vmlal.s16 q11, d20, %f[ker0][2] \n" + "vmlal.s16 q12, d17, %f[ker0][0] \n" + "vmlal.s16 q12, d19, %f[ker0][1] \n" + "vmlal.s16 q12, d21, %f[ker0][2] \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, %e[ker1][0] \n" + "vmlal.s16 q11, d18, %e[ker1][1] \n" + "vmlal.s16 q11, d20, %e[ker1][2] \n" + "vmlal.s16 q12, d17, %e[ker1][0] \n" + "vmlal.s16 q12, d19, %e[ker1][1] \n" + "vmlal.s16 q12, d21, %e[ker1][2] \n" // store row 0 - "vst1.32 {d22-d24}, [%[output_ptr0]]! \n" - - "subs %[loop], #1 \n" - "bne loop_1h6w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - "vld2.8 {d10, d11}, [%[input_ptr0]] \n" - "vld2.8 {d12, d13}, [%[input_ptr1]] \n" - "vld2.8 {d14, d15}, [%[input_ptr2]] \n" - "vext.s8 d9, d10, d10, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d10 \n" - "vmovl.s8 q9, d11 \n" - "vmull.s16 q11, d16, d0 \n" - "vmlal.s16 q11, d18, d1 \n" - "vmlal.s16 q11, d20, d2 \n" - "vmull.s16 q12, d17, d0 \n" - "vmlal.s16 q12, d19, d1 \n" - "vmlal.s16 q12, d21, d2 \n" - - "vext.s8 d9, d12, d12, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d12 \n" - "vmovl.s8 q9, d13 \n" - "vmlal.s16 q11, d16, d3 \n" - "vmlal.s16 q11, d18, d4 \n" - "vmlal.s16 q11, d20, d5 \n" - "vmlal.s16 q12, d17, d3 \n" - "vmlal.s16 q12, d19, d4 \n" - "vmlal.s16 q12, d21, d5 \n" - - "vext.s8 d9, d14, d14, #1 \n" - "vmovl.s8 q10, d9 \n" - "vmovl.s8 q8, d14 \n" - "vmovl.s8 q9, d15 \n" - "vmlal.s16 q11, d16, d6 \n" - "vmlal.s16 q11, d18, d7 \n" - "vmlal.s16 q11, d20, d8 \n" - "vmlal.s16 q12, d17, d6 \n" - "vmlal.s16 q12, d19, d7 \n" - "vmlal.s16 q12, d21, d8 \n" - - "cmp %[remain], #4 \n" - "blt store_1h2w_%= \n" - "vst1.32 {q11}, [%[output_ptr0]]! \n" - "cmp %[remain], #5 \n" - "blt end_%= \n" - "vst1.32 {d24[0]}, [%[output_ptr0]]! \n" - "b end_%= \n" - - "store_1h2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_1h1w_%= \n" - "vst1.32 {d22}, [%[output_ptr0]]! \n" - "cmp %[remain], #3 \n" - "blt end_%= \n" - "vst1.32 {d23[0]}, [%[output_ptr0]]! \n" - "b end_%= \n" - - "store_1h1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" - "end_%=: \n" + "vst1.32 {d22-d24}, [%[output_ptr0]]! \n" + + "subs %[loop], #1 \n" + "bne loop_1h6w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain], lsl #1 \n" + + "vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n" + "vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n" + "vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n" + "vext.s8 d9, d10, d10, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q11, d16, %e[ker0][0] \n" + "vmlal.s16 q11, d18, %e[ker0][1] \n" + "vmlal.s16 q11, d20, %e[ker0][2] \n" + "vmull.s16 q12, d17, %e[ker0][0] \n" + "vmlal.s16 q12, d19, %e[ker0][1] \n" + "vmlal.s16 q12, d21, %e[ker0][2] \n" + + "vext.s8 d9, d12, d12, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q11, d16, %f[ker0][0] \n" + "vmlal.s16 q11, d18, %f[ker0][1] \n" + "vmlal.s16 q11, d20, %f[ker0][2] \n" + "vmlal.s16 q12, d17, %f[ker0][0] \n" + "vmlal.s16 q12, d19, %f[ker0][1] \n" + "vmlal.s16 q12, d21, %f[ker0][2] \n" + + "vext.s8 d9, d14, d14, #1 \n" + "vmovl.s8 q10, d9 \n" + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q11, d16, %e[ker1][0] \n" + "vmlal.s16 q11, d18, %e[ker1][1] \n" + "vmlal.s16 q11, d20, %e[ker1][2] \n" + "vmlal.s16 q12, d17, %e[ker1][0] \n" + "vmlal.s16 q12, d19, %e[ker1][1] \n" + "vmlal.s16 q12, d21, %e[ker1][2] \n" + + "cmp %[remain], #4 \n" + "blt store_1h2w_%= \n" + "vst1.32 {q11}, [%[output_ptr0]]! \n" + "cmp %[remain], #5 \n" + "blt end_%= \n" + "vst1.32 {d24[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h2w_%=: \n" + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d22}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d23[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "cmp %[remain], #1 \n" + "blt end_%= \n" + "vst1.32 {d22[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" : [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2), [loop] "+r"(loop) - : [remain] "r"(output_w_remain) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "r0"); + : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w > 0) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int32x4_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + } else { + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + int32_t sum0 = vgetq_lane_s32(acc, 0); + if (padding == 1) { + sum0 += vgetq_lane_s32(acc, 1); + } + *output_ptr0 = sum0; + } + output_ptr0++; + } + } + } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); } } } diff --git a/src/operators/math/depthwise_conv5x5.cpp b/src/operators/math/depthwise_conv5x5.cpp new file mode 100644 index 0000000000000000000000000000000000000000..792a98659e7b03d4220b7e2ded540782ce880931 --- /dev/null +++ b/src/operators/math/depthwise_conv5x5.cpp @@ -0,0 +1,737 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(__ARM_NEON__) && !defined(__aarch64__) + +#include "operators/math/depthwise_conv5x5.h" +#include + +namespace paddle_mobile { +namespace operators { +namespace math { + +#ifndef __aarch64__ +inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) { + float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0)); + float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1)); + return vcombine_f32(sum0, sum1); +} +#endif + +template +inline void Depth5x5NormalRowLoadInput(const float *input, float32x4_t *y) { + y[0] = vld1q_f32(input); + y[4] = vld1q_f32(input + 4); + y[1] = vextq_f32(y[0], y[4], 1); + y[2] = vextq_f32(y[0], y[4], 2); + y[3] = vextq_f32(y[0], y[4], 3); +} + +template <> +inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) { + float32x4x2_t x = vld2q_f32(input); + y[0] = x.val[0]; + y[1] = x.val[1]; + y[2] = vextq_f32(y[0], y[0], 1); + y[3] = vextq_f32(y[1], y[1], 1); + y[4] = vextq_f32(y[0], y[0], 2); +} + +#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ + for (int w = start; w < end; ++w) { \ + const int w_in_start = -padding_w + w * Stride_w; \ + const int w_in_end = w_in_start + 5; \ + const int w_start = w_in_start > 0 ? w_in_start : 0; \ + const int w_end = w_in_end < input_w ? w_in_end : input_w; \ + float value = 0; \ + for (int h_in = h_start; h_in < h_end; ++h_in) { \ + for (int w_in = w_start; w_in < w_end; ++w_in) { \ + value += filter[(h_in - h_in_start) * 5 + (w_in - w_in_start)] * \ + input[h_in * input_w + w_in]; \ + } \ + } \ + output_ptr[w] = value; \ + } + +template +inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, + const int h_output, const int input_h, + const int input_w, const int padding_h, + const int padding_w, const int output_w, + float *output, float32x4_t *ker, + float32_t *ker1) { + const int h_in_start = -padding_h + h_output * Stride_h; + const int h_in_end = h_in_start + 5; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int h_end = h_in_end < input_h ? h_in_end : input_h; + + int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + int valid_w_end = output_w - valid_w_start; + float *output_ptr = output + h_output * output_w; + // border left + DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) + // middle + int output_tiles = (valid_w_end - valid_w_start) >> 2; + float32x4_t _sum, _x[5]; + // valid w + for (int w = 0; w < output_tiles * 4; w += 4) { + _sum = vdupq_n_f32(0.f); + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride_w - padding_w; + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth5x5NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlaq_n_f32(_sum, _x[0], ker1[index]); + _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[2], vget_low_f32(ker[index]), 1); + _sum = vmlaq_lane_f32(_sum, _x[3], vget_high_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1); + } + vst1q_f32(output_ptr + output_offset, _sum); + } + // remain valid w + int remain = (valid_w_end - valid_w_start) & 0x3; + if (remain > 0) { + _sum = vdupq_n_f32(0.f); + int remain_start = valid_w_start + (output_tiles << 2); + int input_w_offset = remain_start * Stride_w - padding_w; + float *output_ptr0 = output_ptr + remain_start; + + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth5x5NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlaq_n_f32(_sum, _x[0], ker1[index]); + _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[2], vget_low_f32(ker[index]), 1); + _sum = vmlaq_lane_f32(_sum, _x[3], vget_high_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1); + } + switch (remain) { + case 1: + vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(_sum)); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(_sum)); + vst1_lane_f32(output_ptr0 + 2, vget_high_f32(_sum), 0); + break; + } + } + // border right + DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) +} + +template <> +void DepthwiseConv5x5S1(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + const float *filter_data = filter.data(); + float *out_data = output->mutable_data(); + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + int valid_h_start = padding_h; + int valid_h_end = output_h - valid_h_start; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = padding_w; + int valid_w_end = output_w - valid_w_start; + int valid_w = valid_w_end - valid_w_start; + + #pragma omp parallel for + for (int g = 0; g < input.dims()[1]; ++g) { + const float *input_ptr = input_data + g * image_size; + const float *filter_ptr = filter_data + g * 25; + float *output_ptr = out_data + g * out_image_size; + + const float *filter_ptr0 = filter_ptr; + const float *filter_ptr1 = filter_ptr0 + 5; + const float *filter_ptr2 = filter_ptr1 + 5; + const float *filter_ptr3 = filter_ptr2 + 5; + const float *filter_ptr4 = filter_ptr3 + 5; + float32x4_t _ker[7]; + float32_t _ker1[5] = {*filter_ptr0, *filter_ptr1, *filter_ptr2, + *filter_ptr3, *filter_ptr4}; + _ker[0] = vld1q_f32(filter_ptr0 + 1); + _ker[1] = vld1q_f32(filter_ptr1 + 1); + _ker[2] = vld1q_f32(filter_ptr2 + 1); + _ker[3] = vld1q_f32(filter_ptr3 + 1); + _ker[4] = vld1q_f32(filter_ptr4 + 1); + _ker[5] = vld1q_f32(_ker1); + _ker[6] = vld1q_f32(_ker1 + 4); + + // pad top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker, _ker1); + } + + // output 4x4 + int output_w_tiles = valid_w / 4; + int output_w_remain = valid_w - output_w_tiles * 4; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + const float *input_ptr5 = input_ptr4 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t row4 = vld1q_f32(input_ptr4); + float32x4_t row5 = vld1q_f32(input_ptr5); + float32x4_t zero = vdupq_n_f32(0.f); + float32x4_t acc0, acc1; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 5) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + } else { + acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc0 = vmlaq_f32(acc0, row3, _ker[3]); + acc0 = vmlaq_f32(acc0, row4, _ker[4]); + acc1 = vmulq_f32(row1, _ker[0]); + acc1 = vmlaq_f32(acc1, row2, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[2]); + acc1 = vmlaq_f32(acc1, row4, _ker[3]); + acc1 = vmlaq_f32(acc1, row5, _ker[4]); + acc0 = vpaddq_f32(acc0, acc1); + float32x2_t sum = + vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0)); + vst1_lane_f32(output_ptr0 + w, sum, 0); + vst1_lane_f32(output_ptr1 + w, sum, 1); + + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + row3 = vextq_f32(zero, row3, 3); + row4 = vextq_f32(zero, row4, 3); + row5 = vextq_f32(zero, row5, 3); + } + } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + } + // valid + int loop = output_w_tiles; + asm volatile( + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #16 \n" + "loop_2h4w_%=: \n" + "vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n" + "vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n" + "vmul.f32 q14, q7, %e[ker0][0] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr0][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr0][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr0][0] \n" + "vmla.f32 q14, q8, %f[kr0][1] \n" + + "vmla.f32 q14, q9, %e[ker0][1] \n" + "vmul.f32 q15, q9, %e[ker0][0] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr1][0] \n" + "vmla.f32 q15, q13, %e[kr0][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr1][1] \n" + "vmla.f32 q15, q13, %e[kr0][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr1][0] \n" + "vmla.f32 q15, q13, %f[kr0][0] \n" + "vmla.f32 q14, q10, %f[kr1][1] \n" + "vmla.f32 q15, q10, %f[kr0][1] \n" + + "vmla.f32 q14, q11, %f[ker0][0] \n" + "vmla.f32 q15, q11, %e[ker0][1] \n" + "vext.32 q13, q11, q12, #1 \n" + "vmla.f32 q14, q13, %e[kr2][0] \n" + "vmla.f32 q15, q13, %e[kr1][0] \n" + "vext.32 q13, q11, q12, #2 \n" + "vmla.f32 q14, q13, %e[kr2][1] \n" + "vmla.f32 q15, q13, %e[kr1][1] \n" + "vext.32 q13, q11, q12, #3 \n" + "vmla.f32 q14, q13, %f[kr2][0] \n" + "vmla.f32 q15, q13, %f[kr1][0] \n" + "vmla.f32 q14, q12, %f[kr2][1] \n" + "vmla.f32 q15, q12, %f[kr1][1] \n" + + "vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n" + "vld1.32 {d22-d25}, [%[input_ptr5]], r0 \n" + "vmla.f32 q14, q7, %f[ker0][1] \n" + "vmla.f32 q15, q7, %f[ker0][0] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr3][0] \n" + "vmla.f32 q15, q13, %e[kr2][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr3][1] \n" + "vmla.f32 q15, q13, %e[kr2][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr3][0] \n" + "vmla.f32 q15, q13, %f[kr2][0] \n" + "vmla.f32 q14, q8, %f[kr3][1] \n" + "vmla.f32 q15, q8, %f[kr2][1] \n" + + "vmla.f32 q14, q9, %e[ker1][0] \n" + "vmla.f32 q15, q9, %f[ker0][1] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr4][0] \n" + "vmla.f32 q15, q13, %e[kr3][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr4][1] \n" + "vmla.f32 q15, q13, %e[kr3][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr4][0] \n" + "vmla.f32 q15, q13, %f[kr3][0] \n" + "vmla.f32 q14, q10, %f[kr4][1] \n" + "vmla.f32 q15, q10, %f[kr3][1] \n" + + "vmla.f32 q15, q11, %e[ker1][0] \n" + "vext.32 q13, q11, q12, #1 \n" + "vmla.f32 q15, q13, %e[kr4][0] \n" + "vext.32 q13, q11, q12, #2 \n" + "vmla.f32 q15, q13, %e[kr4][1] \n" + "vext.32 q13, q11, q12, #3 \n" + "vmla.f32 q15, q13, %f[kr4][0] \n" + "vmla.f32 q15, q12, %f[kr4][1] \n" + // restore output + "vst1.32 {q14}, [%[output_ptr0]]! \n" + "vst1.32 {q15}, [%[output_ptr1]]! \n" + "subs %[loop], #1 \n" + "bne loop_2h4w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain], lsl #2 \n" + "vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n" + "vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n" + "vmul.f32 q14, q7, %e[ker0][0] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr0][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr0][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr0][0] \n" + "vmla.f32 q14, q8, %f[kr0][1] \n" + + "vmla.f32 q14, q9, %e[ker0][1] \n" + "vmul.f32 q15, q9, %e[ker0][0] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr1][0] \n" + "vmla.f32 q15, q13, %e[kr0][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr1][1] \n" + "vmla.f32 q15, q13, %e[kr0][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr1][0] \n" + "vmla.f32 q15, q13, %f[kr0][0] \n" + "vmla.f32 q14, q10, %f[kr1][1] \n" + "vmla.f32 q15, q10, %f[kr0][1] \n" + + "vmla.f32 q14, q11, %f[ker0][0] \n" + "vmla.f32 q15, q11, %e[ker0][1] \n" + "vext.32 q13, q11, q12, #1 \n" + "vmla.f32 q14, q13, %e[kr2][0] \n" + "vmla.f32 q15, q13, %e[kr1][0] \n" + "vext.32 q13, q11, q12, #2 \n" + "vmla.f32 q14, q13, %e[kr2][1] \n" + "vmla.f32 q15, q13, %e[kr1][1] \n" + "vext.32 q13, q11, q12, #3 \n" + "vmla.f32 q14, q13, %f[kr2][0] \n" + "vmla.f32 q15, q13, %f[kr1][0] \n" + "vmla.f32 q14, q12, %f[kr2][1] \n" + "vmla.f32 q15, q12, %f[kr1][1] \n" + + "vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n" + "vld1.32 {d22-d25}, [%[input_ptr5]], r0 \n" + "vmla.f32 q14, q7, %f[ker0][1] \n" + "vmla.f32 q15, q7, %f[ker0][0] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr3][0] \n" + "vmla.f32 q15, q13, %e[kr2][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr3][1] \n" + "vmla.f32 q15, q13, %e[kr2][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr3][0] \n" + "vmla.f32 q15, q13, %f[kr2][0] \n" + "vmla.f32 q14, q8, %f[kr3][1] \n" + "vmla.f32 q15, q8, %f[kr2][1] \n" + + "vmla.f32 q14, q9, %e[ker1][0] \n" + "vmla.f32 q15, q9, %f[ker0][1] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr4][0] \n" + "vmla.f32 q15, q13, %e[kr3][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr4][1] \n" + "vmla.f32 q15, q13, %e[kr3][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr4][0] \n" + "vmla.f32 q15, q13, %f[kr3][0] \n" + "vmla.f32 q14, q10, %f[kr4][1] \n" + "vmla.f32 q15, q10, %f[kr3][1] \n" + + "vmla.f32 q15, q11, %e[ker1][0] \n" + "vext.32 q13, q11, q12, #1 \n" + "vmla.f32 q15, q13, %e[kr4][0] \n" + "vext.32 q13, q11, q12, #2 \n" + "vmla.f32 q15, q13, %e[kr4][1] \n" + "vext.32 q13, q11, q12, #3 \n" + "vmla.f32 q15, q13, %f[kr4][0] \n" + "vmla.f32 q15, q12, %f[kr4][1] \n" + + "cmp %[remain], #2 \n" + "blt store_2h1w_%= \n" + "vst1.32 {d28}, [%[output_ptr0]]! \n" + "vst1.32 {d30}, [%[output_ptr1]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d29[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d31[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h1w_%=: \n" + "vst1.32 {d28[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d30[0]}, [%[output_ptr1]]! \n" + "end_%=: \n" + : [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), + [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [loop] "+r"(loop) + : [remain] "r"(output_w_remain), [kr0] "w"(_ker[0]), + [kr1] "w"(_ker[1]), [kr2] "w"(_ker[2]), [kr3] "w"(_ker[3]), + [kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6]) + : "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "r0"); + // pad right + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t row4 = vld1q_f32(input_ptr4); + float32x4_t row5 = vld1q_f32(input_ptr5); + float32x4_t zero = vdupq_n_f32(0.f); + float32x4_t acc0, acc1; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 5 - (padding_w + input_w); + if (padding >= 5) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + } else { + int iw = w - valid_w_end; + float sum0 = input_ptr0[iw] * filter_ptr0[0] + + input_ptr1[iw] * filter_ptr1[0] + + input_ptr2[iw] * filter_ptr2[0] + + input_ptr3[iw] * filter_ptr3[0] + + input_ptr4[iw] * filter_ptr4[0]; + float sum1 = input_ptr1[iw] * filter_ptr0[0] + + input_ptr2[iw] * filter_ptr1[0] + + input_ptr3[iw] * filter_ptr2[0] + + input_ptr4[iw] * filter_ptr3[0] + + input_ptr5[iw] * filter_ptr4[0]; + row0 = vextq_f32(row0, zero, 1); + row1 = vextq_f32(row1, zero, 1); + row2 = vextq_f32(row2, zero, 1); + row3 = vextq_f32(row3, zero, 1); + row4 = vextq_f32(row4, zero, 1); + row5 = vextq_f32(row5, zero, 1); + acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc0 = vmlaq_f32(acc0, row3, _ker[3]); + acc0 = vmlaq_f32(acc0, row4, _ker[4]); + acc1 = vmulq_f32(row1, _ker[0]); + acc1 = vmlaq_f32(acc1, row2, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[2]); + acc1 = vmlaq_f32(acc1, row4, _ker[3]); + acc1 = vmlaq_f32(acc1, row5, _ker[4]); + acc0 = vpaddq_f32(acc0, acc1); + float32x2_t sum = + vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0)); + sum0 += vget_lane_f32(sum, 0); + sum1 += vget_lane_f32(sum, 1); + *output_ptr0 = sum0; + *output_ptr1 = sum1; + } + output_ptr0++; + output_ptr1++; + } + } + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + if (start_h < valid_h_end) { + const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + float *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t row4 = vld1q_f32(input_ptr4); + float32x4_t zero = vdupq_n_f32(0.f); + float32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 5) { + output_ptr0[w] = 0.f; + } else { + acc = vmulq_f32(row0, _ker[0]); + acc = vmlaq_f32(acc, row1, _ker[1]); + acc = vmlaq_f32(acc, row2, _ker[2]); + acc = vmlaq_f32(acc, row3, _ker[3]); + acc = vmlaq_f32(acc, row4, _ker[4]); + float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc)); + sum = vpadd_f32(sum, sum); + vst1_lane_f32(output_ptr0 + w, sum, 0); + + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + row3 = vextq_f32(zero, row3, 3); + row4 = vextq_f32(zero, row4, 3); + } + } + output_ptr0 += valid_w_start; + } + // valid + int loop = output_w_tiles; + asm volatile( + "cmp %[loop], #0 \n" + "ble start_remain_%= \n" + "mov r0, #16 \n" + "loop_1h4w_%=: \n" + "vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n" + "vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n" + "vmul.f32 q14, q7, %e[ker0][0] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr0][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr0][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr0][0] \n" + "vmla.f32 q14, q8, %f[kr0][1] \n" + + "vmla.f32 q14, q9, %e[ker0][1] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr1][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr1][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr1][0] \n" + "vmla.f32 q14, q10, %f[kr1][1] \n" + + "vmla.f32 q14, q11, %f[ker0][0] \n" + "vext.32 q13, q11, q12, #1 \n" + "vmla.f32 q14, q13, %e[kr2][0] \n" + "vext.32 q13, q11, q12, #2 \n" + "vmla.f32 q14, q13, %e[kr2][1] \n" + "vext.32 q13, q11, q12, #3 \n" + "vmla.f32 q14, q13, %f[kr2][0] \n" + "vmla.f32 q14, q12, %f[kr2][1] \n" + + "vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n" + "vmla.f32 q14, q7, %f[ker0][1] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr3][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr3][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr3][0] \n" + "vmla.f32 q14, q8, %f[kr3][1] \n" + + "vmla.f32 q14, q9, %e[ker1][0] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr4][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr4][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr4][0] \n" + "vmla.f32 q14, q10, %f[kr4][1] \n" + + // restore output + "vst1.32 {q14}, [%[output_ptr0]]! \n" + "subs %[loop], #1 \n" + "bne loop_1h4w_%= \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain], lsl #2 \n" + "vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n" + "vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n" + "vmul.f32 q14, q7, %e[ker0][0] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr0][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr0][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr0][0] \n" + "vmla.f32 q14, q8, %f[kr0][1] \n" + + "vmla.f32 q14, q9, %e[ker0][1] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr1][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr1][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr1][0] \n" + "vmla.f32 q14, q10, %f[kr1][1] \n" + + "vmla.f32 q14, q11, %f[ker0][0] \n" + "vext.32 q13, q11, q12, #1 \n" + "vmla.f32 q14, q13, %e[kr2][0] \n" + "vext.32 q13, q11, q12, #2 \n" + "vmla.f32 q14, q13, %e[kr2][1] \n" + "vext.32 q13, q11, q12, #3 \n" + "vmla.f32 q14, q13, %f[kr2][0] \n" + "vmla.f32 q14, q12, %f[kr2][1] \n" + + "vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n" + "vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n" + "vmla.f32 q14, q7, %f[ker0][1] \n" + "vext.32 q13, q7, q8, #1 \n" + "vmla.f32 q14, q13, %e[kr3][0] \n" + "vext.32 q13, q7, q8, #2 \n" + "vmla.f32 q14, q13, %e[kr3][1] \n" + "vext.32 q13, q7, q8, #3 \n" + "vmla.f32 q14, q13, %f[kr3][0] \n" + "vmla.f32 q14, q8, %f[kr3][1] \n" + + "vmla.f32 q14, q9, %e[ker1][0] \n" + "vext.32 q13, q9, q10, #1 \n" + "vmla.f32 q14, q13, %e[kr4][0] \n" + "vext.32 q13, q9, q10, #2 \n" + "vmla.f32 q14, q13, %e[kr4][1] \n" + "vext.32 q13, q9, q10, #3 \n" + "vmla.f32 q14, q13, %f[kr4][0] \n" + "vmla.f32 q14, q10, %f[kr4][1] \n" + + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d28}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d29[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "vst1.32 {d28[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" + : [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [output_ptr0] "+r"(output_ptr0), + [loop] "+r"(loop) + : [remain] "r"(output_w_remain), [kr0] "w"(_ker[0]), + [kr1] "w"(_ker[1]), [kr2] "w"(_ker[2]), [kr3] "w"(_ker[3]), + [kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6]) + : "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "r0"); + // pad right + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t row4 = vld1q_f32(input_ptr4); + float32x4_t zero = vdupq_n_f32(0.f); + float32x4_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 5 - (padding_w + input_w); + if (padding >= 5) { + *output_ptr0 = 0.f; + } else { + int iw = w - valid_w_end; + float sum0 = input_ptr0[iw] * filter_ptr0[0] + + input_ptr1[iw] * filter_ptr1[0] + + input_ptr2[iw] * filter_ptr2[0] + + input_ptr3[iw] * filter_ptr3[0] + + input_ptr4[iw] * filter_ptr4[0]; + row0 = vextq_f32(row0, zero, 1); + row1 = vextq_f32(row1, zero, 1); + row2 = vextq_f32(row2, zero, 1); + row3 = vextq_f32(row3, zero, 1); + row4 = vextq_f32(row4, zero, 1); + acc = vmulq_f32(row0, _ker[0]); + acc = vmlaq_f32(acc, row1, _ker[1]); + acc = vmlaq_f32(acc, row2, _ker[2]); + acc = vmlaq_f32(acc, row3, _ker[3]); + acc = vmlaq_f32(acc, row4, _ker[4]); + float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc)); + sum = vpadd_f32(sum, sum); + sum0 += vget_lane_f32(sum, 0); + *output_ptr0 = sum0; + } + output_ptr0++; + } + } + } + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker, _ker1); + } + } +} + +template <> +void DepthwiseConv5x5S2(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) {} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/depthwise_conv5x5.h b/src/operators/math/depthwise_conv5x5.h new file mode 100644 index 0000000000000000000000000000000000000000..d047bbfa1ac179e0ef0b1b6705e349890b25e800 --- /dev/null +++ b/src/operators/math/depthwise_conv5x5.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "framework/tensor.h" +#include "operators/math/conv_func.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +// TODO(hjchen2) need to be implemented +// template +// void DepthwiseConv5x5(const framework::Tensor *input, +// const framework::Tensor *filter, +// const std::vector &strides, +// const std::vector &paddings, +// framework::Tensor *output); + +template +void DepthwiseConv5x5S1(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output); + +template +void DepthwiseConv5x5S2(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv5x5_int8.cpp b/src/operators/math/depthwise_conv5x5_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a92d48272f3c3abbc9c86a652521db4564498d2e --- /dev/null +++ b/src/operators/math/depthwise_conv5x5_int8.cpp @@ -0,0 +1,1041 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(__ARM_NEON__) && !defined(__aarch64__) + +#include +#include "operators/math/depthwise_conv5x5.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +#ifndef __aarch64__ +inline int32x4_t vpaddq_s32(int32x4_t r0, int32x4_t r1) { + int32x2_t sum0 = vpadd_s32(vget_low_s32(r0), vget_high_s32(r0)); + int32x2_t sum1 = vpadd_s32(vget_low_s32(r1), vget_high_s32(r1)); + return vcombine_s32(sum0, sum1); +} +#endif + +template +inline void Depth5x5NormalRowLoadInput(const int8_t *input, int16x4_t *y) { + int16x8_t x = vmovl_s8(vld1_s8(input)); + y[0] = vget_low_s16(x); + y[4] = vget_high_s16(x); + y[1] = vext_s16(y[0], y[4], 1); + y[2] = vext_s16(y[0], y[4], 2); + y[3] = vext_s16(y[0], y[4], 3); +} + +template <> +inline void Depth5x5NormalRowLoadInput<2>(const int8_t *input, int16x4_t *y) { + int8x8x2_t x = vld2_s8(input); + y[0] = vget_low_s16(vmovl_s8(x.val[0])); + y[1] = vget_low_s16(vmovl_s8(x.val[1])); + y[2] = vext_s16(y[0], y[0], 1); + y[3] = vext_s16(y[1], y[1], 1); + y[4] = vext_s16(y[0], y[0], 2); +} + +#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ + for (int w = start; w < end; ++w) { \ + const int w_in_start = -padding_w + w * Stride_w; \ + const int w_in_end = w_in_start + 5; \ + const int w_start = w_in_start > 0 ? w_in_start : 0; \ + const int w_end = w_in_end < input_w ? w_in_end : input_w; \ + int32_t value = 0; \ + for (int h_in = h_start; h_in < h_end; ++h_in) { \ + for (int w_in = w_start; w_in < w_end; ++w_in) { \ + value += filter[(h_in - h_in_start) * 5 + (w_in - w_in_start)] * \ + input[h_in * input_w + w_in]; \ + } \ + } \ + output_ptr[w] = value; \ + } + +template +inline void DepthwiseConv5x5NormalRow(const int8_t *input, const int8_t *filter, + const int h_output, const int input_h, + const int input_w, const int padding_h, + const int padding_w, const int output_w, + int32_t *output, int16x4_t *ker, + int16_t *ker1) { + const int h_in_start = -padding_h + h_output * Stride_h; + const int h_in_end = h_in_start + 5; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int h_end = h_in_end < input_h ? h_in_end : input_h; + + int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + int valid_w_end = output_w - valid_w_start; + int32_t *output_ptr = output + h_output * output_w; + // border left + DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) + // middle + int output_tiles = (valid_w_end - valid_w_start) >> 2; + int16x4_t _x[5]; + int32x4_t _sum; + // valid w + for (int w = 0; w < output_tiles * 4; w += 4) { + _sum = vdupq_n_s32(0); + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride_w - padding_w; + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth5x5NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlal_n_s16(_sum, _x[0], ker1[index]); + _sum = vmlal_lane_s16(_sum, _x[1], ker[index], 0); + _sum = vmlal_lane_s16(_sum, _x[2], ker[index], 1); + _sum = vmlal_lane_s16(_sum, _x[3], ker[index], 2); + _sum = vmlal_lane_s16(_sum, _x[4], ker[index], 3); + } + vst1q_s32(output_ptr + output_offset, _sum); + } + // remain valid w + int remain = (valid_w_end - valid_w_start) & 0x3; + if (remain > 0) { + _sum = vdupq_n_s32(0); + int remain_start = valid_w_start + (output_tiles << 2); + int input_w_offset = remain_start * Stride_w - padding_w; + int32_t *output_ptr0 = output_ptr + remain_start; + + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth5x5NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlal_n_s16(_sum, _x[0], ker1[index]); + _sum = vmlal_lane_s16(_sum, _x[1], ker[index], 0); + _sum = vmlal_lane_s16(_sum, _x[2], ker[index], 1); + _sum = vmlal_lane_s16(_sum, _x[3], ker[index], 2); + _sum = vmlal_lane_s16(_sum, _x[4], ker[index], 3); + } + switch (remain) { + case 1: + vst1_lane_s32(output_ptr0, vget_low_s32(_sum), 0); + break; + case 2: + vst1_s32(output_ptr0, vget_low_s32(_sum)); + break; + case 3: + vst1_s32(output_ptr0, vget_low_s32(_sum)); + vst1_lane_s32(output_ptr0 + 2, vget_high_s32(_sum), 0); + break; + } + } + // border right + DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) +} + +template <> +void DepthwiseConv5x5S1(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) { + const int8_t *input_data = input.data(); + const int8_t *filter_data = filter.data(); + int32_t *out_data = output->mutable_data(); + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + int valid_h_start = padding_h; + int valid_h_end = output_h - valid_h_start; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = padding_w; + int valid_w_end = output_w - valid_w_start; + int valid_w = valid_w_end - valid_w_start; + + #pragma omp parallel for + for (int g = 0; g < input.dims()[1]; ++g) { + const int8_t *input_ptr = input_data + g * image_size; + const int8_t *filter_ptr = filter_data + g * 25; + int32_t *output_ptr = out_data + g * out_image_size; + + const int8_t *filter_ptr0 = filter_ptr; + const int8_t *filter_ptr1 = filter_ptr0 + 5; + const int8_t *filter_ptr2 = filter_ptr1 + 5; + const int8_t *filter_ptr3 = filter_ptr2 + 5; + const int8_t *filter_ptr4 = filter_ptr3 + 5; + int16_t kernel[5] = {*filter_ptr0, *filter_ptr1, *filter_ptr2, *filter_ptr3, + *filter_ptr4}; + int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0 + 1))); + int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1 + 1))); + int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2 + 1))); + int16x4_t _k3 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr3 + 1))); + int16x4_t _k4 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr4 + 1))); + int16x4_t _k5 = vld1_s16(kernel); + int16x4_t _k6 = vld1_s16(kernel + 4); + int16x8_t _ker0 = vcombine_s16(_k0, _k1); + int16x8_t _ker1 = vcombine_s16(_k2, _k3); + int16x8_t _ker2 = vcombine_s16(_k4, _k5); + int16x8_t _ker3 = vcombine_s16(_k6, _k6); + int16x4_t _ker[7] = {_k0, _k1, _k2, _k3, _k4, _k5, _k6}; + + // pad top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker, kernel); + } + + // output 4x4 + int output_w_tiles = valid_w / 8; + int output_w_remain = valid_w - output_w_tiles * 8; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + const int8_t *input_ptr3 = input_ptr2 + input_w; + const int8_t *input_ptr4 = input_ptr3 + input_w; + const int8_t *input_ptr5 = input_ptr4 + input_w; + int32_t *output_ptr0 = output_ptr + h * output_w; + int32_t *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4))); + int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5))); + int16x4_t zero = vdup_n_s16(0); + int32x4_t acc0, acc1; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 5) { + output_ptr0[w] = 0; + output_ptr1[w] = 0; + } else { + acc0 = vmull_s16(row0, _ker[0]); + acc0 = vmlal_s16(acc0, row1, _ker[1]); + acc0 = vmlal_s16(acc0, row2, _ker[2]); + acc0 = vmlal_s16(acc0, row3, _ker[3]); + acc0 = vmlal_s16(acc0, row4, _ker[4]); + acc1 = vmull_s16(row1, _ker[0]); + acc1 = vmlal_s16(acc1, row2, _ker[1]); + acc1 = vmlal_s16(acc1, row3, _ker[2]); + acc1 = vmlal_s16(acc1, row4, _ker[3]); + acc1 = vmlal_s16(acc1, row5, _ker[4]); + acc0 = vpaddq_s32(acc0, acc1); + int32x2_t sum = vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0)); + vst1_lane_s32(output_ptr0 + w, sum, 0); + vst1_lane_s32(output_ptr1 + w, sum, 1); + + row0 = vext_s16(zero, row0, 3); + row1 = vext_s16(zero, row1, 3); + row2 = vext_s16(zero, row2, 3); + row3 = vext_s16(zero, row3, 3); + row4 = vext_s16(zero, row4, 3); + row5 = vext_s16(zero, row5, 3); + } + } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + } + // valid + int loop = output_w_tiles; + int w_remain = output_w_remain; + asm volatile( + "cmp %[loop], #0 \n" + "ble start_remain4_%= \n" + "mov r0, #8 \n" + "loop_2h8w_%=: \n" + "vld1.s8 {d10-d11}, [%[input_ptr0]], r0 \n" + "vld1.s8 {d12-d13}, [%[input_ptr1]], r0 \n" + "vld1.s8 {d14-d15}, [%[input_ptr2]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q12, d16, %f[ker2][0] \n" + "vmull.s16 q13, d17, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker0][0] \n" + "vmlal.s16 q13, d21, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker0][1] \n" + "vmlal.s16 q13, d21, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker0][2] \n" + "vmlal.s16 q13, d21, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker0][3] \n" + "vmlal.s16 q13, d21, %e[ker0][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d16, %f[ker2][1] \n" + "vmlal.s16 q13, d17, %f[ker2][1] \n" + "vmull.s16 q14, d16, %f[ker2][0] \n" + "vmull.s16 q15, d17, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker0][0] \n" + "vmlal.s16 q13, d21, %f[ker0][0] \n" + "vmlal.s16 q14, d20, %e[ker0][0] \n" + "vmlal.s16 q15, d21, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker0][1] \n" + "vmlal.s16 q13, d21, %f[ker0][1] \n" + "vmlal.s16 q14, d20, %e[ker0][1] \n" + "vmlal.s16 q15, d21, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker0][2] \n" + "vmlal.s16 q13, d21, %f[ker0][2] \n" + "vmlal.s16 q14, d20, %e[ker0][2] \n" + "vmlal.s16 q15, d21, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker0][3] \n" + "vmlal.s16 q13, d21, %f[ker0][3] \n" + "vmlal.s16 q14, d20, %e[ker0][3] \n" + "vmlal.s16 q15, d21, %e[ker0][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q12, d16, %f[ker2][2] \n" + "vmlal.s16 q13, d17, %f[ker2][2] \n" + "vmlal.s16 q14, d16, %f[ker2][1] \n" + "vmlal.s16 q15, d17, %f[ker2][1] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker1][0] \n" + "vmlal.s16 q13, d21, %e[ker1][0] \n" + "vmlal.s16 q14, d20, %f[ker0][0] \n" + "vmlal.s16 q15, d21, %f[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker1][1] \n" + "vmlal.s16 q13, d21, %e[ker1][1] \n" + "vmlal.s16 q14, d20, %f[ker0][1] \n" + "vmlal.s16 q15, d21, %f[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker1][2] \n" + "vmlal.s16 q13, d21, %e[ker1][2] \n" + "vmlal.s16 q14, d20, %f[ker0][2] \n" + "vmlal.s16 q15, d21, %f[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker1][3] \n" + "vmlal.s16 q13, d21, %e[ker1][3] \n" + "vmlal.s16 q14, d20, %f[ker0][3] \n" + "vmlal.s16 q15, d21, %f[ker0][3] \n" + + "vld1.s8 {d10-d11}, [%[input_ptr3]], r0 \n" + "vld1.s8 {d12-d13}, [%[input_ptr4]], r0 \n" + "vld1.s8 {d14-d15}, [%[input_ptr5]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmlal.s16 q12, d16, %f[ker2][3] \n" + "vmlal.s16 q13, d17, %f[ker2][3] \n" + "vmlal.s16 q14, d16, %f[ker2][2] \n" + "vmlal.s16 q15, d17, %f[ker2][2] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker1][0] \n" + "vmlal.s16 q13, d21, %f[ker1][0] \n" + "vmlal.s16 q14, d20, %e[ker1][0] \n" + "vmlal.s16 q15, d21, %e[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker1][1] \n" + "vmlal.s16 q13, d21, %f[ker1][1] \n" + "vmlal.s16 q14, d20, %e[ker1][1] \n" + "vmlal.s16 q15, d21, %e[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker1][2] \n" + "vmlal.s16 q13, d21, %f[ker1][2] \n" + "vmlal.s16 q14, d20, %e[ker1][2] \n" + "vmlal.s16 q15, d21, %e[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker1][3] \n" + "vmlal.s16 q13, d21, %f[ker1][3] \n" + "vmlal.s16 q14, d20, %e[ker1][3] \n" + "vmlal.s16 q15, d21, %e[ker1][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d16, %e[ker3][0] \n" + "vmlal.s16 q13, d17, %e[ker3][0] \n" + "vmlal.s16 q14, d16, %f[ker2][3] \n" + "vmlal.s16 q15, d17, %f[ker2][3] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker2][0] \n" + "vmlal.s16 q13, d21, %e[ker2][0] \n" + "vmlal.s16 q14, d20, %f[ker1][0] \n" + "vmlal.s16 q15, d21, %f[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker2][1] \n" + "vmlal.s16 q13, d21, %e[ker2][1] \n" + "vmlal.s16 q14, d20, %f[ker1][1] \n" + "vmlal.s16 q15, d21, %f[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker2][2] \n" + "vmlal.s16 q13, d21, %e[ker2][2] \n" + "vmlal.s16 q14, d20, %f[ker1][2] \n" + "vmlal.s16 q15, d21, %f[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker2][3] \n" + "vmlal.s16 q13, d21, %e[ker2][3] \n" + "vmlal.s16 q14, d20, %f[ker1][3] \n" + "vmlal.s16 q15, d21, %f[ker1][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q14, d16, %e[ker3][0] \n" + "vmlal.s16 q15, d17, %e[ker3][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q14, d20, %e[ker2][0] \n" + "vmlal.s16 q15, d21, %e[ker2][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q14, d20, %e[ker2][1] \n" + "vmlal.s16 q15, d21, %e[ker2][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q14, d20, %e[ker2][2] \n" + "vmlal.s16 q15, d21, %e[ker2][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q14, d20, %e[ker2][3] \n" + "vmlal.s16 q15, d21, %e[ker2][3] \n" + + // restore output + "vst1.32 {q12-q13}, [%[output_ptr0]]! \n" + "vst1.32 {q14-q15}, [%[output_ptr1]]! \n" + "subs %[loop], #1 \n" + "bne loop_2h8w_%= \n" + + "start_remain4_%=: \n" + "cmp %[remain], #4 \n" + "blt start_remain_%= \n" + "mov r0, #4 \n" + "vld1.s8 {d10}, [%[input_ptr0]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr1]], r0 \n" + "vld1.s8 {d14}, [%[input_ptr2]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmull.s16 q12, d16, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker0][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %f[ker2][1] \n" + "vmull.s16 q14, d16, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker0][0] \n" + "vmlal.s16 q14, d20, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker0][1] \n" + "vmlal.s16 q14, d20, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker0][2] \n" + "vmlal.s16 q14, d20, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker0][3] \n" + "vmlal.s16 q14, d20, %e[ker0][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmlal.s16 q12, d16, %f[ker2][2] \n" + "vmlal.s16 q14, d16, %f[ker2][1] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker1][0] \n" + "vmlal.s16 q14, d20, %f[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker1][1] \n" + "vmlal.s16 q14, d20, %f[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker1][2] \n" + "vmlal.s16 q14, d20, %f[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker1][3] \n" + "vmlal.s16 q14, d20, %f[ker0][3] \n" + + "vld1.s8 {d10}, [%[input_ptr3]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr4]], r0 \n" + "vld1.s8 {d14}, [%[input_ptr5]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmlal.s16 q12, d16, %f[ker2][3] \n" + "vmlal.s16 q14, d16, %f[ker2][2] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker1][0] \n" + "vmlal.s16 q14, d20, %e[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker1][1] \n" + "vmlal.s16 q14, d20, %e[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker1][2] \n" + "vmlal.s16 q14, d20, %e[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker1][3] \n" + "vmlal.s16 q14, d20, %e[ker1][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %e[ker3][0] \n" + "vmlal.s16 q14, d16, %f[ker2][3] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker2][0] \n" + "vmlal.s16 q14, d20, %f[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker2][1] \n" + "vmlal.s16 q14, d20, %f[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker2][2] \n" + "vmlal.s16 q14, d20, %f[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker2][3] \n" + "vmlal.s16 q14, d20, %f[ker1][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmlal.s16 q14, d16, %e[ker3][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q14, d20, %e[ker2][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q14, d20, %e[ker2][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q14, d20, %e[ker2][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q14, d20, %e[ker2][3] \n" + + // restore output + "vst1.32 {d24-d25}, [%[output_ptr0]]! \n" + "vst1.32 {d28-d29}, [%[output_ptr1]]! \n" + "sub %[remain], #4 \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain] \n" + "vld1.s8 {d10}, [%[input_ptr0]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr1]], r0 \n" + "vld1.s8 {d14}, [%[input_ptr2]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmull.s16 q12, d16, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker0][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %f[ker2][1] \n" + "vmull.s16 q14, d16, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker0][0] \n" + "vmlal.s16 q14, d20, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker0][1] \n" + "vmlal.s16 q14, d20, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker0][2] \n" + "vmlal.s16 q14, d20, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker0][3] \n" + "vmlal.s16 q14, d20, %e[ker0][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmlal.s16 q12, d16, %f[ker2][2] \n" + "vmlal.s16 q14, d16, %f[ker2][1] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker1][0] \n" + "vmlal.s16 q14, d20, %f[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker1][1] \n" + "vmlal.s16 q14, d20, %f[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker1][2] \n" + "vmlal.s16 q14, d20, %f[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker1][3] \n" + "vmlal.s16 q14, d20, %f[ker0][3] \n" + + "vld1.s8 {d10}, [%[input_ptr3]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr4]], r0 \n" + "vld1.s8 {d14}, [%[input_ptr5]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmlal.s16 q12, d16, %f[ker2][3] \n" + "vmlal.s16 q14, d16, %f[ker2][2] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker1][0] \n" + "vmlal.s16 q14, d20, %e[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker1][1] \n" + "vmlal.s16 q14, d20, %e[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker1][2] \n" + "vmlal.s16 q14, d20, %e[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker1][3] \n" + "vmlal.s16 q14, d20, %e[ker1][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %e[ker3][0] \n" + "vmlal.s16 q14, d16, %f[ker2][3] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker2][0] \n" + "vmlal.s16 q14, d20, %f[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker2][1] \n" + "vmlal.s16 q14, d20, %f[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker2][2] \n" + "vmlal.s16 q14, d20, %f[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker2][3] \n" + "vmlal.s16 q14, d20, %f[ker1][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmlal.s16 q14, d16, %e[ker3][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q14, d20, %e[ker2][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q14, d20, %e[ker2][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q14, d20, %e[ker2][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q14, d20, %e[ker2][3] \n" + + "cmp %[remain], #2 \n" + "blt store_2h1w_%= \n" + "vst1.32 {d24}, [%[output_ptr0]]! \n" + "vst1.32 {d28}, [%[output_ptr1]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d25[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d29[0]}, [%[output_ptr1]]! \n" + "b end_%= \n" + + "store_2h1w_%=: \n" + "vst1.32 {d24[0]}, [%[output_ptr0]]! \n" + "vst1.32 {d28[0]}, [%[output_ptr1]]! \n" + "end_%=: \n" + : [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5), + [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1), + [loop] "+r"(loop), [remain] "+r"(w_remain) + : [ker0] "w"(_ker0), [ker1] "w"(_ker1), [ker2] "w"(_ker2), + [ker3] "w"(_ker3) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4))); + int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5))); + int16x4_t zero = vdup_n_s16(0); + int32x4_t acc0, acc1; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 5 - (padding_w + input_w); + if (padding >= 5) { + *output_ptr0 = 0; + *output_ptr1 = 0; + } else { + int iw = w - valid_w_end; + int32_t sum0 = input_ptr0[iw] * filter_ptr0[0] + + input_ptr1[iw] * filter_ptr1[0] + + input_ptr2[iw] * filter_ptr2[0] + + input_ptr3[iw] * filter_ptr3[0] + + input_ptr4[iw] * filter_ptr4[0]; + int32_t sum1 = input_ptr1[iw] * filter_ptr0[0] + + input_ptr2[iw] * filter_ptr1[0] + + input_ptr3[iw] * filter_ptr2[0] + + input_ptr4[iw] * filter_ptr3[0] + + input_ptr5[iw] * filter_ptr4[0]; + row0 = vext_s16(row0, zero, 1); + row1 = vext_s16(row1, zero, 1); + row2 = vext_s16(row2, zero, 1); + row3 = vext_s16(row3, zero, 1); + row4 = vext_s16(row4, zero, 1); + row5 = vext_s16(row5, zero, 1); + acc0 = vmull_s16(row0, _ker[0]); + acc0 = vmlal_s16(acc0, row1, _ker[1]); + acc0 = vmlal_s16(acc0, row2, _ker[2]); + acc0 = vmlal_s16(acc0, row3, _ker[3]); + acc0 = vmlal_s16(acc0, row4, _ker[4]); + acc1 = vmull_s16(row1, _ker[0]); + acc1 = vmlal_s16(acc1, row2, _ker[1]); + acc1 = vmlal_s16(acc1, row3, _ker[2]); + acc1 = vmlal_s16(acc1, row4, _ker[3]); + acc1 = vmlal_s16(acc1, row5, _ker[4]); + acc0 = vpaddq_s32(acc0, acc1); + int32x2_t sum = vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0)); + sum0 += vget_lane_s32(sum, 0); + sum1 += vget_lane_s32(sum, 1); + *output_ptr0 = sum0; + *output_ptr1 = sum1; + } + output_ptr0++; + output_ptr1++; + } + } + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + if (start_h < valid_h_end) { + const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; + const int8_t *input_ptr1 = input_ptr0 + input_w; + const int8_t *input_ptr2 = input_ptr1 + input_w; + const int8_t *input_ptr3 = input_ptr2 + input_w; + const int8_t *input_ptr4 = input_ptr3 + input_w; + int32_t *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4))); + int16x4_t zero = vdup_n_s16(0); + int32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 5) { + output_ptr0[w] = 0; + } else { + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + acc = vmlal_s16(acc, row3, _ker[3]); + acc = vmlal_s16(acc, row4, _ker[4]); + int32x2_t sum = vpadd_s32(vget_low_s32(acc), vget_high_s32(acc)); + sum = vpadd_s32(sum, sum); + vst1_lane_s32(output_ptr0 + w, sum, 0); + + row0 = vext_s16(zero, row0, 3); + row1 = vext_s16(zero, row1, 3); + row2 = vext_s16(zero, row2, 3); + row3 = vext_s16(zero, row3, 3); + row4 = vext_s16(zero, row4, 3); + } + } + output_ptr0 += valid_w_start; + } + // valid + int loop = output_w_tiles; + int w_remain = output_w_remain; + asm volatile( + "cmp %[loop], #0 \n" + "ble start_remain4_%= \n" + "mov r0, #8 \n" + "loop_1h8w_%=: \n" + "vld1.s8 {d10-d11}, [%[input_ptr0]], r0 \n" + "vld1.s8 {d12-d13}, [%[input_ptr1]], r0 \n" + "vld1.s8 {d14-d15}, [%[input_ptr2]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmull.s16 q12, d16, %f[ker2][0] \n" + "vmull.s16 q13, d17, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker0][0] \n" + "vmlal.s16 q13, d21, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker0][1] \n" + "vmlal.s16 q13, d21, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker0][2] \n" + "vmlal.s16 q13, d21, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker0][3] \n" + "vmlal.s16 q13, d21, %e[ker0][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d16, %f[ker2][1] \n" + "vmlal.s16 q13, d17, %f[ker2][1] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker0][0] \n" + "vmlal.s16 q13, d21, %f[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker0][1] \n" + "vmlal.s16 q13, d21, %f[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker0][2] \n" + "vmlal.s16 q13, d21, %f[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker0][3] \n" + "vmlal.s16 q13, d21, %f[ker0][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmovl.s8 q9, d15 \n" + "vmlal.s16 q12, d16, %f[ker2][2] \n" + "vmlal.s16 q13, d17, %f[ker2][2] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker1][0] \n" + "vmlal.s16 q13, d21, %e[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker1][1] \n" + "vmlal.s16 q13, d21, %e[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker1][2] \n" + "vmlal.s16 q13, d21, %e[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker1][3] \n" + "vmlal.s16 q13, d21, %e[ker1][3] \n" + + "vld1.s8 {d10-d11}, [%[input_ptr3]], r0 \n" + "vld1.s8 {d12-d13}, [%[input_ptr4]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmovl.s8 q9, d11 \n" + "vmlal.s16 q12, d16, %f[ker2][3] \n" + "vmlal.s16 q13, d17, %f[ker2][3] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker1][0] \n" + "vmlal.s16 q13, d21, %f[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker1][1] \n" + "vmlal.s16 q13, d21, %f[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker1][2] \n" + "vmlal.s16 q13, d21, %f[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker1][3] \n" + "vmlal.s16 q13, d21, %f[ker1][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmovl.s8 q9, d13 \n" + "vmlal.s16 q12, d16, %e[ker3][0] \n" + "vmlal.s16 q13, d17, %e[ker3][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker2][0] \n" + "vmlal.s16 q13, d21, %e[ker2][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker2][1] \n" + "vmlal.s16 q13, d21, %e[ker2][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker2][2] \n" + "vmlal.s16 q13, d21, %e[ker2][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker2][3] \n" + "vmlal.s16 q13, d21, %e[ker2][3] \n" + + // restore output + "vst1.32 {q12-q13}, [%[output_ptr0]]! \n" + "subs %[loop], #1 \n" + "bne loop_1h8w_%= \n" + + "start_remain4_%=: \n" + "cmp %[remain], #4 \n" + "blt start_remain_%= \n" + "mov r0, #4 \n" + "vld1.s8 {d10}, [%[input_ptr0]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr1]], r0 \n" + "vld1.s8 {d14}, [%[input_ptr2]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmull.s16 q12, d16, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker0][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %f[ker2][1] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker0][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmlal.s16 q12, d16, %f[ker2][2] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker1][3] \n" + + "vld1.s8 {d10}, [%[input_ptr3]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr4]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmlal.s16 q12, d16, %f[ker2][3] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker1][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %e[ker3][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker2][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker2][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker2][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker2][3] \n" + + // restore output + "vst1.32 {d24-d25}, [%[output_ptr0]]! \n" + "sub %[remain], #4 \n" + + "start_remain_%=: \n" + "cmp %[remain], #0 \n" + "ble end_%= \n" + "mov r0, %[remain] \n" + "vld1.s8 {d10}, [%[input_ptr0]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr1]], r0 \n" + "vld1.s8 {d14}, [%[input_ptr2]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmull.s16 q12, d16, %f[ker2][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker0][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %f[ker2][1] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker0][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker0][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker0][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker0][3] \n" + + "vmovl.s8 q8, d14 \n" + "vmlal.s16 q12, d16, %f[ker2][2] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker1][3] \n" + + "vld1.s8 {d10}, [%[input_ptr3]], r0 \n" + "vld1.s8 {d12}, [%[input_ptr4]], r0 \n" + "vmovl.s8 q8, d10 \n" + "vmlal.s16 q12, d16, %f[ker2][3] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %f[ker1][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %f[ker1][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %f[ker1][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %f[ker1][3] \n" + + "vmovl.s8 q8, d12 \n" + "vmlal.s16 q12, d16, %e[ker3][0] \n" + "vext.s16 q10, q8, q9, #1 \n" + "vmlal.s16 q12, d20, %e[ker2][0] \n" + "vext.s16 q10, q8, q9, #2 \n" + "vmlal.s16 q12, d20, %e[ker2][1] \n" + "vext.s16 q10, q8, q9, #3 \n" + "vmlal.s16 q12, d20, %e[ker2][2] \n" + "vext.s16 q10, q8, q9, #4 \n" + "vmlal.s16 q12, d20, %e[ker2][3] \n" + + "cmp %[remain], #2 \n" + "blt store_1h1w_%= \n" + "vst1.32 {d24}, [%[output_ptr0]]! \n" + "cmp %[remain], #3 \n" + "blt end_%= \n" + "vst1.32 {d25[0]}, [%[output_ptr0]]! \n" + "b end_%= \n" + + "store_1h1w_%=: \n" + "vst1.32 {d24[0]}, [%[output_ptr0]]! \n" + "end_%=: \n" + : [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1), + [input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3), + [input_ptr4] "+r"(input_ptr4), [output_ptr0] "+r"(output_ptr0), + [loop] "+r"(loop), [remain] "+r"(w_remain) + : [ker0] "w"(_ker0), [ker1] "w"(_ker1), [ker2] "w"(_ker2), + [ker3] "w"(_ker3) + : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r0"); + // pad right + if (padding_w) { + int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); + int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1))); + int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2))); + int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3))); + int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4))); + int16x4_t zero = vdup_n_s16(0); + int32x4_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 5 - (padding_w + input_w); + if (padding >= 5) { + *output_ptr0 = 0; + } else { + int iw = w - valid_w_end; + int32_t sum0 = input_ptr0[iw] * filter_ptr0[0] + + input_ptr1[iw] * filter_ptr1[0] + + input_ptr2[iw] * filter_ptr2[0] + + input_ptr3[iw] * filter_ptr3[0] + + input_ptr4[iw] * filter_ptr4[0]; + row0 = vext_s16(row0, zero, 1); + row1 = vext_s16(row1, zero, 1); + row2 = vext_s16(row2, zero, 1); + row3 = vext_s16(row3, zero, 1); + row4 = vext_s16(row4, zero, 1); + acc = vmull_s16(row0, _ker[0]); + acc = vmlal_s16(acc, row1, _ker[1]); + acc = vmlal_s16(acc, row2, _ker[2]); + acc = vmlal_s16(acc, row3, _ker[3]); + acc = vmlal_s16(acc, row4, _ker[4]); + int32x2_t sum = vpadd_s32(vget_low_s32(acc), vget_high_s32(acc)); + sum = vpadd_s32(sum, sum); + sum0 += vget_lane_s32(sum, 0); + *output_ptr0 = sum0; + } + output_ptr0++; + } + } + } + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker, kernel); + } + } +} + +template <> +void DepthwiseConv5x5S2(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) {} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 3c901330dc7ed0fc82e4b25c15292c54722bec2b..92eb8a8a2e0f2573d9dca86ce0bd0369b846e91e 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -3150,9 +3150,11 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { +#ifndef __aarch64__ if (m == 1 && bias == nullptr) { return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu); } +#endif // __aarch64__ #ifdef _OPENMP int max_threads = omp_get_max_threads(); #else diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h index 4239cf8cbc87e786e2e07ac77614f4d2f96d73dd..0f0b4e2630294aca069883932ad8115b50eb2ed4 100644 --- a/src/operators/math/pooling.h +++ b/src/operators/math/pooling.h @@ -53,7 +53,7 @@ struct PoolingVal { ++count; return *this; } - inline float Value() { return (count > 0) ? val / count : 0.f; } + inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; } }; #if defined(__ARM_NEON) || defined(__ARM_NEON__) @@ -67,6 +67,16 @@ inline float32x4_t vPoolInitq_f32() { return vdupq_n_f32(0.f); } +template +inline float32x2_t vPoolInit_f32() { + return vdup_n_f32(-std::numeric_limits::max()); +} + +template <> +inline float32x2_t vPoolInit_f32() { + return vdup_n_f32(0.f); +} + template inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { return vmaxq_f32(x1, x2); @@ -78,6 +88,28 @@ inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, return vaddq_f32(x1, x2); } +template +inline float32x2_t vPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) { + return vmax_f32(x1, x2); +} + +template <> +inline float32x2_t vPoolPre_f32(const float32x2_t &x1, + const float32x2_t &x2) { + return vadd_f32(x1, x2); +} + +template +inline float32x2_t vpPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) { + return vpmax_f32(x1, x2); +} + +template <> +inline float32x2_t vpPoolPre_f32(const float32x2_t &x1, + const float32x2_t &x2) { + return vpadd_f32(x1, x2); +} + template inline float32x4_t vPoolPostq_f32(const float32x4_t &x, const float32x4_t &post) { @@ -89,6 +121,18 @@ inline float32x4_t vPoolPostq_f32(const float32x4_t &x, const float32x4_t &post) { return vmulq_f32(x, post); } + +template +inline float32x2_t vPoolPost_f32(const float32x2_t &x, + const float32x2_t &post) { + return x; +} + +template <> +inline float32x2_t vPoolPost_f32(const float32x2_t &x, + const float32x2_t &post) { + return vmul_f32(x, post); +} #endif // __ARM_NEON__ template diff --git a/src/operators/math/pooling2x2.cpp b/src/operators/math/pooling2x2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..675a6392ed21dce4f9e324bc2dacd8609e2de999 --- /dev/null +++ b/src/operators/math/pooling2x2.cpp @@ -0,0 +1,791 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef POOL_OP + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + +#include +#include "operators/math/pooling.h" + +// TODO(hjchen2): Optimize Pooling2x2NormalRow and use inline assembly + +namespace paddle_mobile { +namespace operators { +namespace math { + +#define POOLING2X2_NORMAL_BORDER(start, end) \ + for (int w = start; w < end; ++w) { \ + const int w_in_start = -padding_w + w * Stride; \ + const int w_in_end = w_in_start + 2; \ + const int w_start = w_in_start > 0 ? w_in_start : 0; \ + const int w_end = w_in_end < input_w ? w_in_end : input_w; \ + PoolingVal

val; \ + for (int h_in = h_start; h_in < h_end; ++h_in) { \ + for (int w_in = w_start; w_in < w_end; ++w_in) { \ + val += input[h_in * input_w + w_in]; \ + } \ + } \ + output_ptr[w] = val.Value(); \ + } + +template +struct Pooling2x2NormalRowLoadInput { + void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) { + x0[0] = vld1q_f32(input); + x0[1] = vld1q_f32(input + 4); + x1[0] = vextq_f32(x0[0], x0[1], 1); + x1[1] = vextq_f32(x0[1], x0[1], 1); + } +}; + +template +struct Pooling2x2NormalRowLoadInput { + void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) { + float32x4x2_t t0 = vld2q_f32(input); + float32x4x2_t t1 = vld2q_f32(input + 8); + x0[0] = t0.val[0]; + x0[1] = t1.val[0]; + x1[0] = t0.val[1]; + x1[1] = t1.val[1]; + } +}; + +template +inline void Pooling2x2NormalRow(const float *input, const int h_output, + const int input_h, const int input_w, + const int padding_h, const int padding_w, + const int output_w, float *output) { + const int h_in_start = -padding_h + h_output * Stride; + const int h_in_end = h_in_start + 2; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int h_end = h_in_end < input_h ? h_in_end : input_h; + + float *output_ptr = output + h_output * output_w; + if (h_end - h_start <= 0) { + memset(output_ptr, 0, output_w * sizeof(float)); + return; + } + + const int valid_w_start = (padding_w + Stride - 1) / Stride; + const int valid_w_end = (input_w + padding_w - 2) / Stride + 1; + const int valid_w = valid_w_end - valid_w_start; + + // border left + POOLING2X2_NORMAL_BORDER(0, valid_w_start) + // valid w + Pooling2x2NormalRowLoadInput load_input; + int output_tiles = valid_w / 6; + int output_tiles_w = output_tiles * 6; + float32x4_t x0[2], x1[2], y0[2]; + float32x4_t post = vdupq_n_f32(1.f / (2 * (h_end - h_start))); + for (int w = 0; w < output_tiles_w; w += 6) { + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride - padding_w; + y0[0] = vPoolInitq_f32

(); + y0[1] = vPoolInitq_f32

(); + for (int h_in = h_start; h_in < h_end; ++h_in) { + load_input(input + h_in * input_w + input_w_offset, x0, x1); + y0[0] = vPoolPreq_f32

(y0[0], x0[0]); + y0[0] = vPoolPreq_f32

(y0[0], x1[0]); + y0[1] = vPoolPreq_f32

(y0[1], x0[1]); + y0[1] = vPoolPreq_f32

(y0[1], x1[1]); + } + y0[0] = vPoolPostq_f32

(y0[0], post); + y0[1] = vPoolPostq_f32

(y0[1], post); + vst1q_f32(output_ptr + output_offset, y0[0]); + vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0[1])); + } + // remain valid w + int remain = valid_w - output_tiles_w; + if (remain > 0) { + int remain_start = valid_w_start + output_tiles_w; + int input_w_offset = remain_start * Stride - padding_w; + float *output_ptr0 = output_ptr + remain_start; + y0[0] = vPoolInitq_f32

(); + y0[1] = vPoolInitq_f32

(); + for (int h_in = h_start; h_in < h_end; ++h_in) { + load_input(input + h_in * input_w + input_w_offset, x0, x1); + y0[0] = vPoolPreq_f32

(y0[0], x0[0]); + y0[0] = vPoolPreq_f32

(y0[0], x1[0]); + y0[1] = vPoolPreq_f32

(y0[1], x0[1]); + y0[1] = vPoolPreq_f32

(y0[1], x1[1]); + } + y0[0] = vPoolPostq_f32

(y0[0], post); + y0[1] = vPoolPostq_f32

(y0[1], post); + switch (remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0[0])); + vst1q_lane_f32(output_ptr0 + 2, y0[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0[0]); + vst1q_lane_f32(output_ptr0 + 4, y0[1], 0); + break; + } + } + // border right + POOLING2X2_NORMAL_BORDER(valid_w_end, output_w) +} + +template +struct Pooling2x2 { + inline void operator()(const framework::Tensor &input, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + float *output_data = output->mutable_data(); + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + int valid_h_start = padding_h; + int valid_h_end = output_h - valid_h_start; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = padding_w; + int valid_w_end = output_w - valid_w_start; + int valid_w = valid_w_end - valid_w_start; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < output->dims()[0]; ++batch) { + for (int c = 0; c < output->dims()[1]; ++c) { + int channel = batch * output->dims()[1] + c; + const float *input_ptr = input_data + channel * image_size; + float *output_ptr = output_data + channel * out_image_size; + // top + for (int h = 0; h < valid_h_start; ++h) { + Pooling2x2NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); + } + // valid + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 3; h += 4) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + float *output_ptr2 = output_ptr1 + output_w; + float *output_ptr3 = output_ptr2 + output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 2) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + output_ptr2[w] = 0.f; + output_ptr3[w] = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + float acc1 = PoolPre

(*input_ptr1, *input_ptr2); + float acc2 = PoolPre

(*input_ptr2, *input_ptr3); + float acc3 = PoolPre

(*input_ptr3, *input_ptr4); + output_ptr0[w] = PoolPost

(acc0, 0.5f); + output_ptr1[w] = PoolPost

(acc1, 0.5f); + output_ptr2[w] = PoolPost

(acc2, 0.5f); + output_ptr3[w] = PoolPost

(acc3, 0.5f); + } + } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + output_ptr2 += valid_w_start; + output_ptr3 += valid_w_start; + } + // valid + float32x4x2_t x0, x1, q0; + float32x4x2_t y0, y1; + float32x4_t post = vdupq_n_f32(0.25f); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vld1q_f32(input_ptr1); + x1.val[1] = vld1q_f32(input_ptr1 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y0.val[0] = vPoolPreq_f32

(x0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], q0.val[1]); + + q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1); + q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1); + y1.val[0] = vPoolPreq_f32

(x1.val[0], q0.val[0]); + y1.val[1] = vPoolPreq_f32

(x1.val[1], q0.val[1]); + y0.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr0, y0.val[0]); + vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); + + x0.val[0] = vld1q_f32(input_ptr2); + x0.val[1] = vld1q_f32(input_ptr2 + 4); + x1.val[0] = vld1q_f32(input_ptr3); + x1.val[1] = vld1q_f32(input_ptr3 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y0.val[0] = vPoolPreq_f32

(x0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], q0.val[1]); + y1.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + y1.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); + vst1q_f32(output_ptr1, y1.val[0]); + vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1])); + + q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1); + q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1); + y1.val[0] = vPoolPreq_f32

(x1.val[0], q0.val[0]); + y1.val[1] = vPoolPreq_f32

(x1.val[1], q0.val[1]); + y0.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr2, y0.val[0]); + vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1])); + + x0.val[0] = vld1q_f32(input_ptr4); + x0.val[1] = vld1q_f32(input_ptr4 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y1.val[0] = vPoolPreq_f32

(y1.val[0], x0.val[0]); + y1.val[0] = vPoolPreq_f32

(y1.val[0], q0.val[0]); + y1.val[1] = vPoolPreq_f32

(y1.val[1], x0.val[1]); + y1.val[1] = vPoolPreq_f32

(y1.val[1], q0.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); + vst1q_f32(output_ptr3, y1.val[0]); + vst1_f32(output_ptr3 + 4, vget_low_f32(y1.val[1])); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + input_ptr3 += 6; + input_ptr4 += 6; + output_ptr0 += 6; + output_ptr1 += 6; + output_ptr2 += 6; + output_ptr3 += 6; + } + // remain width + if (output_w_remain > 0) { + float32x4x2_t y2, y3; + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vld1q_f32(input_ptr1); + x1.val[1] = vld1q_f32(input_ptr1 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y0.val[0] = vPoolPreq_f32

(x0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], q0.val[1]); + + q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1); + q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1); + y1.val[0] = vPoolPreq_f32

(x1.val[0], q0.val[0]); + y1.val[1] = vPoolPreq_f32

(x1.val[1], q0.val[1]); + y0.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], y1.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + + x0.val[0] = vld1q_f32(input_ptr2); + x0.val[1] = vld1q_f32(input_ptr2 + 4); + x1.val[0] = vld1q_f32(input_ptr3); + x1.val[1] = vld1q_f32(input_ptr3 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y2.val[0] = vPoolPreq_f32

(x0.val[0], q0.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], q0.val[1]); + y1.val[0] = vPoolPreq_f32

(y1.val[0], y2.val[0]); + y1.val[1] = vPoolPreq_f32

(y1.val[1], y2.val[1]); + y1.val[0] = vPoolPostq_f32

(y1.val[0], post); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); + + q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1); + q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1); + y3.val[0] = vPoolPreq_f32

(x1.val[0], q0.val[0]); + y3.val[1] = vPoolPreq_f32

(x1.val[1], q0.val[1]); + y2.val[0] = vPoolPreq_f32

(y2.val[0], y3.val[0]); + y2.val[1] = vPoolPreq_f32

(y2.val[1], y3.val[1]); + y2.val[0] = vPoolPostq_f32

(y2.val[0], post); + y2.val[1] = vPoolPostq_f32

(y2.val[1], post); + + x0.val[0] = vld1q_f32(input_ptr4); + x0.val[1] = vld1q_f32(input_ptr4 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y3.val[0] = vPoolPreq_f32

(y3.val[0], x0.val[0]); + y3.val[0] = vPoolPreq_f32

(y3.val[0], q0.val[0]); + y3.val[1] = vPoolPreq_f32

(y3.val[1], x0.val[1]); + y3.val[1] = vPoolPreq_f32

(y3.val[1], q0.val[1]); + y3.val[0] = vPoolPostq_f32

(y3.val[0], post); + y3.val[1] = vPoolPostq_f32

(y3.val[1], post); + + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + vst1q_lane_f32(output_ptr1, y1.val[0], 0); + vst1q_lane_f32(output_ptr2, y2.val[0], 0); + vst1q_lane_f32(output_ptr3, y3.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1_f32(output_ptr1, vget_low_f32(y1.val[0])); + vst1_f32(output_ptr2, vget_low_f32(y2.val[0])); + vst1_f32(output_ptr3, vget_low_f32(y3.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1_f32(output_ptr1, vget_low_f32(y1.val[0])); + vst1_f32(output_ptr2, vget_low_f32(y2.val[0])); + vst1_f32(output_ptr3, vget_low_f32(y3.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2); + vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2); + vst1q_lane_f32(output_ptr3 + 2, y3.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_f32(output_ptr1, y1.val[0]); + vst1q_f32(output_ptr2, y2.val[0]); + vst1q_f32(output_ptr3, y3.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_f32(output_ptr1, y1.val[0]); + vst1q_f32(output_ptr2, y2.val[0]); + vst1q_f32(output_ptr3, y3.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0); + vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0); + vst1q_lane_f32(output_ptr3 + 4, y3.val[1], 0); + break; + } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + input_ptr4 += output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; + output_ptr2 += output_w_remain; + output_ptr3 += output_w_remain; + } + // pad right + if (padding_w) { + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 2 - (padding_w + input_w); + if (padding >= 2) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + *output_ptr2 = 0.f; + *output_ptr3 = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + float acc1 = PoolPre

(*input_ptr1, *input_ptr2); + float acc2 = PoolPre

(*input_ptr2, *input_ptr3); + float acc3 = PoolPre

(*input_ptr3, *input_ptr4); + *output_ptr0 = PoolPost

(acc0, 0.5f); + *output_ptr1 = PoolPost

(acc1, 0.5f); + *output_ptr2 = PoolPost

(acc2, 0.5f); + *output_ptr3 = PoolPost

(acc3, 0.5f); + } + output_ptr0++; + output_ptr1++; + output_ptr2++; + output_ptr3++; + } + } + } + // remain height + int start_h = valid_h_start + (valid_h & 0xFFFC); + for (int h = start_h; h < valid_h_end; ++h) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 2) { + output_ptr0[w] = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + output_ptr0[w] = PoolPost

(acc0, 0.5f); + } + } + output_ptr0 += valid_w_start; + } + // valid + float32x4x2_t x0, x1, q0, y0; + float32x4_t post = vdupq_n_f32(0.25f); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vld1q_f32(input_ptr1); + x1.val[1] = vld1q_f32(input_ptr1 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y0.val[0] = vPoolPreq_f32

(x0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], q0.val[1]); + + q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1); + q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1); + y0.val[0] = vPoolPreq_f32

(y0.val[0], x1.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(y0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], q0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + vst1q_f32(output_ptr0, y0.val[0]); + vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1])); + + input_ptr0 += 6; + input_ptr1 += 6; + output_ptr0 += 6; + } + // remain width + if (output_w_remain > 0) { + x0.val[0] = vld1q_f32(input_ptr0); + x0.val[1] = vld1q_f32(input_ptr0 + 4); + x1.val[0] = vld1q_f32(input_ptr1); + x1.val[1] = vld1q_f32(input_ptr1 + 4); + q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); + y0.val[0] = vPoolPreq_f32

(x0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], q0.val[1]); + + q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1); + q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1); + y0.val[0] = vPoolPreq_f32

(y0.val[0], x1.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], x1.val[1]); + y0.val[0] = vPoolPreq_f32

(y0.val[0], q0.val[0]); + y0.val[1] = vPoolPreq_f32

(y0.val[1], q0.val[1]); + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + break; + } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + output_ptr0 += output_w_remain; + } + // pad right + if (padding_w) { + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 2 - (padding_w + input_w); + if (padding >= 2) { + *output_ptr0 = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + *output_ptr0 = PoolPost

(acc0, 0.5f); + } + output_ptr0++; + } + } + } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + Pooling2x2NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); + } + } + } + } +}; + +template +struct Pooling2x2 { + inline void operator()(const framework::Tensor &input, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + float *output_data = output->mutable_data(); + int input_h = input.dims()[2]; + int input_w = input.dims()[3]; + int output_h = output->dims()[2]; + int output_w = output->dims()[3]; + int padding_h = paddings[0]; + int padding_w = paddings[1]; + int image_size = input_h * input_w; + int out_image_size = output_h * output_w; + int valid_h_start = (padding_h + 1) / 2; + int valid_h_end = (input_h + padding_h) / 2; + int valid_h = valid_h_end - valid_h_start; + int valid_w_start = (padding_w + 1) / 2; + int valid_w_end = (input_w + padding_w) / 2; + int valid_w = valid_w_end - valid_w_start; + + bool ceil_mode = (((input_h + 2 * padding_h) / 2) < output_h) || + (((input_w + 2 * padding_w) / 2) < output_w); + int padding_b = + padding_h + (ceil_mode ? 2 * output_h - (input_h + 2 * padding_h) : 0); + int padding_r = + padding_w + (ceil_mode ? 2 * output_w - (input_w + 2 * padding_w) : 0); + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < output->dims()[0]; ++batch) { + for (int c = 0; c < output->dims()[1]; ++c) { + int channel = batch * output->dims()[1] + c; + const float *input_ptr = input_data + channel * image_size; + float *output_ptr = output_data + channel * out_image_size; + // top + for (int h = 0; h < valid_h_start; ++h) { + Pooling2x2NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); + } + // valid + int output_w_tiles = valid_w / 4; + int output_w_remain = valid_w - output_w_tiles * 4; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w * 2; + if (padding >= 2) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + float acc1 = PoolPre

(*input_ptr2, *input_ptr3); + output_ptr0[w] = PoolPost

(acc0, 0.5f); + output_ptr1[w] = PoolPost

(acc1, 0.5f); + } + } + input_ptr0 += (padding_w & 0x1); + input_ptr1 += (padding_w & 0x1); + input_ptr2 += (padding_w & 0x1); + input_ptr3 += (padding_w & 0x1); + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + } + // valid + float32x4x2_t x0, x1, x2, x3; + float32x4_t y0, y1; + float32x4_t post = vdupq_n_f32(0.25f); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0 = vld2q_f32(input_ptr0); + x1 = vld2q_f32(input_ptr1); + x2 = vld2q_f32(input_ptr2); + x3 = vld2q_f32(input_ptr3); + y0 = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y1 = vPoolPreq_f32

(x2.val[0], x2.val[1]); + y0 = vPoolPreq_f32

(y0, x1.val[0]); + y1 = vPoolPreq_f32

(y1, x3.val[0]); + y0 = vPoolPreq_f32

(y0, x1.val[1]); + y1 = vPoolPreq_f32

(y1, x3.val[1]); + y0 = vPoolPostq_f32

(y0, post); + y1 = vPoolPostq_f32

(y1, post); + vst1q_f32(output_ptr0, y0); + vst1q_f32(output_ptr1, y1); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + input_ptr3 += 8; + output_ptr0 += 4; + output_ptr1 += 4; + } + // remain width + if (output_w_remain > 0) { + x0 = vld2q_f32(input_ptr0); + x1 = vld2q_f32(input_ptr1); + x2 = vld2q_f32(input_ptr2); + x3 = vld2q_f32(input_ptr3); + y0 = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y1 = vPoolPreq_f32

(x2.val[0], x2.val[1]); + y0 = vPoolPreq_f32

(y0, x1.val[0]); + y1 = vPoolPreq_f32

(y1, x3.val[0]); + y0 = vPoolPreq_f32

(y0, x1.val[1]); + y1 = vPoolPreq_f32

(y1, x3.val[1]); + y0 = vPoolPostq_f32

(y0, post); + y1 = vPoolPostq_f32

(y1, post); + + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0, 0); + vst1q_lane_f32(output_ptr1, y1, 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0)); + vst1_f32(output_ptr1, vget_low_f32(y1)); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0)); + vst1q_lane_f32(output_ptr0 + 2, y0, 2); + vst1_f32(output_ptr1, vget_low_f32(y1)); + vst1q_lane_f32(output_ptr1 + 2, y1, 2); + break; + } + input_ptr0 += 2 * output_w_remain; + input_ptr1 += 2 * output_w_remain; + input_ptr2 += 2 * output_w_remain; + input_ptr3 += 2 * output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; + } + // pad right + if (padding_r) { + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 2 - (padding_w + input_w); + if (padding >= 2) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + float acc1 = PoolPre

(*input_ptr2, *input_ptr3); + *output_ptr0 = PoolPost

(acc0, 0.5f); + *output_ptr1 = PoolPost

(acc1, 0.5f); + } + output_ptr0++; + output_ptr1++; + } + } + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + for (int h = start_h; h < valid_h_end; ++h) { + const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - 2 * w; + if (padding >= 2) { + output_ptr0[w] = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + output_ptr0[w] = PoolPost

(acc0, 0.5f); + } + } + input_ptr0 += (padding_w & 0x1); + input_ptr1 += (padding_w & 0x1); + output_ptr0 += valid_w_start; + } + // valid + float32x4x2_t x0, x1; + float32x4_t y0; + float32x4_t post = vdupq_n_f32(0.25f); + for (int loop = 0; loop < output_w_tiles; ++loop) { + x0 = vld2q_f32(input_ptr0); + x1 = vld2q_f32(input_ptr1); + y0 = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y0 = vPoolPreq_f32

(y0, x1.val[0]); + y0 = vPoolPreq_f32

(y0, x1.val[1]); + y0 = vPoolPostq_f32

(y0, post); + vst1q_f32(output_ptr0, y0); + + input_ptr0 += 8; + input_ptr1 += 8; + output_ptr0 += 4; + } + // remain width + if (output_w_remain > 0) { + x0 = vld2q_f32(input_ptr0); + x1 = vld2q_f32(input_ptr1); + y0 = vPoolPreq_f32

(x0.val[0], x0.val[1]); + y0 = vPoolPreq_f32

(y0, x1.val[0]); + y0 = vPoolPreq_f32

(y0, x1.val[1]); + y0 = vPoolPostq_f32

(y0, post); + + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0, 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0)); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0)); + vst1q_lane_f32(output_ptr0 + 2, y0, 2); + break; + } + input_ptr0 += 2 * output_w_remain; + input_ptr1 += 2 * output_w_remain; + output_ptr0 += output_w_remain; + } + // pad right + if (padding_r) { + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 2 - (padding_w + input_w); + if (padding >= 2) { + *output_ptr0 = 0.f; + } else { + float acc0 = PoolPre

(*input_ptr0, *input_ptr1); + *output_ptr0 = PoolPost

(acc0, 0.5f); + } + output_ptr0++; + } + } + } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + Pooling2x2NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); + } + } + } + } +}; + +template struct Pooling2x2; +template struct Pooling2x2; +template struct Pooling2x2; +template struct Pooling2x2; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ +#endif // POOL_OP diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp index 72ffb6161a96fbde432768fbda455cf4d869de61..35029c6425c07b4bed03d667a014bc3e7d960df6 100644 --- a/src/operators/math/pooling3x3.cpp +++ b/src/operators/math/pooling3x3.cpp @@ -14,10 +14,10 @@ limitations under the License. */ #ifdef POOL_OP -#include "operators/math/pooling.h" #if defined(__ARM_NEON) || defined(__ARM_NEON__) + #include -#endif // __ARM_NEON +#include "operators/math/pooling.h" namespace paddle_mobile { namespace operators { @@ -38,87 +38,6 @@ namespace math { output_ptr[w] = val.Value(); \ } -#if defined(__ARM_NEON) || defined(__ARM_NEON__) -template -struct Pooling3x3ValidColLoadInput { - inline void operator()(const float *input, const int input_w, - const int valid_cols, float32x4x2_t &x0, // NOLINT - float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT - float32x4x2_t &y0) { // NOLINT - float fake_input[3][8]; - if (valid_cols == 1) { - for (int i = 0; i < 8; ++i, input += input_w) { - fake_input[0][i] = input[0]; - } - } else if (valid_cols == 2) { - for (int i = 0; i < 8; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - } - } else { - for (int i = 0; i < 8; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - fake_input[2][i] = input[2]; - } - } - y0.val[0] = vPoolInitq_f32

(); - y0.val[1] = vPoolInitq_f32

(); - for (int i = 0; i < valid_cols; ++i) { - x0.val[0] = vld1q_f32(fake_input[i]); - x0.val[1] = vld1q_f32(fake_input[i] + 4); - x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); - x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); - x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); - x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - y0.val[0] = vPoolPreq_f32

(x1.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x1.val[1], y0.val[1]); - y0.val[0] = vPoolPreq_f32

(x2.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x2.val[1], y0.val[1]); - } - } -}; - -template -struct Pooling3x3ValidColLoadInput { - inline void operator()(const float *input, const int input_w, - const int valid_cols, float32x4x2_t &x0, // NOLINT - float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT - float32x4x2_t &y0) { // NOLINT - float fake_input[3][13]; - if (valid_cols == 1) { - for (int i = 0; i < 13; ++i, input += input_w) { - fake_input[0][i] = input[0]; - } - } else if (valid_cols == 2) { - for (int i = 0; i < 13; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - } - } else { - for (int i = 0; i < 13; ++i, input += input_w) { - fake_input[0][i] = input[0]; - fake_input[1][i] = input[1]; - fake_input[2][i] = input[2]; - } - } - for (int i = 0; i < valid_cols; ++i) { - x0 = vld2q_f32(fake_input[i]); - x1 = vld2q_f32(fake_input[i] + 8); - x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); - x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); - x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); - } - } -}; - template struct Pooling3x3NormalRowLoadInput { inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT @@ -156,62 +75,6 @@ struct Pooling3x3NormalRowLoadInput { y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); } }; -#endif // __ARM_NEON__ - -template -inline void Pooling3x3ValidCol(const float *input, const int h_output, - const int h_output_end, const int w_output, - const int input_h, const int input_w, - const int padding_h, const int padding_w, - const int output_w, float *output) { - const int w_in_start = -padding_w + w_output * Stride; - const int w_in_end = w_in_start + 3; - const int w_start = w_in_start > 0 ? w_in_start : 0; - const int w_end = w_in_end < input_w ? w_in_end : input_w; - int remain_start = h_output; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - int output_tiles = (h_output_end - h_output) / 6; - remain_start = h_output + output_tiles * 6; - int input_h_start = h_output * Stride - padding_h; - size_t input_offset = input_h_start * input_w + w_start; - size_t output_offset = h_output * output_w + w_output; - int valid_cols = w_end - w_start; - Pooling3x3ValidColLoadInput PoolingCompute; - float32x4x2_t x0, x1, x2, y0; - float32x4_t avg = vdupq_n_f32(1.f / (3 * valid_cols)); - for (int h = 0; h < output_tiles * 6; h += 6) { - float *output0 = output + output_offset; - float *output1 = output0 + output_w; - float *output2 = output1 + output_w; - float *output3 = output2 + output_w; - float *output4 = output3 + output_w; - float *output5 = output4 + output_w; - y0.val[0] = vPoolInitq_f32

(); - y0.val[1] = vPoolInitq_f32

(); - PoolingCompute(input + input_offset, input_w, valid_cols, x0, x1, x2, y0); - y0.val[0] = vPoolPostq_f32

(y0.val[0], avg); - y0.val[1] = vPoolPostq_f32

(y0.val[1], avg); - vst1q_lane_f32(output0, y0.val[0], 0); - vst1q_lane_f32(output1, y0.val[0], 1); - vst1q_lane_f32(output2, y0.val[0], 2); - vst1q_lane_f32(output3, y0.val[0], 3); - vst1q_lane_f32(output4, y0.val[1], 0); - vst1q_lane_f32(output5, y0.val[1], 1); - input_offset += 6 * Stride * input_w; - output_offset += 6 * output_w; - } -#endif - for (int h = remain_start; h < h_output_end; ++h) { - PoolingVal

val; - const int h_in_start = -padding_h + h * Stride; - for (int i = 0; i < 3; ++i) { - for (int w_in = w_start; w_in < w_end; ++w_in) { - val += input[(h_in_start + i) * input_w + w_in]; - } - } - output[h * output_w + w_output] = val.Value(); - } -} template inline void Pooling3x3NormalRow(const float *input, const int h_output, @@ -223,21 +86,25 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output, const int h_start = h_in_start > 0 ? h_in_start : 0; const int h_end = h_in_end < input_h ? h_in_end : input_h; - int valid_w_start = (padding_w + Stride - 1) / Stride; - int valid_w_end = (input_w - 3) / Stride + 1 + valid_w_start; - float *output_ptr = output + h_output * output_w; + if (h_end - h_start <= 0) { + memset(output_ptr, 0, output_w * sizeof(float)); + return; + } + + const int valid_w_start = (padding_w + Stride - 1) / Stride; + const int valid_w_end = (input_w + padding_w - 3) / Stride + 1; + const int valid_w = valid_w_end - valid_w_start; + // border left POOLING3X3_NORMAL_BORDER(0, valid_w_start) // middle - int remain_start = valid_w_start; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) int output_tiles = (valid_w_end - valid_w_start) / 6; - remain_start = valid_w_start + output_tiles * 6; + int output_tiles_w = output_tiles * 6; Pooling3x3NormalRowLoadInput PoolingCompute; float32x4x2_t x0, x1, x2, y0; float32x4_t post = vdupq_n_f32(1.f / (3 * (h_end - h_start))); - for (int w = 0; w < output_tiles * 6; w += 6) { + for (int w = 0; w < output_tiles_w; w += 6) { int output_offset = valid_w_start + w; int input_w_offset = output_offset * Stride - padding_w; y0.val[0] = vPoolInitq_f32

(); @@ -250,16 +117,37 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output, vst1q_f32(output_ptr + output_offset, y0.val[0]); vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0.val[1])); } -#endif // __ARM_NEON__ - for (int w = remain_start; w < valid_w_end; ++w) { - PoolingVal

val; - int input_start = -padding_w + w * Stride; + int remain = valid_w - output_tiles_w; + if (remain > 0) { + int remain_start = valid_w_start + output_tiles_w; + int input_w_offset = remain_start * Stride - padding_w; + float *output_ptr0 = output_ptr + remain_start; + y0.val[0] = vPoolInitq_f32

(); + y0.val[1] = vPoolInitq_f32

(); for (int h_in = h_start; h_in < h_end; ++h_in) { - for (int j = 0; j < 3; ++j) { - val += input[h_in * input_w + j + input_start]; - } + PoolingCompute(input + h_in * input_w + input_w_offset, x0, x1, x2, y0); + } + y0.val[0] = vPoolPostq_f32

(y0.val[0], post); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + switch (remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + break; } - output_ptr[w] = val.Value(); } // border right POOLING3X3_NORMAL_BORDER(valid_w_end, output_w) @@ -286,7 +174,6 @@ struct Pooling3x3 { int valid_w_start = padding_w; int valid_w = input_w - 2; int valid_w_end = valid_w_start + valid_w; - float avg = 1.f / 9; #pragma omp parallel for collapse(2) for (int batch = 0; batch < output->dims()[0]; ++batch) { @@ -299,23 +186,6 @@ struct Pooling3x3 { Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, padding_w, output_w, output_ptr); } - // left - for (int w = 0; w < valid_w_start; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // right - for (int w = valid_w_end; w < output_w; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // bottom - for (int h = valid_h_end; h < output_h; ++h) { - Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, - padding_w, output_w, output_ptr); - } // valid int output_w_tiles = valid_w / 6; int output_w_remain = valid_w - output_w_tiles * 6; @@ -326,12 +196,61 @@ struct Pooling3x3 { const float *input_ptr3 = input_ptr2 + input_w; const float *input_ptr4 = input_ptr3 + input_w; const float *input_ptr5 = input_ptr4 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; + float *output_ptr0 = output_ptr + h * output_w; float *output_ptr1 = output_ptr0 + output_w; float *output_ptr2 = output_ptr1 + output_w; float *output_ptr3 = output_ptr2 + output_w; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) + // pad left + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t row3 = vld1_f32(input_ptr3); + float32x2_t row4 = vld1_f32(input_ptr4); + float32x2_t row5 = vld1_f32(input_ptr5); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, acc1, acc2, acc3, acc12, acc34, post; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + output_ptr2[w] = 0.f; + output_ptr3[w] = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc12 = vPoolPre_f32

(row1, row2); + acc34 = vPoolPre_f32

(row3, row4); + acc0 = vPoolPre_f32

(row0, acc12); + acc1 = vPoolPre_f32

(row3, acc12); + acc2 = vPoolPre_f32

(row2, acc34); + acc3 = vPoolPre_f32

(row5, acc34); + acc0 = vpPoolPre_f32

(acc0, acc0); + acc1 = vpPoolPre_f32

(acc1, acc1); + acc2 = vpPoolPre_f32

(acc2, acc2); + acc3 = vpPoolPre_f32

(acc3, acc3); + acc0 = vPoolPost_f32

(acc0, post); + acc1 = vPoolPost_f32

(acc1, post); + acc2 = vPoolPost_f32

(acc2, post); + acc3 = vPoolPost_f32

(acc3, post); + vst1_lane_f32(output_ptr0 + w, acc0, 0); + vst1_lane_f32(output_ptr1 + w, acc1, 0); + vst1_lane_f32(output_ptr2 + w, acc2, 0); + vst1_lane_f32(output_ptr3 + w, acc3, 0); + row0 = vext_f32(pad0, row0, 1); + row1 = vext_f32(pad0, row1, 1); + row2 = vext_f32(pad0, row2, 1); + row3 = vext_f32(pad0, row3, 1); + row4 = vext_f32(pad0, row4, 1); + row5 = vext_f32(pad0, row5, 1); + } + } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + output_ptr2 += valid_w_start; + output_ptr3 += valid_w_start; + } + // valid float32x4x2_t x0, x1, x2; float32x4x2_t y0, y1, y2; float32x4_t post = vdupq_n_f32(1.f / 9); @@ -446,100 +365,198 @@ struct Pooling3x3 { output_ptr3 += 6; } // remain width - if (remain >= 4) { + if (output_w_remain > 0) { + float32x4x2_t y3; x0.val[0] = vld1q_f32(input_ptr0); x0.val[1] = vld1q_f32(input_ptr0 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); x0.val[0] = vld1q_f32(input_ptr1); x0.val[1] = vld1q_f32(input_ptr1 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); x0.val[0] = vld1q_f32(input_ptr2); x0.val[1] = vld1q_f32(input_ptr2 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(y2.val[1], y1.val[1]); y0.val[0] = vPoolPreq_f32

(y2.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(y2.val[1], y0.val[1]); y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); x0.val[0] = vld1q_f32(input_ptr3); x0.val[1] = vld1q_f32(input_ptr3 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); - y2.val[0] = vPoolPreq_f32

(y0.val[0], y2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); + y3.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y3.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y1.val[0] = vPoolPreq_f32

(y3.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(y3.val[1], y1.val[1]); + y2.val[0] = vPoolPreq_f32

(y3.val[0], y2.val[0]); + y2.val[1] = vPoolPreq_f32

(y3.val[1], y2.val[1]); y1.val[0] = vPoolPostq_f32

(y1.val[0], post); - vst1q_f32(output_ptr1, y1.val[0]); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); x0.val[0] = vld1q_f32(input_ptr4); x0.val[1] = vld1q_f32(input_ptr4 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y3.val[0] = vPoolPreq_f32

(x0.val[0], y3.val[0]); + y3.val[1] = vPoolPreq_f32

(x0.val[1], y3.val[1]); y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], y2.val[1]); y2.val[0] = vPoolPostq_f32

(y2.val[0], post); - vst1q_f32(output_ptr2, y2.val[0]); + y2.val[1] = vPoolPostq_f32

(y2.val[1], post); x0.val[0] = vld1q_f32(input_ptr5); x0.val[1] = vld1q_f32(input_ptr5 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr3, y0.val[0]); - - input_ptr0 += 4; - input_ptr1 += 4; - input_ptr2 += 4; - input_ptr3 += 4; - input_ptr4 += 4; - input_ptr5 += 4; - output_ptr0 += 4; - output_ptr1 += 4; - output_ptr2 += 4; - output_ptr3 += 4; - remain -= 4; + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y3.val[0] = vPoolPreq_f32

(x0.val[0], y3.val[0]); + y3.val[1] = vPoolPreq_f32

(x0.val[1], y3.val[1]); + y3.val[0] = vPoolPostq_f32

(y3.val[0], post); + y3.val[1] = vPoolPostq_f32

(y3.val[1], post); + + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + vst1q_lane_f32(output_ptr1, y1.val[0], 0); + vst1q_lane_f32(output_ptr2, y2.val[0], 0); + vst1q_lane_f32(output_ptr3, y3.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1_f32(output_ptr1, vget_low_f32(y1.val[0])); + vst1_f32(output_ptr2, vget_low_f32(y2.val[0])); + vst1_f32(output_ptr3, vget_low_f32(y3.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1_f32(output_ptr1, vget_low_f32(y1.val[0])); + vst1_f32(output_ptr2, vget_low_f32(y2.val[0])); + vst1_f32(output_ptr3, vget_low_f32(y3.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2); + vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2); + vst1q_lane_f32(output_ptr3 + 2, y3.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_f32(output_ptr1, y1.val[0]); + vst1q_f32(output_ptr2, y2.val[0]); + vst1q_f32(output_ptr3, y3.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_f32(output_ptr1, y1.val[0]); + vst1q_f32(output_ptr2, y2.val[0]); + vst1q_f32(output_ptr3, y3.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0); + vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0); + vst1q_lane_f32(output_ptr3 + 4, y3.val[1], 0); + break; + } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + input_ptr4 += output_w_remain; + input_ptr5 += output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; + output_ptr2 += output_w_remain; + output_ptr3 += output_w_remain; } -#endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[r], input_ptr0[r + 1]); - m0 = PoolPre

(m0, input_ptr0[r + 2]); - float m1 = PoolPre

(input_ptr1[r], input_ptr1[r + 1]); - m1 = PoolPre

(m1, input_ptr1[r + 2]); - float m2 = PoolPre

(input_ptr2[r], input_ptr2[r + 1]); - m2 = PoolPre

(m2, input_ptr2[r + 2]); - float m3 = PoolPre

(input_ptr3[r], input_ptr3[r + 1]); - m3 = PoolPre

(m3, input_ptr3[r + 2]); - float m4 = PoolPre

(input_ptr4[r], input_ptr4[r + 1]); - m4 = PoolPre

(m4, input_ptr4[r + 2]); - float m5 = PoolPre

(input_ptr5[r], input_ptr5[r + 1]); - m5 = PoolPre

(m5, input_ptr5[r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - m1 = PoolPre

(PoolPre

(m1, m2), m3); - m2 = PoolPre

(PoolPre

(m2, m3), m4); - m3 = PoolPre

(PoolPre

(m3, m4), m5); - output_ptr0[r] = PoolPost

(m0, avg); - output_ptr1[r] = PoolPost

(m1, avg); - output_ptr2[r] = PoolPost

(m2, avg); - output_ptr3[r] = PoolPost

(m3, avg); + // pad right + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t row3 = vld1_f32(input_ptr3); + float32x2_t row4 = vld1_f32(input_ptr4); + float32x2_t row5 = vld1_f32(input_ptr5); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, acc1, acc2, acc3, acc12, acc34, post; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + *output_ptr2 = 0.f; + *output_ptr3 = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc12 = vPoolPre_f32

(row1, row2); + acc34 = vPoolPre_f32

(row3, row4); + acc0 = vPoolPre_f32

(row0, acc12); + acc1 = vPoolPre_f32

(row3, acc12); + acc2 = vPoolPre_f32

(row2, acc34); + acc3 = vPoolPre_f32

(row5, acc34); + acc0 = vpPoolPre_f32

(acc0, acc0); + acc1 = vpPoolPre_f32

(acc1, acc1); + acc2 = vpPoolPre_f32

(acc2, acc2); + acc3 = vpPoolPre_f32

(acc3, acc3); + acc0 = vPoolPost_f32

(acc0, post); + acc1 = vPoolPost_f32

(acc1, post); + acc2 = vPoolPost_f32

(acc2, post); + acc3 = vPoolPost_f32

(acc3, post); + vst1_lane_f32(output_ptr0, acc0, 0); + vst1_lane_f32(output_ptr1, acc1, 0); + vst1_lane_f32(output_ptr2, acc2, 0); + vst1_lane_f32(output_ptr3, acc3, 0); + row0 = vext_f32(row0, pad0, 1); + row1 = vext_f32(row1, pad0, 1); + row2 = vext_f32(row2, pad0, 1); + row3 = vext_f32(row3, pad0, 1); + row4 = vext_f32(row4, pad0, 1); + row5 = vext_f32(row5, pad0, 1); + } + output_ptr0++; + output_ptr1++; + output_ptr2++; + output_ptr3++; + } } } // remain height @@ -548,9 +565,33 @@ struct Pooling3x3 { const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; const float *input_ptr2 = input_ptr1 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float *output_ptr0 = output_ptr + h * output_w; + // pad left + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, post; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc0 = vPoolPre_f32

(row0, row1); + acc0 = vPoolPre_f32

(acc0, row2); + acc0 = vpPoolPre_f32

(acc0, acc0); + acc0 = vPoolPost_f32

(acc0, post); + vst1_lane_f32(output_ptr0 + w, acc0, 0); + row0 = vext_f32(pad0, row0, 1); + row1 = vext_f32(pad0, row1, 1); + row2 = vext_f32(pad0, row2, 1); + } + } + output_ptr0 += valid_w_start; + } + // valid float32x4x2_t x0, x1, x2, y0; float32x4_t post = vdupq_n_f32(1.f / 9); for (int loop = 0; loop < output_w_tiles; ++loop) { @@ -601,51 +642,101 @@ struct Pooling3x3 { output_ptr0 += 6; } // remain width - if (remain >= 4) { + if (output_w_remain > 0) { x0.val[0] = vld1q_f32(input_ptr0); x0.val[1] = vld1q_f32(input_ptr0 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); x0.val[0] = vld1q_f32(input_ptr1); x0.val[1] = vld1q_f32(input_ptr1 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); x0.val[0] = vld1q_f32(input_ptr2); x0.val[1] = vld1q_f32(input_ptr2 + 4); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); + x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); + x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2); x0.val[0] = vPoolPreq_f32

(x0.val[0], x1.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); - - input_ptr0 += 4; - input_ptr1 += 4; - input_ptr2 += 4; - output_ptr0 += 4; - remain -= 4; + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + // restore + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + break; + } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + output_ptr0 += output_w_remain; } -#endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[r], input_ptr0[r + 1]); - m0 = PoolPre

(m0, input_ptr0[r + 2]); - float m1 = PoolPre

(input_ptr1[r], input_ptr1[r + 1]); - m1 = PoolPre

(m1, input_ptr1[r + 2]); - float m2 = PoolPre

(input_ptr2[r], input_ptr2[r + 1]); - m2 = PoolPre

(m2, input_ptr2[r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - output_ptr0[r] = PoolPost

(m0, avg); + // pad right + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, post; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc0 = vPoolPre_f32

(row0, row1); + acc0 = vPoolPre_f32

(acc0, row2); + acc0 = vpPoolPre_f32

(acc0, acc0); + acc0 = vPoolPost_f32

(acc0, post); + vst1_lane_f32(output_ptr0, acc0, 0); + row0 = vext_f32(row0, pad0, 1); + row1 = vext_f32(row1, pad0, 1); + row2 = vext_f32(row2, pad0, 1); + } + output_ptr0++; + } } } + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); + } } } } @@ -667,12 +758,22 @@ struct Pooling3x3 { int image_size = input_h * input_w; int out_image_size = output_h * output_w; int valid_h_start = (padding_h + 1) / 2; - int valid_h = (input_h - 3) / 2 + 1; - int valid_h_end = valid_h_start + valid_h; + int valid_h_end = (input_h + padding_h - 1) / 2; + int valid_h = valid_h_end - valid_h_start; int valid_w_start = (padding_w + 1) / 2; - int valid_w = (input_w - 3) / 2 + 1; - int valid_w_end = valid_w_start + valid_w; - float avg = 1.f / 9; + int valid_w_end = (input_w + padding_w - 1) / 2; + int valid_w = valid_w_end - valid_w_start; + + int padding_height = input_h + 2 * padding_h; + int padding_width = input_w + 2 * padding_w; + bool ceil_mode = (((padding_height - 1) / 2) < output_h) || + (((padding_width - 1) / 2) < output_w); + int padding_b = + padding_h + (ceil_mode ? 2 * output_h - (padding_height - 1) : 0); + int padding_r = + padding_w + (ceil_mode ? 2 * output_w - (padding_width - 1) : 0); + // for pad left + int valid_input_w_start = (valid_w_start << 1) - padding_w; #pragma omp parallel for collapse(2) for (int batch = 0; batch < output->dims()[0]; ++batch) { @@ -685,41 +786,70 @@ struct Pooling3x3 { Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, padding_w, output_w, output_ptr); } - // left - for (int w = 0; w < valid_w_start; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // right - for (int w = valid_w_end; w < output_w; ++w) { - Pooling3x3ValidCol(input_ptr, valid_h_start, valid_h_end, w, - input_h, input_w, padding_h, padding_w, - output_w, output_ptr); - } - // bottom - for (int h = valid_h_end; h < output_h; ++h) { - Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, - padding_w, output_w, output_ptr); - } // valid - int input_w_start = 2 * valid_w_start - padding_w; int output_w_tiles = valid_w / 6; int output_w_remain = valid_w - output_w_tiles * 6; for (int h = valid_h_start; h < valid_h_end - 2; h += 3) { - size_t offset = (2 * h - padding_h) * input_w + input_w_start; - const float *input_ptr0 = input_ptr + offset; + const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; const float *input_ptr2 = input_ptr1 + input_w; const float *input_ptr3 = input_ptr2 + input_w; const float *input_ptr4 = input_ptr3 + input_w; const float *input_ptr5 = input_ptr4 + input_w; const float *input_ptr6 = input_ptr5 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; + float *output_ptr0 = output_ptr + h * output_w; float *output_ptr1 = output_ptr0 + output_w; float *output_ptr2 = output_ptr1 + output_w; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) + // pad left + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t row3 = vld1_f32(input_ptr3); + float32x2_t row4 = vld1_f32(input_ptr4); + float32x2_t row5 = vld1_f32(input_ptr5); + float32x2_t row6 = vld1_f32(input_ptr6); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, acc1, acc2, post; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + output_ptr2[w] = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc0 = vPoolPre_f32

(row0, row1); + acc1 = vPoolPre_f32

(row2, row3); + acc2 = vPoolPre_f32

(row4, row5); + acc0 = vPoolPre_f32

(acc0, row2); + acc1 = vPoolPre_f32

(acc1, row4); + acc2 = vPoolPre_f32

(acc2, row6); + if (padding == 1) { + acc0 = vpPoolPre_f32

(acc0, acc0); + acc1 = vpPoolPre_f32

(acc1, acc1); + acc2 = vpPoolPre_f32

(acc2, acc2); + } + acc0 = vPoolPost_f32

(acc0, post); + acc1 = vPoolPost_f32

(acc1, post); + acc2 = vPoolPost_f32

(acc2, post); + vst1_lane_f32(output_ptr0 + w, acc0, 0); + vst1_lane_f32(output_ptr1 + w, acc1, 0); + vst1_lane_f32(output_ptr2 + w, acc2, 0); + } + } + input_ptr0 += valid_input_w_start; + input_ptr1 += valid_input_w_start; + input_ptr2 += valid_input_w_start; + input_ptr3 += valid_input_w_start; + input_ptr4 += valid_input_w_start; + input_ptr5 += valid_input_w_start; + input_ptr6 += valid_input_w_start; + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; + output_ptr2 += valid_w_start; + } + // valid float32x4x2_t x0, x1, x2; float32x4x2_t y0, y1, y2; float32x4_t post = vdupq_n_f32(1.f / 9); @@ -823,108 +953,210 @@ struct Pooling3x3 { output_ptr2 += 6; } // remain width - if (remain >= 4) { + if (output_w_remain > 0) { x0 = vld2q_f32(input_ptr0); - x1.val[0] = vdupq_n_f32(input_ptr0[8]); + x1 = vld2q_f32(input_ptr0 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); x0 = vld2q_f32(input_ptr1); - x1.val[0] = vdupq_n_f32(input_ptr1[8]); + x1 = vld2q_f32(input_ptr1 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); x0 = vld2q_f32(input_ptr2); - x1.val[0] = vdupq_n_f32(input_ptr2[8]); + x1 = vld2q_f32(input_ptr2 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); y1.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y1.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(y1.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(y1.val[1], y0.val[1]); y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); x0 = vld2q_f32(input_ptr3); - x1.val[0] = vdupq_n_f32(input_ptr3[8]); + x1 = vld2q_f32(input_ptr3 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y1.val[0] = vPoolPreq_f32

(x0.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(x0.val[1], y1.val[1]); x0 = vld2q_f32(input_ptr4); - x1.val[0] = vdupq_n_f32(input_ptr4[8]); + x1 = vld2q_f32(input_ptr4 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y1.val[0] = vPoolPreq_f32

(y0.val[0], y1.val[0]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y1.val[0] = vPoolPreq_f32

(y2.val[0], y1.val[0]); + y1.val[1] = vPoolPreq_f32

(y2.val[1], y1.val[1]); y1.val[0] = vPoolPostq_f32

(y1.val[0], post); - vst1q_f32(output_ptr1, y1.val[0]); + y1.val[1] = vPoolPostq_f32

(y1.val[1], post); x0 = vld2q_f32(input_ptr5); - x1.val[0] = vdupq_n_f32(input_ptr5[8]); + x1 = vld2q_f32(input_ptr5 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], y2.val[1]); x0 = vld2q_f32(input_ptr6); - x1.val[0] = vdupq_n_f32(input_ptr6[8]); + x1 = vld2q_f32(input_ptr6 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); - y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); - y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr2, y0.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); + y2.val[0] = vPoolPreq_f32

(x0.val[0], y2.val[0]); + y2.val[1] = vPoolPreq_f32

(x0.val[1], y2.val[1]); + y2.val[0] = vPoolPostq_f32

(y2.val[0], post); + y2.val[1] = vPoolPostq_f32

(y2.val[1], post); - input_ptr0 += 8; - input_ptr1 += 8; - input_ptr2 += 8; - input_ptr3 += 8; - input_ptr4 += 8; - input_ptr5 += 8; - input_ptr6 += 8; - output_ptr0 += 4; - output_ptr1 += 4; - output_ptr2 += 4; - remain -= 4; + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + vst1q_lane_f32(output_ptr1, y1.val[0], 0); + vst1q_lane_f32(output_ptr2, y2.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1_f32(output_ptr1, vget_low_f32(y1.val[0])); + vst1_f32(output_ptr2, vget_low_f32(y2.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1_f32(output_ptr1, vget_low_f32(y1.val[0])); + vst1_f32(output_ptr2, vget_low_f32(y2.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2); + vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_f32(output_ptr1, y1.val[0]); + vst1q_f32(output_ptr2, y2.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_f32(output_ptr1, y1.val[0]); + vst1q_f32(output_ptr2, y2.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0); + vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0); + break; + } + input_ptr0 += (output_w_remain << 1); + input_ptr1 += (output_w_remain << 1); + input_ptr2 += (output_w_remain << 1); + input_ptr3 += (output_w_remain << 1); + input_ptr4 += (output_w_remain << 1); + input_ptr5 += (output_w_remain << 1); + input_ptr6 += (output_w_remain << 1); + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; + output_ptr2 += output_w_remain; } -#endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[2 * r], input_ptr0[2 * r + 1]); - m0 = PoolPre

(m0, input_ptr0[2 * r + 2]); - float m1 = PoolPre

(input_ptr1[2 * r], input_ptr1[2 * r + 1]); - m1 = PoolPre

(m1, input_ptr1[2 * r + 2]); - float m2 = PoolPre

(input_ptr2[2 * r], input_ptr2[2 * r + 1]); - m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); - float m3 = PoolPre

(input_ptr3[2 * r], input_ptr3[2 * r + 1]); - m3 = PoolPre

(m3, input_ptr3[2 * r + 2]); - float m4 = PoolPre

(input_ptr4[2 * r], input_ptr4[2 * r + 1]); - m4 = PoolPre

(m4, input_ptr4[2 * r + 2]); - float m5 = PoolPre

(input_ptr5[2 * r], input_ptr5[2 * r + 1]); - m5 = PoolPre

(m5, input_ptr5[2 * r + 2]); - float m6 = PoolPre

(input_ptr6[2 * r], input_ptr6[2 * r + 1]); - m6 = PoolPre

(m6, input_ptr6[2 * r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - m1 = PoolPre

(PoolPre

(m2, m3), m4); - m2 = PoolPre

(PoolPre

(m4, m5), m6); - output_ptr0[r] = PoolPost

(m0, avg); - output_ptr1[r] = PoolPost

(m1, avg); - output_ptr2[r] = PoolPost

(m2, avg); + // pad right + if (padding_r > 0) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t row3 = vld1_f32(input_ptr3); + float32x2_t row4 = vld1_f32(input_ptr4); + float32x2_t row5 = vld1_f32(input_ptr5); + float32x2_t row6 = vld1_f32(input_ptr6); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, acc1, acc2, post; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + *output_ptr2 = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc0 = vPoolPre_f32

(row0, row1); + acc1 = vPoolPre_f32

(row2, row3); + acc2 = vPoolPre_f32

(row4, row5); + acc0 = vPoolPre_f32

(acc0, row2); + acc1 = vPoolPre_f32

(acc1, row4); + acc2 = vPoolPre_f32

(acc2, row6); + if (padding == 1) { + acc0 = vpPoolPre_f32

(acc0, acc0); + acc1 = vpPoolPre_f32

(acc1, acc1); + acc2 = vpPoolPre_f32

(acc2, acc2); + } + acc0 = vPoolPost_f32

(acc0, post); + acc1 = vPoolPost_f32

(acc1, post); + acc2 = vPoolPost_f32

(acc2, post); + vst1_lane_f32(output_ptr0, acc0, 0); + vst1_lane_f32(output_ptr1, acc1, 0); + vst1_lane_f32(output_ptr2, acc2, 0); + } + output_ptr0++; + output_ptr1++; + output_ptr2++; + } } } // remain height int start_h = valid_h_start + valid_h / 3 * 3; for (int h = start_h; h < valid_h_end; ++h) { - size_t offset = (2 * h - padding_h) * input_w + input_w_start; - const float *input_ptr0 = input_ptr + offset; + const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; const float *input_ptr1 = input_ptr0 + input_w; const float *input_ptr2 = input_ptr1 + input_w; - float *output_ptr0 = output_ptr + h * output_w + valid_w_start; - int remain = output_w_remain; -#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float *output_ptr0 = output_ptr + h * output_w; + // pad left + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, post; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc0 = vPoolPre_f32

(row0, row1); + acc0 = vPoolPre_f32

(acc0, row2); + if (padding == 1) { + acc0 = vpPoolPre_f32

(acc0, acc0); + } + acc0 = vPoolPost_f32

(acc0, post); + vst1_lane_f32(output_ptr0 + w, acc0, 0); + } + } + input_ptr0 += valid_input_w_start; + input_ptr1 += valid_input_w_start; + input_ptr2 += valid_input_w_start; + output_ptr0 += valid_w_start; + } + // valid float32x4x2_t x0, x1, x2, y0; float32x4_t post = vdupq_n_f32(1.f / 9); for (int loop = 0; loop < output_w_tiles; ++loop) { @@ -969,48 +1201,94 @@ struct Pooling3x3 { output_ptr0 += 6; } // remain width - if (remain >= 4) { + if (output_w_remain > 0) { x0 = vld2q_f32(input_ptr0); - x1.val[0] = vdupq_n_f32(input_ptr0[8]); + x1 = vld2q_f32(input_ptr0 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); x0 = vld2q_f32(input_ptr1); - x1.val[0] = vdupq_n_f32(input_ptr1[8]); + x1 = vld2q_f32(input_ptr1 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); x0 = vld2q_f32(input_ptr2); - x1.val[0] = vdupq_n_f32(input_ptr2[8]); + x1 = vld2q_f32(input_ptr2 + 8); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); + x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x0.val[0] = vPoolPreq_f32

(x0.val[0], x0.val[1]); + x0.val[1] = vPoolPreq_f32

(x1.val[0], x1.val[1]); x0.val[0] = vPoolPreq_f32

(x0.val[0], x2.val[0]); + x0.val[1] = vPoolPreq_f32

(x0.val[1], x2.val[1]); y0.val[0] = vPoolPreq_f32

(x0.val[0], y0.val[0]); + y0.val[1] = vPoolPreq_f32

(x0.val[1], y0.val[1]); y0.val[0] = vPoolPostq_f32

(y0.val[0], post); - vst1q_f32(output_ptr0, y0.val[0]); - - input_ptr0 += 8; - input_ptr1 += 8; - input_ptr2 += 8; - output_ptr0 += 4; - remain -= 4; + y0.val[1] = vPoolPostq_f32

(y0.val[1], post); + // restore + switch (output_w_remain) { + case 1: + vst1q_lane_f32(output_ptr0, y0.val[0], 0); + break; + case 2: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + break; + case 3: + vst1_f32(output_ptr0, vget_low_f32(y0.val[0])); + vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2); + break; + case 4: + vst1q_f32(output_ptr0, y0.val[0]); + break; + case 5: + vst1q_f32(output_ptr0, y0.val[0]); + vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0); + break; + } + input_ptr0 += (output_w_remain << 1); + input_ptr1 += (output_w_remain << 1); + input_ptr2 += (output_w_remain << 1); + output_ptr0 += output_w_remain; } -#endif // __ARM_NEON__ - for (int r = 0; r < remain; ++r) { - float m0 = PoolPre

(input_ptr0[2 * r], input_ptr0[2 * r + 1]); - m0 = PoolPre

(m0, input_ptr0[2 * r + 2]); - float m1 = PoolPre

(input_ptr1[2 * r], input_ptr1[2 * r + 1]); - m1 = PoolPre

(m1, input_ptr1[2 * r + 2]); - float m2 = PoolPre

(input_ptr2[2 * r], input_ptr2[2 * r + 1]); - m2 = PoolPre

(m2, input_ptr2[2 * r + 2]); - - m0 = PoolPre

(PoolPre

(m0, m1), m2); - output_ptr0[r] = PoolPost

(m0, avg); + // pad right + if (padding_r > 0) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t pad0 = vPoolInit_f32

(); + float32x2_t acc0, post; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; + } else { + post = vdup_n_f32(1.f / (3 * (3 - padding))); + acc0 = vPoolPre_f32

(row0, row1); + acc0 = vPoolPre_f32

(acc0, row2); + if (padding == 1) { + acc0 = vpPoolPre_f32

(acc0, acc0); + } + acc0 = vPoolPost_f32

(acc0, post); + vst1_lane_f32(output_ptr0, acc0, 0); + } + output_ptr0++; + } } } + // bottom + for (int h = valid_h_end; h < output_h; ++h) { + Pooling3x3NormalRow(input_ptr, h, input_h, input_w, padding_h, + padding_w, output_w, output_ptr); + } } } } @@ -1025,4 +1303,5 @@ template struct Pooling3x3; } // namespace operators } // namespace paddle_mobile +#endif // __ARM_NEON #endif // POOL_OP diff --git a/src/operators/math/quantize.h b/src/operators/math/quantize.h index b6e9d1a24d0d760a3eadbeec84d8c855072f6784..9f9e91330cd9645eb1fbc2fd01192eda4ef0fc7a 100644 --- a/src/operators/math/quantize.h +++ b/src/operators/math/quantize.h @@ -56,6 +56,9 @@ inline int32x4_t vRoundq_f32(const float32x4_t &x) { template <> inline int32x4_t vRoundq_f32(const float32x4_t &x) { +#if __aarch64__ + return vcvtaq_s32_f32(x); +#else float32x4_t plus = vdupq_n_f32(0.5); float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t zero = vdupq_n_f32(0); @@ -64,10 +67,14 @@ inline int32x4_t vRoundq_f32(const float32x4_t &x) { temp = vaddq_f32(x, temp); int32x4_t ret = vcvtq_s32_f32(temp); return ret; +#endif } template <> inline int32x4_t vRoundq_f32(const float32x4_t &x) { +#if __aarch64__ + return vcvtnq_s32_f32(x); +#else float32x4_t point5 = vdupq_n_f32(0.5); int32x4_t one = vdupq_n_s32(1); int32x4_t zero = vdupq_n_s32(0); @@ -90,6 +97,7 @@ inline int32x4_t vRoundq_f32(const float32x4_t &x) { smask = vsubq_s32(smask, one); rnd = vaddq_s32(rnd, smask); return rnd; +#endif } #endif // __ARM_NEON__ diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 0362ee44454569980f017c2b527acf8b77e2ff10..9d7c213afa8277c421c0e6cce6cdaefa5ef58dd9 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -424,8 +424,10 @@ class ConvParam : public OpParam { EXEC_DEPTHWISE3x3_FLOAT, EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD5X5_FLOAT, + EXEC_DEPTHWISE5x5_FLOAT, EXEC_GEMM_INT8, EXEC_DEPTHWISE3x3_INT8, + EXEC_DEPTHWISE5x5_INT8, }; ExecMode &ExecMode() const { return exec_mode_; } @@ -2605,8 +2607,8 @@ class QuantizeParam : public OpParam { // if offine scale or not bool offline_ = false; // round method type - // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; - RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; + RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + // RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; }; #endif diff --git a/test/operators/test_conv_op.cpp b/test/operators/test_conv_op.cpp index fdc516a001489e56b020a8901f1ffa6dafe029dc..c596c1def4006853532395f151c6e9c47cf8e3e8 100644 --- a/test/operators/test_conv_op.cpp +++ b/test/operators/test_conv_op.cpp @@ -165,14 +165,12 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, auto filter = filter_var->template GetMutable(); SetupTensor(filter, filter_shape, -20, 20); - for (int i = 0; i < input->numel(); ++i) { - DLOG << "input[" << i - << "] = " << static_cast(input->data()[i]); - } - for (int i = 0; i < filter->numel(); ++i) { - DLOG << "filter[" << i - << "] = " << static_cast(filter->data()[i]); - } + // for (int i = 0; i < input->numel(); ++i) { + // DLOG << "input[" << i << "] = " << float(input->data()[i]); + // } + // for (int i = 0; i < filter->numel(); ++i) { + // DLOG << "filter[" << i << "] = " << float(filter->data()[i]); + // } auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; @@ -198,18 +196,12 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, // (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6; // LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms"; - int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; - int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1; - int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1; - auto output_shape = framework::make_ddim( - std::vector({batch_size, output_c, output_h, output_w})); + // compare results + auto *output = output_var->template Get(); framework::Tensor output_cmp; - output_cmp.mutable_data(output_shape); + output_cmp.mutable_data(output->dims()); conv2d(input, filter, attrs, &output_cmp); - // compare results - auto output = output_var->template Get(); const Otype *output_data = output->data(); Otype *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { @@ -285,96 +277,39 @@ int main(int argc, char *argv[]) { paddle_mobile::TestConvOp( in_channels, in_height, in_width, out_channels, groups); - // // kernel = 7, pad = 0, stride = 2 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 1, stride = 2 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 3, stride = 2 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 1, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 3, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 5, stride = 3 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 7, pad = 3, stride = 4 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 3, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 3, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; - // paddle_mobile::TestConvOp(in_channels, in_height, - // in_width, out_channels, - // groups); - // // kernel = 3, pad = 1, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 3, pad = 1, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; - // paddle_mobile::TestConvOp(in_channels, in_height, - // in_width, out_channels, - // groups); - // // kernel = 5, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 5, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; - // paddle_mobile::TestConvOp(in_channels, in_height, - // in_width, out_channels, - // groups); - // // kernel = 5, pad = 2, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; - // paddle_mobile::TestConvOp(in_channels, - // in_height, - // in_width, - // out_channels, groups); - // // kernel = 5, pad = 2, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; - // paddle_mobile::TestConvOp(in_channels, in_height, - // in_width, out_channels, - // groups); + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 5, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=1, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 5, pad = 5, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=5, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + + // kernel = 5, pad = 0, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 5, pad = 1, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=1, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 5, pad = 2, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 5, pad = 5, stride = 1 + LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=5, stride=1"; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + + return 0; } diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index 5d3c4374a403e0f3050d9b9babd3d09bdff03bc9..acbf0eaf34c8cb7b35a94fd4e8a4a3867a7c1dff 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -169,28 +169,55 @@ int main(int argc, char *argv[]) { << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); - // // kernel = 5, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) - // << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, - // stride=1"; - // paddle_mobile::TestPoolOp(in_channels, in_height, - // in_width); - // // kernel = 5, pad = 0, stride = 2 - // LOG(paddle_mobile::kLOG_INFO) - // << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, - // stride=1"; - // paddle_mobile::TestPoolOp(in_channels, in_height, - // in_width); - // // kernel = 7, pad = 0, stride = 1 - // LOG(paddle_mobile::kLOG_INFO) - // << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, - // stride=1"; - // paddle_mobile::TestPoolOp(in_channels, in_height, - // in_width); - // // kernel = 7, pad = 0, stride = 4 - // LOG(paddle_mobile::kLOG_INFO) - // << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, - // stride=4"; - // paddle_mobile::TestPoolOp(in_channels, in_height, - // in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=0, stride=1"; + // paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=1, stride=1"; + // paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=2, stride=1"; + // paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=5, stride=1"; + // paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width); + // + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=0, stride=1"; + // paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=1, stride=1"; + // paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=2, stride=1"; + // paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=5, stride=1"; + // paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width); + // + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=0, stride=2"; + // paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=1, stride=2"; + // paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=2, stride=2"; + // paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=max, kernel=2, pad=5, stride=2"; + // paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width); + // + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=0, stride=2"; + // paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=1, stride=2"; + // paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=2, stride=2"; + // paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width); + // LOG(paddle_mobile::kLOG_INFO) + // << "float, pooling_type=avg, kernel=2, pad=5, stride=2"; + // paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width); }