未验证 提交 96dfa087 编写于 作者: R Ray Liu 提交者: GitHub

Merge pull request #1375 from hjchen2/ocr_ctc

Optimize int8/float 5x5 depthwise conv and 2x2 pooling, add aarch64 macros to make compilation no problem
......@@ -64,9 +64,9 @@ void OperatorBase<Dtype>::Run() {
for (const auto key : input_keys) {
auto var_vec_in = inputs_.at(key);
for (int i = 0; i < var_vec_in.size(); ++i) {
auto vari = scope_->FindVar(var_vec_in[i]);
auto vari = this->scope_->FindVar(var_vec_in[i]);
if (vari->IsInitialized()) {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
}
}
......@@ -76,7 +76,7 @@ void OperatorBase<Dtype>::Run() {
for (int i = 0; i < var_vec_out.size(); ++i) {
auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor;
}
}
......@@ -97,10 +97,10 @@ void OperatorBase<GPU_CL>::Run() {
auto vari = scope_->FindVar(var_vec_in[i]);
if (vari->IsInitialized()) {
if (type_ == "feed") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
const CLImage *cl_image = vari->template Get<framework::CLImage>();
if (cl_image) {
DLOG << type_ << " input- " << key << "=" << *cl_image;
}
......@@ -114,12 +114,12 @@ void OperatorBase<GPU_CL>::Run() {
auto vari = scope_->FindVar(var_vec_out[i]);
if (vari->IsInitialized()) {
if (type_ == "fetch") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
const Tensor *tensor = vari->template Get<framework::LoDTensor>();
if (tensor) {
DLOG << type_ << " output- " << key << "=" << *tensor;
}
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
const CLImage *cl_image = vari->template Get<framework::CLImage>();
if (cl_image) {
DLOG << type_ << " output- " << key << "=" << *cl_image;
}
......
......@@ -14,6 +14,7 @@
#include "io/api_paddle_mobile.h"
#include <vector>
#include "common/enforce.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......
......@@ -12,19 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains the implementation of inference API with Anakin engine
* embeded, this API can only support Anakin models.
*/
#pragma once
#include <vector>
#include "io/paddle_inference_api.h"
// from paddle_mobile
#include "common/enforce.h"
#include "common/types.h"
#include "io/paddle_inference_api.h"
#include "io/paddle_mobile.h"
namespace paddle_mobile {
......
......@@ -104,6 +104,8 @@ class PaddlePredictor {
// The common configs for all the predictors.
struct Config {
std::string model_dir; // path to the model directory.
std::string prog_file;
std::string param_file;
};
protected:
......@@ -128,9 +130,8 @@ struct PaddleMobileConfig : public PaddlePredictor::Config {
int batch_size = 1;
bool optimize = true;
bool quantification = false;
bool lod_mode = false;
int thread_num = 1;
std::string prog_file;
std::string param_file;
std::string cl_path;
struct PaddleModelMemoryPack memory_pack;
};
......
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "io/paddle_mobile.h"
#include <utility>
#include "common/common.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_MOBILE_CL
#include <CL/cl.h>
#include "framework/cl/cl_tensor.h"
......@@ -33,7 +36,7 @@ void PaddleMobile<Device, T>::SetThreadNum(int num) {
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -43,7 +46,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(dirname, optimize, quantification), batch_size, optimize,
loddable);
lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -55,7 +58,7 @@ template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -65,7 +68,7 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(model_path, para_path, optimize, quantification),
batch_size, optimize, loddable);
batch_size, optimize, lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -73,11 +76,26 @@ PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
return PMSuccess;
}
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const PaddleMobileConfig &config) {
if (!config.model_dir.empty()) {
return this->Load(config.model_dir, config.optimize, config.quantification,
config.batch_size, config.lod_mode);
} else if (!config.prog_file.empty() && !config.param_file.empty()) {
return this->Load(config.prog_file, config.param_file, config.optimize,
config.quantification, config.batch_size,
config.lod_mode);
} else {
LOG(kLOG_ERROR) << "Failed to load inference model";
return PMNotInitialized;
}
}
template <typename Device, typename T>
bool PaddleMobile<Device, T>::LoadCombinedMemory(
size_t model_len, const uint8_t *model_buf, size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize, bool quantification,
int batch_size, bool loddable) {
int batch_size, bool lod_mode) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
......@@ -88,7 +106,7 @@ bool PaddleMobile<Device, T>::LoadCombinedMemory(
loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len,
combined_params_buf, optimize,
quantification),
batch_size, optimize, loddable);
batch_size, optimize, lod_mode);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......
......@@ -18,15 +18,12 @@ limitations under the License. */
#include <string>
#include <utility>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#include "common/types.h"
#include "framework/executor.h"
#include "framework/load_ops.h"
#include "framework/loader.h"
#include "framework/tensor.h"
#include "io/paddle_inference_api.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_engine.h"
#endif
......@@ -46,10 +43,12 @@ class PaddleMobile {
PMStatus Load(const std::string &dirname, const bool optimize = false,
const bool quantification = false, const int batch_size = 1,
const bool lod = false);
const bool lod_mode = false);
PMStatus Load(const std::string &model_path, const std::string &para_path,
const bool optimize = false, const bool quantification = false,
const int batch_size = 1, const bool lod = false);
const int batch_size = 1, const bool lod_mode = false);
PMStatus Load(const PaddleMobileConfig &config);
PMStatus Predict(const framework::Tensor &input);
PMStatus Predict(const framework::LoDTensor &input);
......@@ -75,7 +74,7 @@ class PaddleMobile {
size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize = false,
bool quantification = false, int batch_size = 1,
bool loddable = false);
bool lod_mode = false);
void SetThreadNum(int count);
void Clear();
......
......@@ -24,15 +24,26 @@ template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 5;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
#ifndef __aarch64__
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else if (depth5x5 && param->Strides()[0] < 2 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8;
} else {
#endif // __aarch64__
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
#ifndef __aarch64__
}
#endif // __aarch64__
} else {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
......@@ -47,6 +58,9 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT;
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
......@@ -72,9 +86,14 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_INT8:
DepthwiseConv5x5<int8_t, int32_t>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
......@@ -87,9 +106,14 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false);
break;
#ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
#endif // __aarch64__
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
......
......@@ -15,7 +15,8 @@ limitations under the License. */
#ifdef POOL_OP
#include "operators/kernel/pool_kernel.h"
#include "../central-arm-func/pool_arm_func.h"
#include "operators/kernel/central-arm-func/pool_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -28,7 +29,8 @@ template <>
void PoolKernel<CPU, float>::Compute(const PoolParam<CPU> &param) {
PoolCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // POOL_OP
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/math/depthwise_conv5x5.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/pad.h"
......@@ -160,6 +161,7 @@ inline void WinogradConv3x3(const ConvParam<CPU> &param) {
}
}
#ifndef __aarch64__
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
......@@ -180,14 +182,34 @@ inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3S2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented.");
GemmConv<Itype, Otype>(param);
}
}
}
template <typename Itype, typename Otype>
inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output();
output->mutable_data<Otype>();
if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
}
} else {
GemmConv<Itype, Otype>(param);
}
}
#endif // __aarch64__
} // namespace operators
} // namespace paddle_mobile
......
......@@ -59,12 +59,11 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const float *input = input_data + offset;
const float bias = bias_data[j];
float *output = output_data + offset;
int remain = elementwise_num;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = elementwise_num >> 0x4;
remain = elementwise_num & 0xF;
int remain = elementwise_num & 0xF;
float32x4_t rb = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k) {
float32x4_t rb = vdupq_n_f32(bias);
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
......@@ -80,10 +79,46 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
input += 16;
output += 16;
}
#endif
for (int k = 0; k < remain; ++k) {
if (remain >= 8) {
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
input += 8;
output += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
vst1q_f32(output, r0);
input += 4;
output += 4;
remain -= 4;
}
if (remain > 0) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
switch (remain) {
case 1:
vst1q_lane_f32(output, r0, 0);
break;
case 2:
vst1_f32(output, vget_low_f32(r0));
break;
case 3:
vst1_f32(output, vget_low_f32(r0));
vst1q_lane_f32(output, r0, 2);
break;
}
}
#else
for (int k = 0; k < elementwise_num; ++k) {
output[k] = input[k] + bias;
}
#endif // __ARM_NEON__
}
}
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#pragma once
#include <string>
......@@ -54,8 +55,24 @@ void PoolCompute(const PoolParam<CPU> &param) {
} else {
math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
}
} else {
// Others
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1]) {
if (pooling_type == "max" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling2x2<MAX, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling2x2<MAX, 2>()(*input, paddings, output);
} else {
math::Pooling<MAX>()(*input, ksize, strides, paddings, output);
}
} else if (pooling_type == "avg" && strides[0] == strides[1]) {
if (strides[0] == 1) {
math::Pooling2x2<AVG, 1>()(*input, paddings, output);
} else if (strides[0] == 2) {
math::Pooling2x2<AVG, 2>()(*input, paddings, output);
} else {
math::Pooling<AVG>()(*input, ksize, strides, paddings, output);
}
}
} else {
if (pooling_type == "max") {
......
......@@ -253,7 +253,6 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const float *bias_data = bias->data<float>();
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int h = static_cast<int>(input->dims()[2]);
......@@ -267,6 +266,11 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const int lb = (h - 1) * w;
const int rb = h * w - 1;
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; ++b) {
......@@ -1966,7 +1970,6 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias, bool if_relu) {
#if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
......@@ -1983,7 +1986,12 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias->data<float>() + c;
const float *bias_data;
float32x4_t biasv;
if (if_bias) {
bias_data = bias->data<float>() + c;
biasv = vld1q_dup_f32(bias_data);
}
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
......@@ -1994,7 +2002,6 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
float w20 = filter_data[6];
float w21 = filter_data[7];
float w22 = filter_data[8];
float32x4_t biasv = vld1q_dup_f32(bias_data);
for (int i = 0; i < output_height; i += 1) {
for (int m = 0; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m;
......
......@@ -14,185 +14,13 @@ limitations under the License. */
#if defined(__ARM_NEON__) && !defined(__aarch64__)
#include "operators/math/depthwise_conv3x3.h"
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif
#include "operators/math/depthwise_conv3x3.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <int Stride>
inline void Depth3x3ValidColLoadInput(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride);
}
template <>
inline void Depth3x3ValidColLoadInput<1>(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
int8_t fake_input[3][8];
if (valid_cols == 1) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
int8x8_t input0 = vld1_s8(fake_input[0]);
int8x8_t input1 = vld1_s8(fake_input[1]);
int8x8_t input2 = vld1_s8(fake_input[2]);
y0[0] = vmovl_s8(input0);
y1[0] = vmovl_s8(input1);
y2[0] = vmovl_s8(input2);
y0[1] = vextq_s16(y0[0], y0[0], 1);
y0[2] = vextq_s16(y0[0], y0[0], 2);
y1[1] = vextq_s16(y1[0], y1[0], 1);
y1[2] = vextq_s16(y1[0], y1[0], 2);
y2[1] = vextq_s16(y2[0], y2[0], 1);
y2[2] = vextq_s16(y2[0], y2[0], 2);
}
template <>
inline void Depth3x3ValidColLoadInput<2>(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
int8_t fake_input[3][13];
if (valid_cols == 1) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
int8x8x2_t input0 = vld2_s8(fake_input[0]);
int8x8x2_t input1 = vld2_s8(fake_input[1]);
int8x8x2_t input2 = vld2_s8(fake_input[2]);
y0[0] = vmovl_s8(input0.val[0]);
y0[1] = vmovl_s8(input0.val[1]);
y0[2] = vextq_s16(y0[0], y0[0], 1);
y1[0] = vmovl_s8(input1.val[0]);
y1[1] = vmovl_s8(input1.val[1]);
y1[2] = vextq_s16(y1[0], y1[0], 1);
y2[0] = vmovl_s8(input2.val[0]);
y2[1] = vmovl_s8(input2.val[1]);
y2[2] = vextq_s16(y2[0], y2[0], 1);
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter,
const int h_output, const int h_output_end,
const int w_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
int32_t *output) {
const int w_in_start = -padding_w + w_output * Stride_w;
const int w_in_end = w_in_start + 3;
const int w_start = w_in_start > 0 ? w_in_start : 0;
const int w_end = w_in_end < input_w ? w_in_end : input_w;
int remain_start = h_output;
#ifdef __ARM_NEON__
int output_tiles = (h_output_end - h_output) / 6;
remain_start = h_output + output_tiles * 6;
int input_h_start = h_output * Stride_h - padding_h;
size_t input_offset = input_h_start * input_w + w_start;
size_t output_offset = h_output * output_w + w_output;
int16x8_t _input[3][3];
int16x4_t _kernel[3];
int32x4_t _sum0, _sum1;
const int8_t *filter_ptr = filter;
asm volatile(
"mov r0, #3 \n"
"vld1.s8 d10, [%[filter]], r0 \n"
"vld1.s8 d11, [%[filter]], r0 \n"
"vld1.s8 d12, [%[filter]] \n"
"vtrn.8 d10, d11 \n"
"vtrn.8 d12, d13 \n"
"vtrn.16 d10, d12 \n"
"vtrn.16 d11, d13 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmovl.s8 q9, d12 \n"
"vmov.32 %[_kernel0], d14 \n"
"vmov.32 %[_kernel1], d16 \n"
"vmov.32 %[_kernel2], d18 \n"
: [_kernel0] "+w"(_kernel[0]), [_kernel1] "+w"(_kernel[1]),
[_kernel2] "+w"(_kernel[2])
: [filter] "r"(filter_ptr)
: "memory", "q5", "q6", "q7", "q8", "q9", "r0");
int valid_cols = w_end - w_start;
for (int h = 0; h < output_tiles * 6; h += 6) {
int32_t *output0 = output + output_offset;
int32_t *output1 = output0 + output_w;
int32_t *output2 = output1 + output_w;
int32_t *output3 = output2 + output_w;
int32_t *output4 = output3 + output_w;
int32_t *output5 = output4 + output_w;
Depth3x3ValidColLoadInput<Stride_w>(input + input_offset, input_w,
valid_cols, _input[0], _input[1],
_input[2]);
_sum0 = veorq_s32(_sum0, _sum0);
_sum1 = veorq_s32(_sum1, _sum1);
for (int w_in = 0; w_in < valid_cols; ++w_in) {
int index = w_in + w_start - w_in_start;
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][0]),
_kernel[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][1]),
_kernel[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][2]),
_kernel[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][0]),
_kernel[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][1]),
_kernel[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][2]),
_kernel[index], 2);
}
vst1q_lane_s32(output0, _sum0, 0);
vst1q_lane_s32(output1, _sum0, 1);
vst1q_lane_s32(output2, _sum0, 2);
vst1q_lane_s32(output3, _sum0, 3);
vst1q_lane_s32(output4, _sum1, 0);
vst1q_lane_s32(output5, _sum1, 1);
input_offset += 6 * Stride_h * input_w;
output_offset += 6 * output_w;
}
#endif
for (int h = remain_start; h < h_output_end; ++h) {
int32_t value = 0;
const int h_in_start = -padding_h + h * Stride_h;
for (int i = 0; i < 3; ++i) {
for (int w_in = w_start; w_in < w_end; ++w_in) {
value += filter[i * 3 + (w_in - w_in_start)] *
input[(h_in_start + i) * input_w + w_in];
}
}
output[h * output_w + w_output] = value;
}
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
......@@ -209,34 +37,19 @@ inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter,
output_ptr[w] = value; \
}
template <int Stride>
inline void Depth3x3NormalRowLoadInput(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride);
}
template <>
inline void Depth3x3NormalRowLoadInput<1>(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
int8x8_t x0 = vld1_s8(input);
y0 = vmovl_s8(x0);
y1 = vextq_s16(y0, y0, 1);
y2 = vextq_s16(y1, y1, 1);
template <int Stride = 1>
inline void Depth3x3NormalRowLoadInput(const int8_t *input, int16x8_t *y) {
y[0] = vmovl_s8(vld1_s8(input));
y[1] = vextq_s16(y[0], y[0], 1);
y[2] = vextq_s16(y[1], y[1], 1);
}
template <>
inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input, int16x8_t *y) {
int8x8x2_t x0 = vld2_s8(input);
y0 = vmovl_s8(x0.val[0]);
y1 = vmovl_s8(x0.val[1]);
y2 = vextq_s16(y0, y0, 1);
y[0] = vmovl_s8(x0.val[0]);
y[1] = vmovl_s8(x0.val[1]);
y[2] = vextq_s16(y[0], y[0], 1);
}
template <int Stride_h, int Stride_w>
......@@ -244,15 +57,14 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
const int h_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
int32_t *output) {
int32_t *output, int16x4_t *ker) {
const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 3;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start;
const int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
const int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1;
int32_t *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
......@@ -262,14 +74,7 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6;
int32x4_t _sum0, _sum1;
int16x8_t y0, y1, y2;
int16x4_t _kernel[3];
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
int8x8_t w0 = vld1_s8(filter + index * 3);
int16x8_t w1 = vmovl_s8(w0);
_kernel[index] = vget_low_s16(w1);
}
int16x8_t _y[3];
for (int w = 0; w < output_tiles * 6; w += 6) {
_sum0 = veorq_s32(_sum0, _sum0);
_sum1 = veorq_s32(_sum1, _sum1);
......@@ -278,19 +83,18 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, y0, y1, y2);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y0), _kernel[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y1), _kernel[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y2), _kernel[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y0), _kernel[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y1), _kernel[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y2), _kernel[index], 2);
input + h_in * input_w + input_w_offset, _y);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[0]), ker[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[1]), ker[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_y[2]), ker[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[0]), ker[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[1]), ker[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_y[2]), ker[index], 2);
}
vst1q_s32(output_ptr + output_offset, _sum0);
vst1q_lane_s32(output_ptr + output_offset + 4, _sum1, 0);
vst1q_lane_s32(output_ptr + output_offset + 5, _sum1, 1);
vst1_s32(output_ptr + output_offset + 4, vget_low_s32(_sum1));
}
#endif
#endif // __ARM_NEON__
for (int w = remain_start; w < valid_w_end; ++w) {
int32_t value = 0;
int input_start = -padding_w + w * Stride_w;
......@@ -306,14 +110,6 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
}
// template<>
// void DepthwiseConv3x3<int8_t, int32_t>(
// const framework::Tensor *input, const framework::Tensor *filter,
// const std::vector<int> &strides, framework::Tensor *output) {
// PADDLE_MOBILE_THROW_EXCEPTION(
// "Depthwise conv with generic strides has not been implemented.");
// }
template <>
void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
......@@ -342,29 +138,22 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 9;
int32_t *output_ptr = out_data + g * out_image_size;
const int8_t *filter_ptr0 = filter_ptr;
const int8_t *filter_ptr1 = filter_ptr0 + 3;
const int8_t *filter_ptr2 = filter_ptr1 + 3;
int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0)));
int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1)));
int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2)));
int16x8_t _ker0 = vcombine_s16(_k0, _k1);
int16x8_t _ker1 = vcombine_s16(_k2, _k2);
int16x4_t zero = vdup_n_s16(0);
int16x4_t _ker[3] = {_k0, _k1, _k2};
// top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
DepthwiseConv3x3ValidCol<1, 1>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
DepthwiseConv3x3ValidCol<1, 1>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
output_ptr, _ker);
}
// valid
int output_w_tiles = valid_w / 6;
......@@ -376,334 +165,419 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w;
int32_t *output_ptr2 = output_ptr1 + output_w;
int32_t *output_ptr3 = output_ptr2 + output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5)));
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
output_ptr2[w] = 0;
output_ptr3[w] = 0;
} else {
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
row3 = vext_s16(zero, row3, 3);
row4 = vext_s16(zero, row4, 3);
row5 = vext_s16(zero, row5, 3);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
output_ptr1[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row2, _ker[0]);
acc = vmlal_s16(acc, row3, _ker[1]);
acc = vmlal_s16(acc, row4, _ker[2]);
output_ptr2[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row3, _ker[0]);
acc = vmlal_s16(acc, row4, _ker[1]);
acc = vmlal_s16(acc, row5, _ker[2]);
output_ptr3[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
output_ptr3 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_4h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
// loop 6 width
"loop_4h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, %f[ker0][2] \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vmull.s16 q14, d14, %e[ker0][0] \n"
"vmlal.s16 q14, d16, %e[ker0][1] \n"
"vmlal.s16 q14, d18, %e[ker0][2] \n"
"vmull.s16 q15, d15, %e[ker0][0] \n"
"vmlal.s16 q15, d17, %e[ker0][1] \n"
"vmlal.s16 q15, d19, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vld1.32 {d10}, [%[input_ptr4]], r0 \n"
"vld1.32 {d11}, [%[input_ptr5]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, %e[ker1][2] \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"vmlal.s16 q14, d14, %f[ker0][0] \n"
"vmlal.s16 q14, d16, %f[ker0][1] \n"
"vmlal.s16 q14, d18, %f[ker0][2] \n"
"vmlal.s16 q15, d15, %f[ker0][0] \n"
"vmlal.s16 q15, d17, %f[ker0][1] \n"
"vmlal.s16 q15, d19, %f[ker0][2] \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q14, d14, %e[ker1][0] \n"
"vmlal.s16 q14, d16, %e[ker1][1] \n"
"vmlal.s16 q14, d18, %e[ker1][2] \n"
"vmlal.s16 q15, d15, %e[ker1][0] \n"
"vmlal.s16 q15, d17, %e[ker1][1] \n"
"vmlal.s16 q15, d19, %e[ker1][2] \n"
// store row 2
"vst1.32 {d28-d30}, [%[output_ptr2]]! \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 3
"vst1.32 {d20-d22}, [%[output_ptr3]]! \n"
"subs %[loop], #1 \n"
"bne loop_4h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr1]] \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr2]] \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vmull.s16 q5, d14, d0 \n"
"vmlal.s16 q5, d16, d1 \n"
"vmlal.s16 q5, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr4]] \n"
"vmull.s16 q6, d15, d0 \n"
"vmlal.s16 q6, d17, d1 \n"
"vmlal.s16 q6, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"vmlal.s16 q5, d14, d3 \n"
"vmlal.s16 q5, d16, d4 \n"
"vmlal.s16 q5, d18, d5 \n"
"vld1.32 {d9}, [%[input_ptr5]] \n"
"vmlal.s16 q6, d15, d3 \n"
"vmlal.s16 q6, d17, d4 \n"
"vmlal.s16 q6, d19, d5 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d14, d6 \n"
"vmlal.s16 q5, d16, d7 \n"
"vmlal.s16 q5, d18, d8 \n"
"vmlal.s16 q6, d15, d6 \n"
"vmlal.s16 q6, d17, d7 \n"
"vmlal.s16 q6, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_4h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"vst1.32 {q5}, [%[output_ptr3]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d12[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_4h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"vst1.32 {d10}, [%[output_ptr3]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d11[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d10[0]}, [%[output_ptr3]]! \n"
"end_%=: \n"
"subs %[loop], #1 \n"
"bne loop_4h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain] \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr1]], r0 \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr2]], r0 \n"
"vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
"vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, %f[ker0][2] \n"
"vmull.s16 q14, d14, %e[ker0][0] \n"
"vmlal.s16 q14, d16, %e[ker0][1] \n"
"vmlal.s16 q14, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vmull.s16 q15, d15, %e[ker0][0] \n"
"vmlal.s16 q15, d17, %e[ker0][1] \n"
"vmlal.s16 q15, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, %e[ker1][2] \n"
"vmlal.s16 q14, d14, %f[ker0][0] \n"
"vmlal.s16 q14, d16, %f[ker0][1] \n"
"vmlal.s16 q14, d18, %f[ker0][2] \n"
"vmlal.s16 q15, d15, %f[ker0][0] \n"
"vmlal.s16 q15, d17, %f[ker0][1] \n"
"vmlal.s16 q15, d19, %f[ker0][2] \n"
"vmull.s16 q5, d14, %e[ker0][0] \n"
"vmlal.s16 q5, d16, %e[ker0][1] \n"
"vmlal.s16 q5, d18, %e[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr4]], r0 \n"
"vmull.s16 q6, d15, %e[ker0][0] \n"
"vmlal.s16 q6, d17, %e[ker0][1] \n"
"vmlal.s16 q6, d19, %e[ker0][2] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d14, %e[ker1][0] \n"
"vmlal.s16 q14, d16, %e[ker1][1] \n"
"vmlal.s16 q14, d18, %e[ker1][2] \n"
"vmlal.s16 q15, d15, %e[ker1][0] \n"
"vmlal.s16 q15, d17, %e[ker1][1] \n"
"vmlal.s16 q15, d19, %e[ker1][2] \n"
"vmlal.s16 q5, d14, %f[ker0][0] \n"
"vmlal.s16 q5, d16, %f[ker0][1] \n"
"vmlal.s16 q5, d18, %f[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr5]], r0 \n"
"vmlal.s16 q6, d15, %f[ker0][0] \n"
"vmlal.s16 q6, d17, %f[ker0][1] \n"
"vmlal.s16 q6, d19, %f[ker0][2] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d14, %e[ker1][0] \n"
"vmlal.s16 q5, d16, %e[ker1][1] \n"
"vmlal.s16 q5, d18, %e[ker1][2] \n"
"vmlal.s16 q6, d15, %e[ker1][0] \n"
"vmlal.s16 q6, d17, %e[ker1][1] \n"
"vmlal.s16 q6, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n"
"blt store_4h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"vst1.32 {q5}, [%[output_ptr3]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d12[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_4h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"vst1.32 {d10}, [%[output_ptr3]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d11[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d10[0]}, [%[output_ptr3]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
: [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - 2)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4 - 2)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5 - 2)));
row0 = vext_s16(row0, zero, 2);
row1 = vext_s16(row1, zero, 2);
row2 = vext_s16(row2, zero, 2);
row3 = vext_s16(row3, zero, 2);
row4 = vext_s16(row4, zero, 2);
row5 = vext_s16(row5, zero, 2);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
*output_ptr2 = 0;
*output_ptr3 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
*output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
*output_ptr1 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row2, _ker[0]);
acc = vmlal_s16(acc, row3, _ker[1]);
acc = vmlal_s16(acc, row4, _ker[2]);
*output_ptr2 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row3, _ker[0]);
acc = vmlal_s16(acc, row4, _ker[1]);
acc = vmlal_s16(acc, row5, _ker[2]);
*output_ptr3 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
row3 = vext_s16(row3, zero, 1);
row4 = vext_s16(row4, zero, 1);
row5 = vext_s16(row5, zero, 1);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
output_ptr3++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xFFFC);
......@@ -712,208 +586,259 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
} else {
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
row3 = vext_s16(zero, row3, 3);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
output_ptr1[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
// loop 6 widths
"loop_2h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"loop_2h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, %f[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, %e[ker1][2] \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_2h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
"subs %[loop], #1 \n"
"bne loop_2h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain] \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
"vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, %f[ker0][2] \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n"
"blt store_2h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
: [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - 2)));
row0 = vext_s16(row0, zero, 2);
row1 = vext_s16(row1, zero, 2);
row2 = vext_s16(row2, zero, 2);
row3 = vext_s16(row3, zero, 2);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
*output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
acc = vmull_s16(row1, _ker[0]);
acc = vmlal_s16(acc, row2, _ker[1]);
acc = vmlal_s16(acc, row3, _ker[2]);
*output_ptr1 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
row3 = vext_s16(row3, zero, 1);
}
output_ptr0++;
output_ptr1++;
}
}
}
start_h = valid_h_start + (valid_h & 0xFFFE);
......@@ -921,145 +846,185 @@ void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
int32_t *output_ptr0 = output_ptr + start_h * output_w + valid_w_start;
int32_t *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0;
} else {
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
output_ptr0[w] = vgetq_lane_s32(acc, 1) + vgetq_lane_s32(acc, 2);
}
}
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"loop_1h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain] \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "r0");
: [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - 2)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - 2)));
row0 = vext_s16(row0, zero, 2);
row1 = vext_s16(row1, zero, 2);
row2 = vext_s16(row2, zero, 2);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
*output_ptr0 = vgetq_lane_s32(acc, 0) + vgetq_lane_s32(acc, 1);
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
}
}
......@@ -1081,11 +1046,13 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2;
int valid_h_end = output_h - valid_h_start;
int valid_h_end = (input_h + padding_h - 1) / 2;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = output_w - valid_w_start;
int valid_w_end = (input_w + padding_w - 1) / 2;
int valid_w = valid_w_end - valid_w_start;
// for pad left
int valid_input_w_start = (valid_w_start << 1) - padding_w;
// DLOG << "valid_h_start: " << valid_h_start;
// DLOG << "valid_h_end: " << valid_h_end;
......@@ -1097,459 +1064,579 @@ void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 9;
int32_t *output_ptr = out_data + g * out_image_size;
const int8_t *filter_ptr0 = filter_ptr;
const int8_t *filter_ptr1 = filter_ptr0 + 3;
const int8_t *filter_ptr2 = filter_ptr1 + 3;
int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0)));
int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1)));
int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2)));
int16x8_t _ker0 = vcombine_s16(_k0, _k1);
int16x8_t _ker1 = vcombine_s16(_k2, _k2);
int16x4_t _ker[3] = {_k0, _k1, _k2};
// top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
DepthwiseConv3x3ValidCol<2, 2>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
DepthwiseConv3x3ValidCol<2, 2>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
output_ptr, _ker);
}
// valid
int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const int8_t *input_ptr0 = input_ptr + offset;
const int8_t *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w;
const int8_t *input_ptr6 = input_ptr5 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w;
int32_t *output_ptr2 = output_ptr1 + output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
output_ptr2[w] = 0;
} else {
int16x4_t row0 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - padding)));
int16x4_t row1 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - padding)));
int16x4_t row2 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - padding)));
int16x4_t row3 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr3 - padding)));
int16x4_t row4 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr4 - padding)));
int16x4_t row5 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr5 - padding)));
int16x4_t row6 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr6 - padding)));
int32x4_t acc0 = vmull_s16(row0, _ker[0]);
acc0 = vmlal_s16(acc0, row1, _ker[1]);
acc0 = vmlal_s16(acc0, row2, _ker[2]);
int32x4_t acc1 = vmull_s16(row2, _ker[0]);
acc1 = vmlal_s16(acc1, row3, _ker[1]);
acc1 = vmlal_s16(acc1, row4, _ker[2]);
int32x4_t acc2 = vmull_s16(row4, _ker[0]);
acc2 = vmlal_s16(acc2, row5, _ker[1]);
acc2 = vmlal_s16(acc2, row6, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc0, 2);
int32_t sum1 = vgetq_lane_s32(acc1, 2);
int32_t sum2 = vgetq_lane_s32(acc2, 2);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc0, 1);
sum1 += vgetq_lane_s32(acc1, 1);
sum2 += vgetq_lane_s32(acc2, 1);
}
output_ptr0[w] = sum0;
output_ptr1[w] = sum1;
output_ptr2[w] = sum2;
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
input_ptr3 += valid_input_w_start;
input_ptr4 += valid_input_w_start;
input_ptr5 += valid_input_w_start;
input_ptr6 += valid_input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #12 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #12 \n"
// loop 6 widths
"loop_3h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
"loop_3h6w_%=: \n"
"vld2.8 {d10-d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12-d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14-d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, %f[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, %e[ker1][2] \n"
// store row 0, reuse q11/q12
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"vmull.s16 q13, d16, d0 \n"
"vmlal.s16 q13, d18, d1 \n"
"vmlal.s16 q13, d20, d2 \n"
"vmull.s16 q14, d17, d0 \n"
"vmlal.s16 q14, d19, d1 \n"
"vmlal.s16 q14, d21, d2 \n"
"vld2.8 {d10, d11}, [%[input_ptr3]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr4]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr5]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q13, d16, d3 \n"
"vmlal.s16 q13, d18, d4 \n"
"vmlal.s16 q13, d20, d5 \n"
"vmlal.s16 q14, d17, d3 \n"
"vmlal.s16 q14, d19, d4 \n"
"vmlal.s16 q14, d21, d5 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q13, d16, d6 \n"
"vmlal.s16 q13, d18, d7 \n"
"vmlal.s16 q13, d20, d8 \n"
"vmlal.s16 q14, d17, d6 \n"
"vmlal.s16 q14, d19, d7 \n"
"vmlal.s16 q14, d21, d8 \n"
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"vmull.s16 q13, d16, %e[ker0][0] \n"
"vmlal.s16 q13, d18, %e[ker0][1] \n"
"vmlal.s16 q13, d20, %e[ker0][2] \n"
"vmull.s16 q14, d17, %e[ker0][0] \n"
"vmlal.s16 q14, d19, %e[ker0][1] \n"
"vmlal.s16 q14, d21, %e[ker0][2] \n"
"vld2.8 {d10-d11}, [%[input_ptr3]], r0 \n"
"vld2.8 {d12-d13}, [%[input_ptr4]], r0 \n"
"vld2.8 {d14-d15}, [%[input_ptr5]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q13, d16, %f[ker0][0] \n"
"vmlal.s16 q13, d18, %f[ker0][1] \n"
"vmlal.s16 q13, d20, %f[ker0][2] \n"
"vmlal.s16 q14, d17, %f[ker0][0] \n"
"vmlal.s16 q14, d19, %f[ker0][1] \n"
"vmlal.s16 q14, d21, %f[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q13, d16, %e[ker1][0] \n"
"vmlal.s16 q13, d18, %e[ker1][1] \n"
"vmlal.s16 q13, d20, %e[ker1][2] \n"
"vmlal.s16 q14, d17, %e[ker1][0] \n"
"vmlal.s16 q14, d19, %e[ker1][1] \n"
"vmlal.s16 q14, d21, %e[ker1][2] \n"
// store row 1
"vst1.32 {d26-d28}, [%[output_ptr1]]! \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr6]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
"vst1.32 {d26-d28}, [%[output_ptr1]]! \n"
"vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, %f[ker0][2] \n"
"vld2.8 {d10-d11}, [%[input_ptr6]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, %e[ker1][2] \n"
// store row 2
"vst1.32 {d22-d24}, [%[output_ptr2]]! \n"
"subs %[loop], #1 \n"
"bne loop_3h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld2.8 {d10, d11}, [%[input_ptr0]] \n"
"vld2.8 {d12, d13}, [%[input_ptr1]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr2]] \n"
"vld2.8 {d12, d13}, [%[input_ptr3]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr4]] \n"
"vld2.8 {d12, d13}, [%[input_ptr5]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr6]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_3h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"b end_%= \n"
"store_3h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_3h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"b end_%= \n"
"store_3h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"end_%=: \n"
"vst1.32 {d22-d24}, [%[output_ptr2]]! \n"
"subs %[loop], #1 \n"
"bne loop_3h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain], lsl #1 \n"
"vld2.8 {d10-d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12-d13}, [%[input_ptr1]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmull.s16 q10, d14, %e[ker0][0] \n"
"vmlal.s16 q10, d16, %e[ker0][1] \n"
"vmlal.s16 q10, d18, %e[ker0][2] \n"
"vmull.s16 q11, d15, %e[ker0][0] \n"
"vmlal.s16 q11, d17, %e[ker0][1] \n"
"vmlal.s16 q11, d19, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q10, d14, %f[ker0][0] \n"
"vmlal.s16 q10, d16, %f[ker0][1] \n"
"vmlal.s16 q10, d18, %f[ker0][2] \n"
"vmlal.s16 q11, d15, %f[ker0][0] \n"
"vmlal.s16 q11, d17, %f[ker0][1] \n"
"vmlal.s16 q11, d19, %f[ker0][2] \n"
"vld2.8 {d10-d11}, [%[input_ptr2]], r0 \n"
"vld2.8 {d12-d13}, [%[input_ptr3]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q10, d14, %e[ker1][0] \n"
"vmlal.s16 q10, d16, %e[ker1][1] \n"
"vmlal.s16 q10, d18, %e[ker1][2] \n"
"vmlal.s16 q11, d15, %e[ker1][0] \n"
"vmlal.s16 q11, d17, %e[ker1][1] \n"
"vmlal.s16 q11, d19, %e[ker1][2] \n"
"vmull.s16 q12, d14, %e[ker0][0] \n"
"vmlal.s16 q12, d16, %e[ker0][1] \n"
"vmlal.s16 q12, d18, %e[ker0][2] \n"
"vmull.s16 q13, d15, %e[ker0][0] \n"
"vmlal.s16 q13, d17, %e[ker0][1] \n"
"vmlal.s16 q13, d19, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q12, d14, %f[ker0][0] \n"
"vmlal.s16 q12, d16, %f[ker0][1] \n"
"vmlal.s16 q12, d18, %f[ker0][2] \n"
"vmlal.s16 q13, d15, %f[ker0][0] \n"
"vmlal.s16 q13, d17, %f[ker0][1] \n"
"vmlal.s16 q13, d19, %f[ker0][2] \n"
"vld2.8 {d10-d11}, [%[input_ptr4]], r0 \n"
"vld2.8 {d12-d13}, [%[input_ptr5]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q12, d14, %e[ker1][0] \n"
"vmlal.s16 q12, d16, %e[ker1][1] \n"
"vmlal.s16 q12, d18, %e[ker1][2] \n"
"vmlal.s16 q13, d15, %e[ker1][0] \n"
"vmlal.s16 q13, d17, %e[ker1][1] \n"
"vmlal.s16 q13, d19, %e[ker1][2] \n"
"vmull.s16 q14, d14, %e[ker0][0] \n"
"vmlal.s16 q14, d16, %e[ker0][1] \n"
"vmlal.s16 q14, d18, %e[ker0][2] \n"
"vmull.s16 q15, d15, %e[ker0][0] \n"
"vmlal.s16 q15, d17, %e[ker0][1] \n"
"vmlal.s16 q15, d19, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q14, d14, %f[ker0][0] \n"
"vmlal.s16 q14, d16, %f[ker0][1] \n"
"vmlal.s16 q14, d18, %f[ker0][2] \n"
"vmlal.s16 q15, d15, %f[ker0][0] \n"
"vmlal.s16 q15, d17, %f[ker0][1] \n"
"vmlal.s16 q15, d19, %f[ker0][2] \n"
"vld2.8 {d10-d11}, [%[input_ptr6]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q14, d14, %e[ker1][0] \n"
"vmlal.s16 q14, d16, %e[ker1][1] \n"
"vmlal.s16 q14, d18, %e[ker1][2] \n"
"vmlal.s16 q15, d15, %e[ker1][0] \n"
"vmlal.s16 q15, d17, %e[ker1][1] \n"
"vmlal.s16 q15, d19, %e[ker1][2] \n"
"cmp %[remain], #4 \n"
"blt store_3h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"b end_%= \n"
"store_3h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_3h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"b end_%= \n"
"store_3h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [input_ptr6] "+r"(input_ptr6),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
: [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w > 0) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5)));
int16x4_t row6 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr6)));
int32x4_t acc0, acc1, acc2;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
*output_ptr1 = 0;
*output_ptr2 = 0;
} else {
acc0 = vmull_s16(row0, _ker[0]);
acc0 = vmlal_s16(acc0, row1, _ker[1]);
acc0 = vmlal_s16(acc0, row2, _ker[2]);
acc1 = vmull_s16(row2, _ker[0]);
acc1 = vmlal_s16(acc1, row3, _ker[1]);
acc1 = vmlal_s16(acc1, row4, _ker[2]);
acc2 = vmull_s16(row4, _ker[0]);
acc2 = vmlal_s16(acc2, row5, _ker[1]);
acc2 = vmlal_s16(acc2, row6, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc0, 0);
int32_t sum1 = vgetq_lane_s32(acc1, 0);
int32_t sum2 = vgetq_lane_s32(acc2, 0);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc0, 1);
sum1 += vgetq_lane_s32(acc1, 1);
sum2 += vgetq_lane_s32(acc2, 1);
}
*output_ptr0 = sum0;
*output_ptr1 = sum1;
*output_ptr2 = sum2;
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
}
}
}
// remain height
int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const int8_t *input_ptr0 = input_ptr + offset;
const int8_t *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr0 = output_ptr + h * output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0;
} else {
int16x4_t row0 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - padding)));
int16x4_t row1 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr1 - padding)));
int16x4_t row2 =
vget_low_s16(vmovl_s8(vld1_s8(input_ptr2 - padding)));
int32x4_t acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc, 2);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc, 1);
}
output_ptr0[w] = sum0;
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #12 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #12 \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, %f[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, %e[ker1][2] \n"
// store row 0
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld2.8 {d10, d11}, [%[input_ptr0]] \n"
"vld2.8 {d12, d13}, [%[input_ptr1]] \n"
"vld2.8 {d14, d15}, [%[input_ptr2]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q11}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d24[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d22}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d23[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain], lsl #1 \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, %e[ker0][0] \n"
"vmlal.s16 q11, d18, %e[ker0][1] \n"
"vmlal.s16 q11, d20, %e[ker0][2] \n"
"vmull.s16 q12, d17, %e[ker0][0] \n"
"vmlal.s16 q12, d19, %e[ker0][1] \n"
"vmlal.s16 q12, d21, %e[ker0][2] \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, %f[ker0][0] \n"
"vmlal.s16 q11, d18, %f[ker0][1] \n"
"vmlal.s16 q11, d20, %f[ker0][2] \n"
"vmlal.s16 q12, d17, %f[ker0][0] \n"
"vmlal.s16 q12, d19, %f[ker0][1] \n"
"vmlal.s16 q12, d21, %f[ker0][2] \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, %e[ker1][0] \n"
"vmlal.s16 q11, d18, %e[ker1][1] \n"
"vmlal.s16 q11, d20, %e[ker1][2] \n"
"vmlal.s16 q12, d17, %e[ker1][0] \n"
"vmlal.s16 q12, d19, %e[ker1][1] \n"
"vmlal.s16 q12, d21, %e[ker1][2] \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q11}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d24[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d22}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d23[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "r0");
: [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w > 0) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
int32_t sum0 = vgetq_lane_s32(acc, 0);
if (padding == 1) {
sum0 += vgetq_lane_s32(acc, 1);
}
*output_ptr0 = sum0;
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker);
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) && !defined(__aarch64__)
#include "operators/math/depthwise_conv5x5.h"
#include <arm_neon.h>
namespace paddle_mobile {
namespace operators {
namespace math {
#ifndef __aarch64__
inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) {
float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0));
float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1));
return vcombine_f32(sum0, sum1);
}
#endif
template <int Stride = 1>
inline void Depth5x5NormalRowLoadInput(const float *input, float32x4_t *y) {
y[0] = vld1q_f32(input);
y[4] = vld1q_f32(input + 4);
y[1] = vextq_f32(y[0], y[4], 1);
y[2] = vextq_f32(y[0], y[4], 2);
y[3] = vextq_f32(y[0], y[4], 3);
}
template <>
inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) {
float32x4x2_t x = vld2q_f32(input);
y[0] = x.val[0];
y[1] = x.val[1];
y[2] = vextq_f32(y[0], y[0], 1);
y[3] = vextq_f32(y[1], y[1], 1);
y[4] = vextq_f32(y[0], y[0], 2);
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 5; \
const int w_start = w_in_start > 0 ? w_in_start : 0; \
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
float value = 0; \
for (int h_in = h_start; h_in < h_end; ++h_in) { \
for (int w_in = w_start; w_in < w_end; ++w_in) { \
value += filter[(h_in - h_in_start) * 5 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
} \
} \
output_ptr[w] = value; \
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter,
const int h_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
float *output, float32x4_t *ker,
float32_t *ker1) {
const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 5;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start;
float *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
// middle
int output_tiles = (valid_w_end - valid_w_start) >> 2;
float32x4_t _sum, _x[5];
// valid w
for (int w = 0; w < output_tiles * 4; w += 4) {
_sum = vdupq_n_f32(0.f);
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_n_f32(_sum, _x[0], ker1[index]);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[3], vget_high_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1);
}
vst1q_f32(output_ptr + output_offset, _sum);
}
// remain valid w
int remain = (valid_w_end - valid_w_start) & 0x3;
if (remain > 0) {
_sum = vdupq_n_f32(0.f);
int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w;
float *output_ptr0 = output_ptr + remain_start;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlaq_n_f32(_sum, _x[0], ker1[index]);
_sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[2], vget_low_f32(ker[index]), 1);
_sum = vmlaq_lane_f32(_sum, _x[3], vget_high_f32(ker[index]), 0);
_sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1);
}
switch (remain) {
case 1:
vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(_sum));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(_sum));
vst1_lane_f32(output_ptr0 + 2, vget_high_f32(_sum), 0);
break;
}
}
// border right
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
}
template <>
void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 25;
float *output_ptr = out_data + g * out_image_size;
const float *filter_ptr0 = filter_ptr;
const float *filter_ptr1 = filter_ptr0 + 5;
const float *filter_ptr2 = filter_ptr1 + 5;
const float *filter_ptr3 = filter_ptr2 + 5;
const float *filter_ptr4 = filter_ptr3 + 5;
float32x4_t _ker[7];
float32_t _ker1[5] = {*filter_ptr0, *filter_ptr1, *filter_ptr2,
*filter_ptr3, *filter_ptr4};
_ker[0] = vld1q_f32(filter_ptr0 + 1);
_ker[1] = vld1q_f32(filter_ptr1 + 1);
_ker[2] = vld1q_f32(filter_ptr2 + 1);
_ker[3] = vld1q_f32(filter_ptr3 + 1);
_ker[4] = vld1q_f32(filter_ptr4 + 1);
_ker[5] = vld1q_f32(_ker1);
_ker[6] = vld1q_f32(_ker1 + 4);
// pad top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker, _ker1);
}
// output 4x4
int output_w_tiles = valid_w / 4;
int output_w_remain = valid_w - output_w_tiles * 4;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t row5 = vld1q_f32(input_ptr5);
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc0, acc1;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 5) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
} else {
acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc0 = vmlaq_f32(acc0, row3, _ker[3]);
acc0 = vmlaq_f32(acc0, row4, _ker[4]);
acc1 = vmulq_f32(row1, _ker[0]);
acc1 = vmlaq_f32(acc1, row2, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[2]);
acc1 = vmlaq_f32(acc1, row4, _ker[3]);
acc1 = vmlaq_f32(acc1, row5, _ker[4]);
acc0 = vpaddq_f32(acc0, acc1);
float32x2_t sum =
vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
vst1_lane_f32(output_ptr0 + w, sum, 0);
vst1_lane_f32(output_ptr1 + w, sum, 1);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
row4 = vextq_f32(zero, row4, 3);
row5 = vextq_f32(zero, row5, 3);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #16 \n"
"loop_2h4w_%=: \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vmul.f32 q15, q9, %e[ker0][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vmla.f32 q15, q13, %e[kr0][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vmla.f32 q15, q13, %e[kr0][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q15, q13, %f[kr0][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q15, q10, %f[kr0][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vmla.f32 q15, q11, %e[ker0][1] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vmla.f32 q15, q13, %e[kr1][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vmla.f32 q15, q13, %e[kr1][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q15, q13, %f[kr1][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vmla.f32 q15, q12, %f[kr1][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr5]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vmla.f32 q15, q7, %f[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vmla.f32 q15, q13, %e[kr2][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vmla.f32 q15, q13, %e[kr2][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q15, q13, %f[kr2][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q15, q8, %f[kr2][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vmla.f32 q15, q9, %f[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vmla.f32 q15, q13, %e[kr3][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vmla.f32 q15, q13, %e[kr3][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q15, q13, %f[kr3][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
"vmla.f32 q15, q10, %f[kr3][1] \n"
"vmla.f32 q15, q11, %e[ker1][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q15, q13, %e[kr4][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q15, q13, %e[kr4][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q15, q13, %f[kr4][0] \n"
"vmla.f32 q15, q12, %f[kr4][1] \n"
// restore output
"vst1.32 {q14}, [%[output_ptr0]]! \n"
"vst1.32 {q15}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h4w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain], lsl #2 \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vmul.f32 q15, q9, %e[ker0][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vmla.f32 q15, q13, %e[kr0][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vmla.f32 q15, q13, %e[kr0][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q15, q13, %f[kr0][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q15, q10, %f[kr0][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vmla.f32 q15, q11, %e[ker0][1] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vmla.f32 q15, q13, %e[kr1][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vmla.f32 q15, q13, %e[kr1][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q15, q13, %f[kr1][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vmla.f32 q15, q12, %f[kr1][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr5]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vmla.f32 q15, q7, %f[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vmla.f32 q15, q13, %e[kr2][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vmla.f32 q15, q13, %e[kr2][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q15, q13, %f[kr2][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q15, q8, %f[kr2][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vmla.f32 q15, q9, %f[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vmla.f32 q15, q13, %e[kr3][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vmla.f32 q15, q13, %e[kr3][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q15, q13, %f[kr3][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
"vmla.f32 q15, q10, %f[kr3][1] \n"
"vmla.f32 q15, q11, %e[ker1][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q15, q13, %e[kr4][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q15, q13, %e[kr4][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q15, q13, %f[kr4][0] \n"
"vmla.f32 q15, q12, %f[kr4][1] \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d28}, [%[output_ptr0]]! \n"
"vst1.32 {d30}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d29[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d31[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"vst1.32 {d28[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain), [kr0] "w"(_ker[0]),
[kr1] "w"(_ker[1]), [kr2] "w"(_ker[2]), [kr3] "w"(_ker[3]),
[kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6])
: "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15", "r0");
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t row5 = vld1q_f32(input_ptr5);
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
} else {
int iw = w - valid_w_end;
float sum0 = input_ptr0[iw] * filter_ptr0[0] +
input_ptr1[iw] * filter_ptr1[0] +
input_ptr2[iw] * filter_ptr2[0] +
input_ptr3[iw] * filter_ptr3[0] +
input_ptr4[iw] * filter_ptr4[0];
float sum1 = input_ptr1[iw] * filter_ptr0[0] +
input_ptr2[iw] * filter_ptr1[0] +
input_ptr3[iw] * filter_ptr2[0] +
input_ptr4[iw] * filter_ptr3[0] +
input_ptr5[iw] * filter_ptr4[0];
row0 = vextq_f32(row0, zero, 1);
row1 = vextq_f32(row1, zero, 1);
row2 = vextq_f32(row2, zero, 1);
row3 = vextq_f32(row3, zero, 1);
row4 = vextq_f32(row4, zero, 1);
row5 = vextq_f32(row5, zero, 1);
acc0 = vmulq_f32(row0, _ker[0]);
acc0 = vmlaq_f32(acc0, row1, _ker[1]);
acc0 = vmlaq_f32(acc0, row2, _ker[2]);
acc0 = vmlaq_f32(acc0, row3, _ker[3]);
acc0 = vmlaq_f32(acc0, row4, _ker[4]);
acc1 = vmulq_f32(row1, _ker[0]);
acc1 = vmlaq_f32(acc1, row2, _ker[1]);
acc1 = vmlaq_f32(acc1, row3, _ker[2]);
acc1 = vmlaq_f32(acc1, row4, _ker[3]);
acc1 = vmlaq_f32(acc1, row5, _ker[4]);
acc0 = vpaddq_f32(acc0, acc1);
float32x2_t sum =
vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
sum0 += vget_lane_f32(sum, 0);
sum1 += vget_lane_f32(sum, 1);
*output_ptr0 = sum0;
*output_ptr1 = sum1;
}
output_ptr0++;
output_ptr1++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
if (start_h < valid_h_end) {
const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
float *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 5) {
output_ptr0[w] = 0.f;
} else {
acc = vmulq_f32(row0, _ker[0]);
acc = vmlaq_f32(acc, row1, _ker[1]);
acc = vmlaq_f32(acc, row2, _ker[2]);
acc = vmlaq_f32(acc, row3, _ker[3]);
acc = vmlaq_f32(acc, row4, _ker[4]);
float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc));
sum = vpadd_f32(sum, sum);
vst1_lane_f32(output_ptr0 + w, sum, 0);
row0 = vextq_f32(zero, row0, 3);
row1 = vextq_f32(zero, row1, 3);
row2 = vextq_f32(zero, row2, 3);
row3 = vextq_f32(zero, row3, 3);
row4 = vextq_f32(zero, row4, 3);
}
}
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles;
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #16 \n"
"loop_1h4w_%=: \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
// restore output
"vst1.32 {q14}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h4w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain], lsl #2 \n"
"vld1.32 {d14-d17}, [%[input_ptr0]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr1]], r0 \n"
"vld1.32 {d22-d25}, [%[input_ptr2]], r0 \n"
"vmul.f32 q14, q7, %e[ker0][0] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr0][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr0][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr0][0] \n"
"vmla.f32 q14, q8, %f[kr0][1] \n"
"vmla.f32 q14, q9, %e[ker0][1] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr1][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr1][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr1][0] \n"
"vmla.f32 q14, q10, %f[kr1][1] \n"
"vmla.f32 q14, q11, %f[ker0][0] \n"
"vext.32 q13, q11, q12, #1 \n"
"vmla.f32 q14, q13, %e[kr2][0] \n"
"vext.32 q13, q11, q12, #2 \n"
"vmla.f32 q14, q13, %e[kr2][1] \n"
"vext.32 q13, q11, q12, #3 \n"
"vmla.f32 q14, q13, %f[kr2][0] \n"
"vmla.f32 q14, q12, %f[kr2][1] \n"
"vld1.32 {d14-d17}, [%[input_ptr3]], r0 \n"
"vld1.32 {d18-d21}, [%[input_ptr4]], r0 \n"
"vmla.f32 q14, q7, %f[ker0][1] \n"
"vext.32 q13, q7, q8, #1 \n"
"vmla.f32 q14, q13, %e[kr3][0] \n"
"vext.32 q13, q7, q8, #2 \n"
"vmla.f32 q14, q13, %e[kr3][1] \n"
"vext.32 q13, q7, q8, #3 \n"
"vmla.f32 q14, q13, %f[kr3][0] \n"
"vmla.f32 q14, q8, %f[kr3][1] \n"
"vmla.f32 q14, q9, %e[ker1][0] \n"
"vext.32 q13, q9, q10, #1 \n"
"vmla.f32 q14, q13, %e[kr4][0] \n"
"vext.32 q13, q9, q10, #2 \n"
"vmla.f32 q14, q13, %e[kr4][1] \n"
"vext.32 q13, q9, q10, #3 \n"
"vmla.f32 q14, q13, %f[kr4][0] \n"
"vmla.f32 q14, q10, %f[kr4][1] \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d28}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d29[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"vst1.32 {d28[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [output_ptr0] "+r"(output_ptr0),
[loop] "+r"(loop)
: [remain] "r"(output_w_remain), [kr0] "w"(_ker[0]),
[kr1] "w"(_ker[1]), [kr2] "w"(_ker[2]), [kr3] "w"(_ker[3]),
[kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6])
: "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15", "r0");
// pad right
if (padding_w) {
float32x4_t row0 = vld1q_f32(input_ptr0);
float32x4_t row1 = vld1q_f32(input_ptr1);
float32x4_t row2 = vld1q_f32(input_ptr2);
float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) {
*output_ptr0 = 0.f;
} else {
int iw = w - valid_w_end;
float sum0 = input_ptr0[iw] * filter_ptr0[0] +
input_ptr1[iw] * filter_ptr1[0] +
input_ptr2[iw] * filter_ptr2[0] +
input_ptr3[iw] * filter_ptr3[0] +
input_ptr4[iw] * filter_ptr4[0];
row0 = vextq_f32(row0, zero, 1);
row1 = vextq_f32(row1, zero, 1);
row2 = vextq_f32(row2, zero, 1);
row3 = vextq_f32(row3, zero, 1);
row4 = vextq_f32(row4, zero, 1);
acc = vmulq_f32(row0, _ker[0]);
acc = vmlaq_f32(acc, row1, _ker[1]);
acc = vmlaq_f32(acc, row2, _ker[2]);
acc = vmlaq_f32(acc, row3, _ker[3]);
acc = vmlaq_f32(acc, row4, _ker[4]);
float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc));
sum = vpadd_f32(sum, sum);
sum0 += vget_lane_f32(sum, 0);
*output_ptr0 = sum0;
}
output_ptr0++;
}
}
}
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker, _ker1);
}
}
}
template <>
void DepthwiseConv5x5S2<float, float>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON__
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv5x5(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv5x5S1(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv5x5S2(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) && !defined(__aarch64__)
#include <arm_neon.h>
#include "operators/math/depthwise_conv5x5.h"
namespace paddle_mobile {
namespace operators {
namespace math {
#ifndef __aarch64__
inline int32x4_t vpaddq_s32(int32x4_t r0, int32x4_t r1) {
int32x2_t sum0 = vpadd_s32(vget_low_s32(r0), vget_high_s32(r0));
int32x2_t sum1 = vpadd_s32(vget_low_s32(r1), vget_high_s32(r1));
return vcombine_s32(sum0, sum1);
}
#endif
template <int Stride = 1>
inline void Depth5x5NormalRowLoadInput(const int8_t *input, int16x4_t *y) {
int16x8_t x = vmovl_s8(vld1_s8(input));
y[0] = vget_low_s16(x);
y[4] = vget_high_s16(x);
y[1] = vext_s16(y[0], y[4], 1);
y[2] = vext_s16(y[0], y[4], 2);
y[3] = vext_s16(y[0], y[4], 3);
}
template <>
inline void Depth5x5NormalRowLoadInput<2>(const int8_t *input, int16x4_t *y) {
int8x8x2_t x = vld2_s8(input);
y[0] = vget_low_s16(vmovl_s8(x.val[0]));
y[1] = vget_low_s16(vmovl_s8(x.val[1]));
y[2] = vext_s16(y[0], y[0], 1);
y[3] = vext_s16(y[1], y[1], 1);
y[4] = vext_s16(y[0], y[0], 2);
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 5; \
const int w_start = w_in_start > 0 ? w_in_start : 0; \
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
int32_t value = 0; \
for (int h_in = h_start; h_in < h_end; ++h_in) { \
for (int w_in = w_start; w_in < w_end; ++w_in) { \
value += filter[(h_in - h_in_start) * 5 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
} \
} \
output_ptr[w] = value; \
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv5x5NormalRow(const int8_t *input, const int8_t *filter,
const int h_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
int32_t *output, int16x4_t *ker,
int16_t *ker1) {
const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 5;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start;
int32_t *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
// middle
int output_tiles = (valid_w_end - valid_w_start) >> 2;
int16x4_t _x[5];
int32x4_t _sum;
// valid w
for (int w = 0; w < output_tiles * 4; w += 4) {
_sum = vdupq_n_s32(0);
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlal_n_s16(_sum, _x[0], ker1[index]);
_sum = vmlal_lane_s16(_sum, _x[1], ker[index], 0);
_sum = vmlal_lane_s16(_sum, _x[2], ker[index], 1);
_sum = vmlal_lane_s16(_sum, _x[3], ker[index], 2);
_sum = vmlal_lane_s16(_sum, _x[4], ker[index], 3);
}
vst1q_s32(output_ptr + output_offset, _sum);
}
// remain valid w
int remain = (valid_w_end - valid_w_start) & 0x3;
if (remain > 0) {
_sum = vdupq_n_s32(0);
int remain_start = valid_w_start + (output_tiles << 2);
int input_w_offset = remain_start * Stride_w - padding_w;
int32_t *output_ptr0 = output_ptr + remain_start;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth5x5NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, _x);
_sum = vmlal_n_s16(_sum, _x[0], ker1[index]);
_sum = vmlal_lane_s16(_sum, _x[1], ker[index], 0);
_sum = vmlal_lane_s16(_sum, _x[2], ker[index], 1);
_sum = vmlal_lane_s16(_sum, _x[3], ker[index], 2);
_sum = vmlal_lane_s16(_sum, _x[4], ker[index], 3);
}
switch (remain) {
case 1:
vst1_lane_s32(output_ptr0, vget_low_s32(_sum), 0);
break;
case 2:
vst1_s32(output_ptr0, vget_low_s32(_sum));
break;
case 3:
vst1_s32(output_ptr0, vget_low_s32(_sum));
vst1_lane_s32(output_ptr0 + 2, vget_high_s32(_sum), 0);
break;
}
}
// border right
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
}
template <>
void DepthwiseConv5x5S1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) {
const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 25;
int32_t *output_ptr = out_data + g * out_image_size;
const int8_t *filter_ptr0 = filter_ptr;
const int8_t *filter_ptr1 = filter_ptr0 + 5;
const int8_t *filter_ptr2 = filter_ptr1 + 5;
const int8_t *filter_ptr3 = filter_ptr2 + 5;
const int8_t *filter_ptr4 = filter_ptr3 + 5;
int16_t kernel[5] = {*filter_ptr0, *filter_ptr1, *filter_ptr2, *filter_ptr3,
*filter_ptr4};
int16x4_t _k0 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr0 + 1)));
int16x4_t _k1 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr1 + 1)));
int16x4_t _k2 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr2 + 1)));
int16x4_t _k3 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr3 + 1)));
int16x4_t _k4 = vget_low_s16(vmovl_s8(vld1_s8(filter_ptr4 + 1)));
int16x4_t _k5 = vld1_s16(kernel);
int16x4_t _k6 = vld1_s16(kernel + 4);
int16x8_t _ker0 = vcombine_s16(_k0, _k1);
int16x8_t _ker1 = vcombine_s16(_k2, _k3);
int16x8_t _ker2 = vcombine_s16(_k4, _k5);
int16x8_t _ker3 = vcombine_s16(_k6, _k6);
int16x4_t _ker[7] = {_k0, _k1, _k2, _k3, _k4, _k5, _k6};
// pad top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker, kernel);
}
// output 4x4
int output_w_tiles = valid_w / 8;
int output_w_remain = valid_w - output_w_tiles * 8;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w;
int32_t *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5)));
int16x4_t zero = vdup_n_s16(0);
int32x4_t acc0, acc1;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 5) {
output_ptr0[w] = 0;
output_ptr1[w] = 0;
} else {
acc0 = vmull_s16(row0, _ker[0]);
acc0 = vmlal_s16(acc0, row1, _ker[1]);
acc0 = vmlal_s16(acc0, row2, _ker[2]);
acc0 = vmlal_s16(acc0, row3, _ker[3]);
acc0 = vmlal_s16(acc0, row4, _ker[4]);
acc1 = vmull_s16(row1, _ker[0]);
acc1 = vmlal_s16(acc1, row2, _ker[1]);
acc1 = vmlal_s16(acc1, row3, _ker[2]);
acc1 = vmlal_s16(acc1, row4, _ker[3]);
acc1 = vmlal_s16(acc1, row5, _ker[4]);
acc0 = vpaddq_s32(acc0, acc1);
int32x2_t sum = vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
vst1_lane_s32(output_ptr0 + w, sum, 0);
vst1_lane_s32(output_ptr1 + w, sum, 1);
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
row3 = vext_s16(zero, row3, 3);
row4 = vext_s16(zero, row4, 3);
row5 = vext_s16(zero, row5, 3);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
int loop = output_w_tiles;
int w_remain = output_w_remain;
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain4_%= \n"
"mov r0, #8 \n"
"loop_2h8w_%=: \n"
"vld1.s8 {d10-d11}, [%[input_ptr0]], r0 \n"
"vld1.s8 {d12-d13}, [%[input_ptr1]], r0 \n"
"vld1.s8 {d14-d15}, [%[input_ptr2]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q12, d16, %f[ker2][0] \n"
"vmull.s16 q13, d17, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker0][0] \n"
"vmlal.s16 q13, d21, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker0][1] \n"
"vmlal.s16 q13, d21, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker0][2] \n"
"vmlal.s16 q13, d21, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker0][3] \n"
"vmlal.s16 q13, d21, %e[ker0][3] \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d16, %f[ker2][1] \n"
"vmlal.s16 q13, d17, %f[ker2][1] \n"
"vmull.s16 q14, d16, %f[ker2][0] \n"
"vmull.s16 q15, d17, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker0][0] \n"
"vmlal.s16 q13, d21, %f[ker0][0] \n"
"vmlal.s16 q14, d20, %e[ker0][0] \n"
"vmlal.s16 q15, d21, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker0][1] \n"
"vmlal.s16 q13, d21, %f[ker0][1] \n"
"vmlal.s16 q14, d20, %e[ker0][1] \n"
"vmlal.s16 q15, d21, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker0][2] \n"
"vmlal.s16 q13, d21, %f[ker0][2] \n"
"vmlal.s16 q14, d20, %e[ker0][2] \n"
"vmlal.s16 q15, d21, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker0][3] \n"
"vmlal.s16 q13, d21, %f[ker0][3] \n"
"vmlal.s16 q14, d20, %e[ker0][3] \n"
"vmlal.s16 q15, d21, %e[ker0][3] \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q12, d16, %f[ker2][2] \n"
"vmlal.s16 q13, d17, %f[ker2][2] \n"
"vmlal.s16 q14, d16, %f[ker2][1] \n"
"vmlal.s16 q15, d17, %f[ker2][1] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker1][0] \n"
"vmlal.s16 q13, d21, %e[ker1][0] \n"
"vmlal.s16 q14, d20, %f[ker0][0] \n"
"vmlal.s16 q15, d21, %f[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker1][1] \n"
"vmlal.s16 q13, d21, %e[ker1][1] \n"
"vmlal.s16 q14, d20, %f[ker0][1] \n"
"vmlal.s16 q15, d21, %f[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker1][2] \n"
"vmlal.s16 q13, d21, %e[ker1][2] \n"
"vmlal.s16 q14, d20, %f[ker0][2] \n"
"vmlal.s16 q15, d21, %f[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker1][3] \n"
"vmlal.s16 q13, d21, %e[ker1][3] \n"
"vmlal.s16 q14, d20, %f[ker0][3] \n"
"vmlal.s16 q15, d21, %f[ker0][3] \n"
"vld1.s8 {d10-d11}, [%[input_ptr3]], r0 \n"
"vld1.s8 {d12-d13}, [%[input_ptr4]], r0 \n"
"vld1.s8 {d14-d15}, [%[input_ptr5]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q12, d16, %f[ker2][3] \n"
"vmlal.s16 q13, d17, %f[ker2][3] \n"
"vmlal.s16 q14, d16, %f[ker2][2] \n"
"vmlal.s16 q15, d17, %f[ker2][2] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker1][0] \n"
"vmlal.s16 q13, d21, %f[ker1][0] \n"
"vmlal.s16 q14, d20, %e[ker1][0] \n"
"vmlal.s16 q15, d21, %e[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker1][1] \n"
"vmlal.s16 q13, d21, %f[ker1][1] \n"
"vmlal.s16 q14, d20, %e[ker1][1] \n"
"vmlal.s16 q15, d21, %e[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker1][2] \n"
"vmlal.s16 q13, d21, %f[ker1][2] \n"
"vmlal.s16 q14, d20, %e[ker1][2] \n"
"vmlal.s16 q15, d21, %e[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker1][3] \n"
"vmlal.s16 q13, d21, %f[ker1][3] \n"
"vmlal.s16 q14, d20, %e[ker1][3] \n"
"vmlal.s16 q15, d21, %e[ker1][3] \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d16, %e[ker3][0] \n"
"vmlal.s16 q13, d17, %e[ker3][0] \n"
"vmlal.s16 q14, d16, %f[ker2][3] \n"
"vmlal.s16 q15, d17, %f[ker2][3] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker2][0] \n"
"vmlal.s16 q13, d21, %e[ker2][0] \n"
"vmlal.s16 q14, d20, %f[ker1][0] \n"
"vmlal.s16 q15, d21, %f[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker2][1] \n"
"vmlal.s16 q13, d21, %e[ker2][1] \n"
"vmlal.s16 q14, d20, %f[ker1][1] \n"
"vmlal.s16 q15, d21, %f[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker2][2] \n"
"vmlal.s16 q13, d21, %e[ker2][2] \n"
"vmlal.s16 q14, d20, %f[ker1][2] \n"
"vmlal.s16 q15, d21, %f[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker2][3] \n"
"vmlal.s16 q13, d21, %e[ker2][3] \n"
"vmlal.s16 q14, d20, %f[ker1][3] \n"
"vmlal.s16 q15, d21, %f[ker1][3] \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q14, d16, %e[ker3][0] \n"
"vmlal.s16 q15, d17, %e[ker3][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q14, d20, %e[ker2][0] \n"
"vmlal.s16 q15, d21, %e[ker2][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q14, d20, %e[ker2][1] \n"
"vmlal.s16 q15, d21, %e[ker2][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q14, d20, %e[ker2][2] \n"
"vmlal.s16 q15, d21, %e[ker2][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q14, d20, %e[ker2][3] \n"
"vmlal.s16 q15, d21, %e[ker2][3] \n"
// restore output
"vst1.32 {q12-q13}, [%[output_ptr0]]! \n"
"vst1.32 {q14-q15}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h8w_%= \n"
"start_remain4_%=: \n"
"cmp %[remain], #4 \n"
"blt start_remain_%= \n"
"mov r0, #4 \n"
"vld1.s8 {d10}, [%[input_ptr0]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr1]], r0 \n"
"vld1.s8 {d14}, [%[input_ptr2]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmull.s16 q12, d16, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker0][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %f[ker2][1] \n"
"vmull.s16 q14, d16, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker0][0] \n"
"vmlal.s16 q14, d20, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker0][1] \n"
"vmlal.s16 q14, d20, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker0][2] \n"
"vmlal.s16 q14, d20, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker0][3] \n"
"vmlal.s16 q14, d20, %e[ker0][3] \n"
"vmovl.s8 q8, d14 \n"
"vmlal.s16 q12, d16, %f[ker2][2] \n"
"vmlal.s16 q14, d16, %f[ker2][1] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker1][0] \n"
"vmlal.s16 q14, d20, %f[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker1][1] \n"
"vmlal.s16 q14, d20, %f[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker1][2] \n"
"vmlal.s16 q14, d20, %f[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker1][3] \n"
"vmlal.s16 q14, d20, %f[ker0][3] \n"
"vld1.s8 {d10}, [%[input_ptr3]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr4]], r0 \n"
"vld1.s8 {d14}, [%[input_ptr5]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmlal.s16 q12, d16, %f[ker2][3] \n"
"vmlal.s16 q14, d16, %f[ker2][2] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker1][0] \n"
"vmlal.s16 q14, d20, %e[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker1][1] \n"
"vmlal.s16 q14, d20, %e[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker1][2] \n"
"vmlal.s16 q14, d20, %e[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker1][3] \n"
"vmlal.s16 q14, d20, %e[ker1][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %e[ker3][0] \n"
"vmlal.s16 q14, d16, %f[ker2][3] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker2][0] \n"
"vmlal.s16 q14, d20, %f[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker2][1] \n"
"vmlal.s16 q14, d20, %f[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker2][2] \n"
"vmlal.s16 q14, d20, %f[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker2][3] \n"
"vmlal.s16 q14, d20, %f[ker1][3] \n"
"vmovl.s8 q8, d14 \n"
"vmlal.s16 q14, d16, %e[ker3][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q14, d20, %e[ker2][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q14, d20, %e[ker2][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q14, d20, %e[ker2][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q14, d20, %e[ker2][3] \n"
// restore output
"vst1.32 {d24-d25}, [%[output_ptr0]]! \n"
"vst1.32 {d28-d29}, [%[output_ptr1]]! \n"
"sub %[remain], #4 \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain] \n"
"vld1.s8 {d10}, [%[input_ptr0]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr1]], r0 \n"
"vld1.s8 {d14}, [%[input_ptr2]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmull.s16 q12, d16, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker0][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %f[ker2][1] \n"
"vmull.s16 q14, d16, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker0][0] \n"
"vmlal.s16 q14, d20, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker0][1] \n"
"vmlal.s16 q14, d20, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker0][2] \n"
"vmlal.s16 q14, d20, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker0][3] \n"
"vmlal.s16 q14, d20, %e[ker0][3] \n"
"vmovl.s8 q8, d14 \n"
"vmlal.s16 q12, d16, %f[ker2][2] \n"
"vmlal.s16 q14, d16, %f[ker2][1] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker1][0] \n"
"vmlal.s16 q14, d20, %f[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker1][1] \n"
"vmlal.s16 q14, d20, %f[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker1][2] \n"
"vmlal.s16 q14, d20, %f[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker1][3] \n"
"vmlal.s16 q14, d20, %f[ker0][3] \n"
"vld1.s8 {d10}, [%[input_ptr3]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr4]], r0 \n"
"vld1.s8 {d14}, [%[input_ptr5]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmlal.s16 q12, d16, %f[ker2][3] \n"
"vmlal.s16 q14, d16, %f[ker2][2] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker1][0] \n"
"vmlal.s16 q14, d20, %e[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker1][1] \n"
"vmlal.s16 q14, d20, %e[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker1][2] \n"
"vmlal.s16 q14, d20, %e[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker1][3] \n"
"vmlal.s16 q14, d20, %e[ker1][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %e[ker3][0] \n"
"vmlal.s16 q14, d16, %f[ker2][3] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker2][0] \n"
"vmlal.s16 q14, d20, %f[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker2][1] \n"
"vmlal.s16 q14, d20, %f[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker2][2] \n"
"vmlal.s16 q14, d20, %f[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker2][3] \n"
"vmlal.s16 q14, d20, %f[ker1][3] \n"
"vmovl.s8 q8, d14 \n"
"vmlal.s16 q14, d16, %e[ker3][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q14, d20, %e[ker2][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q14, d20, %e[ker2][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q14, d20, %e[ker2][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q14, d20, %e[ker2][3] \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d24}, [%[output_ptr0]]! \n"
"vst1.32 {d28}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d25[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"vst1.32 {d24[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[loop] "+r"(loop), [remain] "+r"(w_remain)
: [ker0] "w"(_ker0), [ker1] "w"(_ker1), [ker2] "w"(_ker2),
[ker3] "w"(_ker3)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t row5 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr5)));
int16x4_t zero = vdup_n_s16(0);
int32x4_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) {
*output_ptr0 = 0;
*output_ptr1 = 0;
} else {
int iw = w - valid_w_end;
int32_t sum0 = input_ptr0[iw] * filter_ptr0[0] +
input_ptr1[iw] * filter_ptr1[0] +
input_ptr2[iw] * filter_ptr2[0] +
input_ptr3[iw] * filter_ptr3[0] +
input_ptr4[iw] * filter_ptr4[0];
int32_t sum1 = input_ptr1[iw] * filter_ptr0[0] +
input_ptr2[iw] * filter_ptr1[0] +
input_ptr3[iw] * filter_ptr2[0] +
input_ptr4[iw] * filter_ptr3[0] +
input_ptr5[iw] * filter_ptr4[0];
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
row3 = vext_s16(row3, zero, 1);
row4 = vext_s16(row4, zero, 1);
row5 = vext_s16(row5, zero, 1);
acc0 = vmull_s16(row0, _ker[0]);
acc0 = vmlal_s16(acc0, row1, _ker[1]);
acc0 = vmlal_s16(acc0, row2, _ker[2]);
acc0 = vmlal_s16(acc0, row3, _ker[3]);
acc0 = vmlal_s16(acc0, row4, _ker[4]);
acc1 = vmull_s16(row1, _ker[0]);
acc1 = vmlal_s16(acc1, row2, _ker[1]);
acc1 = vmlal_s16(acc1, row3, _ker[2]);
acc1 = vmlal_s16(acc1, row4, _ker[3]);
acc1 = vmlal_s16(acc1, row5, _ker[4]);
acc0 = vpaddq_s32(acc0, acc1);
int32x2_t sum = vpadd_s32(vget_low_s32(acc0), vget_high_s32(acc0));
sum0 += vget_lane_s32(sum, 0);
sum1 += vget_lane_s32(sum, 1);
*output_ptr0 = sum0;
*output_ptr1 = sum1;
}
output_ptr0++;
output_ptr1++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
if (start_h < valid_h_end) {
const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w;
int32_t *output_ptr0 = output_ptr + start_h * output_w;
// pad left
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t zero = vdup_n_s16(0);
int32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 5) {
output_ptr0[w] = 0;
} else {
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
acc = vmlal_s16(acc, row3, _ker[3]);
acc = vmlal_s16(acc, row4, _ker[4]);
int32x2_t sum = vpadd_s32(vget_low_s32(acc), vget_high_s32(acc));
sum = vpadd_s32(sum, sum);
vst1_lane_s32(output_ptr0 + w, sum, 0);
row0 = vext_s16(zero, row0, 3);
row1 = vext_s16(zero, row1, 3);
row2 = vext_s16(zero, row2, 3);
row3 = vext_s16(zero, row3, 3);
row4 = vext_s16(zero, row4, 3);
}
}
output_ptr0 += valid_w_start;
}
// valid
int loop = output_w_tiles;
int w_remain = output_w_remain;
asm volatile(
"cmp %[loop], #0 \n"
"ble start_remain4_%= \n"
"mov r0, #8 \n"
"loop_1h8w_%=: \n"
"vld1.s8 {d10-d11}, [%[input_ptr0]], r0 \n"
"vld1.s8 {d12-d13}, [%[input_ptr1]], r0 \n"
"vld1.s8 {d14-d15}, [%[input_ptr2]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q12, d16, %f[ker2][0] \n"
"vmull.s16 q13, d17, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker0][0] \n"
"vmlal.s16 q13, d21, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker0][1] \n"
"vmlal.s16 q13, d21, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker0][2] \n"
"vmlal.s16 q13, d21, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker0][3] \n"
"vmlal.s16 q13, d21, %e[ker0][3] \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d16, %f[ker2][1] \n"
"vmlal.s16 q13, d17, %f[ker2][1] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker0][0] \n"
"vmlal.s16 q13, d21, %f[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker0][1] \n"
"vmlal.s16 q13, d21, %f[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker0][2] \n"
"vmlal.s16 q13, d21, %f[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker0][3] \n"
"vmlal.s16 q13, d21, %f[ker0][3] \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q12, d16, %f[ker2][2] \n"
"vmlal.s16 q13, d17, %f[ker2][2] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker1][0] \n"
"vmlal.s16 q13, d21, %e[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker1][1] \n"
"vmlal.s16 q13, d21, %e[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker1][2] \n"
"vmlal.s16 q13, d21, %e[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker1][3] \n"
"vmlal.s16 q13, d21, %e[ker1][3] \n"
"vld1.s8 {d10-d11}, [%[input_ptr3]], r0 \n"
"vld1.s8 {d12-d13}, [%[input_ptr4]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q12, d16, %f[ker2][3] \n"
"vmlal.s16 q13, d17, %f[ker2][3] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker1][0] \n"
"vmlal.s16 q13, d21, %f[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker1][1] \n"
"vmlal.s16 q13, d21, %f[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker1][2] \n"
"vmlal.s16 q13, d21, %f[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker1][3] \n"
"vmlal.s16 q13, d21, %f[ker1][3] \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d16, %e[ker3][0] \n"
"vmlal.s16 q13, d17, %e[ker3][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker2][0] \n"
"vmlal.s16 q13, d21, %e[ker2][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker2][1] \n"
"vmlal.s16 q13, d21, %e[ker2][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker2][2] \n"
"vmlal.s16 q13, d21, %e[ker2][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker2][3] \n"
"vmlal.s16 q13, d21, %e[ker2][3] \n"
// restore output
"vst1.32 {q12-q13}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h8w_%= \n"
"start_remain4_%=: \n"
"cmp %[remain], #4 \n"
"blt start_remain_%= \n"
"mov r0, #4 \n"
"vld1.s8 {d10}, [%[input_ptr0]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr1]], r0 \n"
"vld1.s8 {d14}, [%[input_ptr2]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmull.s16 q12, d16, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker0][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %f[ker2][1] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker0][3] \n"
"vmovl.s8 q8, d14 \n"
"vmlal.s16 q12, d16, %f[ker2][2] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker1][3] \n"
"vld1.s8 {d10}, [%[input_ptr3]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr4]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmlal.s16 q12, d16, %f[ker2][3] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker1][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %e[ker3][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker2][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker2][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker2][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker2][3] \n"
// restore output
"vst1.32 {d24-d25}, [%[output_ptr0]]! \n"
"sub %[remain], #4 \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"mov r0, %[remain] \n"
"vld1.s8 {d10}, [%[input_ptr0]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr1]], r0 \n"
"vld1.s8 {d14}, [%[input_ptr2]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmull.s16 q12, d16, %f[ker2][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker0][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %f[ker2][1] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker0][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker0][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker0][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker0][3] \n"
"vmovl.s8 q8, d14 \n"
"vmlal.s16 q12, d16, %f[ker2][2] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker1][3] \n"
"vld1.s8 {d10}, [%[input_ptr3]], r0 \n"
"vld1.s8 {d12}, [%[input_ptr4]], r0 \n"
"vmovl.s8 q8, d10 \n"
"vmlal.s16 q12, d16, %f[ker2][3] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %f[ker1][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %f[ker1][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %f[ker1][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %f[ker1][3] \n"
"vmovl.s8 q8, d12 \n"
"vmlal.s16 q12, d16, %e[ker3][0] \n"
"vext.s16 q10, q8, q9, #1 \n"
"vmlal.s16 q12, d20, %e[ker2][0] \n"
"vext.s16 q10, q8, q9, #2 \n"
"vmlal.s16 q12, d20, %e[ker2][1] \n"
"vext.s16 q10, q8, q9, #3 \n"
"vmlal.s16 q12, d20, %e[ker2][2] \n"
"vext.s16 q10, q8, q9, #4 \n"
"vmlal.s16 q12, d20, %e[ker2][3] \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d24}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d25[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"vst1.32 {d24[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [output_ptr0] "+r"(output_ptr0),
[loop] "+r"(loop), [remain] "+r"(w_remain)
: [ker0] "w"(_ker0), [ker1] "w"(_ker1), [ker2] "w"(_ker2),
[ker3] "w"(_ker3)
: "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15", "r0");
// pad right
if (padding_w) {
int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0)));
int16x4_t row1 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr1)));
int16x4_t row2 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr2)));
int16x4_t row3 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr3)));
int16x4_t row4 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr4)));
int16x4_t zero = vdup_n_s16(0);
int32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) {
*output_ptr0 = 0;
} else {
int iw = w - valid_w_end;
int32_t sum0 = input_ptr0[iw] * filter_ptr0[0] +
input_ptr1[iw] * filter_ptr1[0] +
input_ptr2[iw] * filter_ptr2[0] +
input_ptr3[iw] * filter_ptr3[0] +
input_ptr4[iw] * filter_ptr4[0];
row0 = vext_s16(row0, zero, 1);
row1 = vext_s16(row1, zero, 1);
row2 = vext_s16(row2, zero, 1);
row3 = vext_s16(row3, zero, 1);
row4 = vext_s16(row4, zero, 1);
acc = vmull_s16(row0, _ker[0]);
acc = vmlal_s16(acc, row1, _ker[1]);
acc = vmlal_s16(acc, row2, _ker[2]);
acc = vmlal_s16(acc, row3, _ker[3]);
acc = vmlal_s16(acc, row4, _ker[4]);
int32x2_t sum = vpadd_s32(vget_low_s32(acc), vget_high_s32(acc));
sum = vpadd_s32(sum, sum);
sum0 += vget_lane_s32(sum, 0);
*output_ptr0 = sum0;
}
output_ptr0++;
}
}
}
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv5x5NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr, _ker, kernel);
}
}
}
template <>
void DepthwiseConv5x5S2<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON__
......@@ -3150,9 +3150,11 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifndef __aarch64__
if (m == 1 && bias == nullptr) {
return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu);
}
#endif // __aarch64__
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
......
......@@ -53,7 +53,7 @@ struct PoolingVal<AVG> {
++count;
return *this;
}
inline float Value() { return (count > 0) ? val / count : 0.f; }
inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; }
};
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
......@@ -67,6 +67,16 @@ inline float32x4_t vPoolInitq_f32<AVG>() {
return vdupq_n_f32(0.f);
}
template <PoolingType P = MAX>
inline float32x2_t vPoolInit_f32() {
return vdup_n_f32(-std::numeric_limits<float>::max());
}
template <>
inline float32x2_t vPoolInit_f32<AVG>() {
return vdup_n_f32(0.f);
}
template <PoolingType P = MAX>
inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) {
return vmaxq_f32(x1, x2);
......@@ -78,6 +88,28 @@ inline float32x4_t vPoolPreq_f32<AVG>(const float32x4_t &x1,
return vaddq_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x2_t vPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) {
return vmax_f32(x1, x2);
}
template <>
inline float32x2_t vPoolPre_f32<AVG>(const float32x2_t &x1,
const float32x2_t &x2) {
return vadd_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x2_t vpPoolPre_f32(const float32x2_t &x1, const float32x2_t &x2) {
return vpmax_f32(x1, x2);
}
template <>
inline float32x2_t vpPoolPre_f32<AVG>(const float32x2_t &x1,
const float32x2_t &x2) {
return vpadd_f32(x1, x2);
}
template <PoolingType P = MAX>
inline float32x4_t vPoolPostq_f32(const float32x4_t &x,
const float32x4_t &post) {
......@@ -89,6 +121,18 @@ inline float32x4_t vPoolPostq_f32<AVG>(const float32x4_t &x,
const float32x4_t &post) {
return vmulq_f32(x, post);
}
template <PoolingType P = MAX>
inline float32x2_t vPoolPost_f32(const float32x2_t &x,
const float32x2_t &post) {
return x;
}
template <>
inline float32x2_t vPoolPost_f32<AVG>(const float32x2_t &x,
const float32x2_t &post) {
return vmul_f32(x, post);
}
#endif // __ARM_NEON__
template <PoolingType P = MAX>
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
#include "operators/math/pooling.h"
// TODO(hjchen2): Optimize Pooling2x2NormalRow and use inline assembly
namespace paddle_mobile {
namespace operators {
namespace math {
#define POOLING2X2_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride; \
const int w_in_end = w_in_start + 2; \
const int w_start = w_in_start > 0 ? w_in_start : 0; \
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
PoolingVal<P> val; \
for (int h_in = h_start; h_in < h_end; ++h_in) { \
for (int w_in = w_start; w_in < w_end; ++w_in) { \
val += input[h_in * input_w + w_in]; \
} \
} \
output_ptr[w] = val.Value(); \
}
template <PoolingType P, int Stride = 1>
struct Pooling2x2NormalRowLoadInput {
void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) {
x0[0] = vld1q_f32(input);
x0[1] = vld1q_f32(input + 4);
x1[0] = vextq_f32(x0[0], x0[1], 1);
x1[1] = vextq_f32(x0[1], x0[1], 1);
}
};
template <PoolingType P>
struct Pooling2x2NormalRowLoadInput<P, 2> {
void operator()(const float *input, float32x4_t *x0, float32x4_t *x1) {
float32x4x2_t t0 = vld2q_f32(input);
float32x4x2_t t1 = vld2q_f32(input + 8);
x0[0] = t0.val[0];
x0[1] = t1.val[0];
x1[0] = t0.val[1];
x1[1] = t1.val[1];
}
};
template <PoolingType P, int Stride>
inline void Pooling2x2NormalRow(const float *input, const int h_output,
const int input_h, const int input_w,
const int padding_h, const int padding_w,
const int output_w, float *output) {
const int h_in_start = -padding_h + h_output * Stride;
const int h_in_end = h_in_start + 2;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
float *output_ptr = output + h_output * output_w;
if (h_end - h_start <= 0) {
memset(output_ptr, 0, output_w * sizeof(float));
return;
}
const int valid_w_start = (padding_w + Stride - 1) / Stride;
const int valid_w_end = (input_w + padding_w - 2) / Stride + 1;
const int valid_w = valid_w_end - valid_w_start;
// border left
POOLING2X2_NORMAL_BORDER(0, valid_w_start)
// valid w
Pooling2x2NormalRowLoadInput<P, Stride> load_input;
int output_tiles = valid_w / 6;
int output_tiles_w = output_tiles * 6;
float32x4_t x0[2], x1[2], y0[2];
float32x4_t post = vdupq_n_f32(1.f / (2 * (h_end - h_start)));
for (int w = 0; w < output_tiles_w; w += 6) {
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride - padding_w;
y0[0] = vPoolInitq_f32<P>();
y0[1] = vPoolInitq_f32<P>();
for (int h_in = h_start; h_in < h_end; ++h_in) {
load_input(input + h_in * input_w + input_w_offset, x0, x1);
y0[0] = vPoolPreq_f32<P>(y0[0], x0[0]);
y0[0] = vPoolPreq_f32<P>(y0[0], x1[0]);
y0[1] = vPoolPreq_f32<P>(y0[1], x0[1]);
y0[1] = vPoolPreq_f32<P>(y0[1], x1[1]);
}
y0[0] = vPoolPostq_f32<P>(y0[0], post);
y0[1] = vPoolPostq_f32<P>(y0[1], post);
vst1q_f32(output_ptr + output_offset, y0[0]);
vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0[1]));
}
// remain valid w
int remain = valid_w - output_tiles_w;
if (remain > 0) {
int remain_start = valid_w_start + output_tiles_w;
int input_w_offset = remain_start * Stride - padding_w;
float *output_ptr0 = output_ptr + remain_start;
y0[0] = vPoolInitq_f32<P>();
y0[1] = vPoolInitq_f32<P>();
for (int h_in = h_start; h_in < h_end; ++h_in) {
load_input(input + h_in * input_w + input_w_offset, x0, x1);
y0[0] = vPoolPreq_f32<P>(y0[0], x0[0]);
y0[0] = vPoolPreq_f32<P>(y0[0], x1[0]);
y0[1] = vPoolPreq_f32<P>(y0[1], x0[1]);
y0[1] = vPoolPreq_f32<P>(y0[1], x1[1]);
}
y0[0] = vPoolPostq_f32<P>(y0[0], post);
y0[1] = vPoolPostq_f32<P>(y0[1], post);
switch (remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0[0]));
vst1q_lane_f32(output_ptr0 + 2, y0[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0[0]);
vst1q_lane_f32(output_ptr0 + 4, y0[1], 0);
break;
}
}
// border right
POOLING2X2_NORMAL_BORDER(valid_w_end, output_w)
}
template <PoolingType P>
struct Pooling2x2<P, 1> {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &paddings,
framework::Tensor *output) {
const float *input_data = input.data<float>();
float *output_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < output->dims()[1]; ++c) {
int channel = batch * output->dims()[1] + c;
const float *input_ptr = input_data + channel * image_size;
float *output_ptr = output_data + channel * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
Pooling2x2NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 3; h += 4) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
float *output_ptr3 = output_ptr2 + output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 2) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
output_ptr2[w] = 0.f;
output_ptr3[w] = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
float acc1 = PoolPre<P>(*input_ptr1, *input_ptr2);
float acc2 = PoolPre<P>(*input_ptr2, *input_ptr3);
float acc3 = PoolPre<P>(*input_ptr3, *input_ptr4);
output_ptr0[w] = PoolPost<P>(acc0, 0.5f);
output_ptr1[w] = PoolPost<P>(acc1, 0.5f);
output_ptr2[w] = PoolPost<P>(acc2, 0.5f);
output_ptr3[w] = PoolPost<P>(acc3, 0.5f);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
output_ptr3 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, q0;
float32x4x2_t y0, y1;
float32x4_t post = vdupq_n_f32(0.25f);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vld1q_f32(input_ptr1);
x1.val[1] = vld1q_f32(input_ptr1 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], q0.val[1]);
q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1);
q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1);
y1.val[0] = vPoolPreq_f32<P>(x1.val[0], q0.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x1.val[1], q0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vld1q_f32(input_ptr3);
x1.val[1] = vld1q_f32(input_ptr3 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], q0.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1);
q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1);
y1.val[0] = vPoolPreq_f32<P>(x1.val[0], q0.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x1.val[1], q0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr2, y0.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1]));
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y1.val[0] = vPoolPreq_f32<P>(y1.val[0], x0.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y1.val[0], q0.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y1.val[1], x0.val[1]);
y1.val[1] = vPoolPreq_f32<P>(y1.val[1], q0.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr3, y1.val[0]);
vst1_f32(output_ptr3 + 4, vget_low_f32(y1.val[1]));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
input_ptr3 += 6;
input_ptr4 += 6;
output_ptr0 += 6;
output_ptr1 += 6;
output_ptr2 += 6;
output_ptr3 += 6;
}
// remain width
if (output_w_remain > 0) {
float32x4x2_t y2, y3;
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vld1q_f32(input_ptr1);
x1.val[1] = vld1q_f32(input_ptr1 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], q0.val[1]);
q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1);
q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1);
y1.val[0] = vPoolPreq_f32<P>(x1.val[0], q0.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x1.val[1], q0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vld1q_f32(input_ptr3);
x1.val[1] = vld1q_f32(input_ptr3 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], q0.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], q0.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y1.val[0], y2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y1.val[1], y2.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1);
q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1);
y3.val[0] = vPoolPreq_f32<P>(x1.val[0], q0.val[0]);
y3.val[1] = vPoolPreq_f32<P>(x1.val[1], q0.val[1]);
y2.val[0] = vPoolPreq_f32<P>(y2.val[0], y3.val[0]);
y2.val[1] = vPoolPreq_f32<P>(y2.val[1], y3.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y3.val[0] = vPoolPreq_f32<P>(y3.val[0], x0.val[0]);
y3.val[0] = vPoolPreq_f32<P>(y3.val[0], q0.val[0]);
y3.val[1] = vPoolPreq_f32<P>(y3.val[1], x0.val[1]);
y3.val[1] = vPoolPreq_f32<P>(y3.val[1], q0.val[1]);
y3.val[0] = vPoolPostq_f32<P>(y3.val[0], post);
y3.val[1] = vPoolPostq_f32<P>(y3.val[1], post);
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
vst1q_lane_f32(output_ptr1, y1.val[0], 0);
vst1q_lane_f32(output_ptr2, y2.val[0], 0);
vst1q_lane_f32(output_ptr3, y3.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1_f32(output_ptr3, vget_low_f32(y3.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1_f32(output_ptr3, vget_low_f32(y3.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2);
vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2);
vst1q_lane_f32(output_ptr3 + 2, y3.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_f32(output_ptr3, y3.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_f32(output_ptr3, y3.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0);
vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0);
vst1q_lane_f32(output_ptr3 + 4, y3.val[1], 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
input_ptr4 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
output_ptr2 += output_w_remain;
output_ptr3 += output_w_remain;
}
// pad right
if (padding_w) {
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 2 - (padding_w + input_w);
if (padding >= 2) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
*output_ptr2 = 0.f;
*output_ptr3 = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
float acc1 = PoolPre<P>(*input_ptr1, *input_ptr2);
float acc2 = PoolPre<P>(*input_ptr2, *input_ptr3);
float acc3 = PoolPre<P>(*input_ptr3, *input_ptr4);
*output_ptr0 = PoolPost<P>(acc0, 0.5f);
*output_ptr1 = PoolPost<P>(acc1, 0.5f);
*output_ptr2 = PoolPost<P>(acc2, 0.5f);
*output_ptr3 = PoolPost<P>(acc3, 0.5f);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
output_ptr3++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xFFFC);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 2) {
output_ptr0[w] = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
output_ptr0[w] = PoolPost<P>(acc0, 0.5f);
}
}
output_ptr0 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, q0, y0;
float32x4_t post = vdupq_n_f32(0.25f);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vld1q_f32(input_ptr1);
x1.val[1] = vld1q_f32(input_ptr1 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], q0.val[1]);
q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1);
q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], x1.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], q0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 6;
input_ptr1 += 6;
output_ptr0 += 6;
}
// remain width
if (output_w_remain > 0) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vld1q_f32(input_ptr1);
x1.val[1] = vld1q_f32(input_ptr1 + 4);
q0.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
q0.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], q0.val[1]);
q0.val[0] = vextq_f32(x1.val[0], x1.val[1], 1);
q0.val[1] = vextq_f32(x1.val[1], x1.val[1], 1);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], x1.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y0.val[0], q0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y0.val[1], q0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
output_ptr0 += output_w_remain;
}
// pad right
if (padding_w) {
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 2 - (padding_w + input_w);
if (padding >= 2) {
*output_ptr0 = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
*output_ptr0 = PoolPost<P>(acc0, 0.5f);
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling2x2NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
}
}
}
};
template <PoolingType P>
struct Pooling2x2<P, 2> {
inline void operator()(const framework::Tensor &input,
const std::vector<int> &paddings,
framework::Tensor *output) {
const float *input_data = input.data<float>();
float *output_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2;
int valid_h_end = (input_h + padding_h) / 2;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = (input_w + padding_w) / 2;
int valid_w = valid_w_end - valid_w_start;
bool ceil_mode = (((input_h + 2 * padding_h) / 2) < output_h) ||
(((input_w + 2 * padding_w) / 2) < output_w);
int padding_b =
padding_h + (ceil_mode ? 2 * output_h - (input_h + 2 * padding_h) : 0);
int padding_r =
padding_w + (ceil_mode ? 2 * output_w - (input_w + 2 * padding_w) : 0);
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < output->dims()[1]; ++c) {
int channel = batch * output->dims()[1] + c;
const float *input_ptr = input_data + channel * image_size;
float *output_ptr = output_data + channel * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
Pooling2x2NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid
int output_w_tiles = valid_w / 4;
int output_w_remain = valid_w - output_w_tiles * 4;
for (int h = valid_h_start; h < valid_h_end - 1; h += 2) {
const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w * 2;
if (padding >= 2) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
float acc1 = PoolPre<P>(*input_ptr2, *input_ptr3);
output_ptr0[w] = PoolPost<P>(acc0, 0.5f);
output_ptr1[w] = PoolPost<P>(acc1, 0.5f);
}
}
input_ptr0 += (padding_w & 0x1);
input_ptr1 += (padding_w & 0x1);
input_ptr2 += (padding_w & 0x1);
input_ptr3 += (padding_w & 0x1);
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2, x3;
float32x4_t y0, y1;
float32x4_t post = vdupq_n_f32(0.25f);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr1);
x2 = vld2q_f32(input_ptr2);
x3 = vld2q_f32(input_ptr3);
y0 = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y1 = vPoolPreq_f32<P>(x2.val[0], x2.val[1]);
y0 = vPoolPreq_f32<P>(y0, x1.val[0]);
y1 = vPoolPreq_f32<P>(y1, x3.val[0]);
y0 = vPoolPreq_f32<P>(y0, x1.val[1]);
y1 = vPoolPreq_f32<P>(y1, x3.val[1]);
y0 = vPoolPostq_f32<P>(y0, post);
y1 = vPoolPostq_f32<P>(y1, post);
vst1q_f32(output_ptr0, y0);
vst1q_f32(output_ptr1, y1);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
}
// remain width
if (output_w_remain > 0) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr1);
x2 = vld2q_f32(input_ptr2);
x3 = vld2q_f32(input_ptr3);
y0 = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y1 = vPoolPreq_f32<P>(x2.val[0], x2.val[1]);
y0 = vPoolPreq_f32<P>(y0, x1.val[0]);
y1 = vPoolPreq_f32<P>(y1, x3.val[0]);
y0 = vPoolPreq_f32<P>(y0, x1.val[1]);
y1 = vPoolPreq_f32<P>(y1, x3.val[1]);
y0 = vPoolPostq_f32<P>(y0, post);
y1 = vPoolPostq_f32<P>(y1, post);
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0, 0);
vst1q_lane_f32(output_ptr1, y1, 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0));
vst1_f32(output_ptr1, vget_low_f32(y1));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0));
vst1q_lane_f32(output_ptr0 + 2, y0, 2);
vst1_f32(output_ptr1, vget_low_f32(y1));
vst1q_lane_f32(output_ptr1 + 2, y1, 2);
break;
}
input_ptr0 += 2 * output_w_remain;
input_ptr1 += 2 * output_w_remain;
input_ptr2 += 2 * output_w_remain;
input_ptr3 += 2 * output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
}
// pad right
if (padding_r) {
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 2 - (padding_w + input_w);
if (padding >= 2) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
float acc1 = PoolPre<P>(*input_ptr2, *input_ptr3);
*output_ptr0 = PoolPost<P>(acc0, 0.5f);
*output_ptr1 = PoolPost<P>(acc1, 0.5f);
}
output_ptr0++;
output_ptr1++;
}
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xfffe);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
float *output_ptr0 = output_ptr + h * output_w;
// pad left
if (padding_w) {
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - 2 * w;
if (padding >= 2) {
output_ptr0[w] = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
output_ptr0[w] = PoolPost<P>(acc0, 0.5f);
}
}
input_ptr0 += (padding_w & 0x1);
input_ptr1 += (padding_w & 0x1);
output_ptr0 += valid_w_start;
}
// valid
float32x4x2_t x0, x1;
float32x4_t y0;
float32x4_t post = vdupq_n_f32(0.25f);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr1);
y0 = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0 = vPoolPreq_f32<P>(y0, x1.val[0]);
y0 = vPoolPreq_f32<P>(y0, x1.val[1]);
y0 = vPoolPostq_f32<P>(y0, post);
vst1q_f32(output_ptr0, y0);
input_ptr0 += 8;
input_ptr1 += 8;
output_ptr0 += 4;
}
// remain width
if (output_w_remain > 0) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr1);
y0 = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0 = vPoolPreq_f32<P>(y0, x1.val[0]);
y0 = vPoolPreq_f32<P>(y0, x1.val[1]);
y0 = vPoolPostq_f32<P>(y0, post);
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0, 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0));
vst1q_lane_f32(output_ptr0 + 2, y0, 2);
break;
}
input_ptr0 += 2 * output_w_remain;
input_ptr1 += 2 * output_w_remain;
output_ptr0 += output_w_remain;
}
// pad right
if (padding_r) {
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 2 - (padding_w + input_w);
if (padding >= 2) {
*output_ptr0 = 0.f;
} else {
float acc0 = PoolPre<P>(*input_ptr0, *input_ptr1);
*output_ptr0 = PoolPost<P>(acc0, 0.5f);
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling2x2NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
}
}
}
};
template struct Pooling2x2<MAX, 1>;
template struct Pooling2x2<AVG, 1>;
template struct Pooling2x2<MAX, 2>;
template struct Pooling2x2<AVG, 2>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON__
#endif // POOL_OP
......@@ -14,10 +14,10 @@ limitations under the License. */
#ifdef POOL_OP
#include "operators/math/pooling.h"
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
#endif // __ARM_NEON
#include "operators/math/pooling.h"
namespace paddle_mobile {
namespace operators {
......@@ -38,87 +38,6 @@ namespace math {
output_ptr[w] = val.Value(); \
}
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
template <PoolingType P, int Stride = 1>
struct Pooling3x3ValidColLoadInput {
inline void operator()(const float *input, const int input_w,
const int valid_cols, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
float fake_input[3][8];
if (valid_cols == 1) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
for (int i = 0; i < valid_cols; ++i) {
x0.val[0] = vld1q_f32(fake_input[i]);
x0.val[1] = vld1q_f32(fake_input[i] + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x1.val[1], y0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x2.val[1], y0.val[1]);
}
}
};
template <PoolingType P>
struct Pooling3x3ValidColLoadInput<P, 2> {
inline void operator()(const float *input, const int input_w,
const int valid_cols, float32x4x2_t &x0, // NOLINT
float32x4x2_t &x1, float32x4x2_t &x2, // NOLINT
float32x4x2_t &y0) { // NOLINT
float fake_input[3][13];
if (valid_cols == 1) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
for (int i = 0; i < valid_cols; ++i) {
x0 = vld2q_f32(fake_input[i]);
x1 = vld2q_f32(fake_input[i] + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
}
}
};
template <PoolingType P, int Stride = 1>
struct Pooling3x3NormalRowLoadInput {
inline void operator()(const float *input, float32x4x2_t &x0, // NOLINT
......@@ -156,62 +75,6 @@ struct Pooling3x3NormalRowLoadInput<P, 2> {
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
}
};
#endif // __ARM_NEON__
template <PoolingType P, int Stride>
inline void Pooling3x3ValidCol(const float *input, const int h_output,
const int h_output_end, const int w_output,
const int input_h, const int input_w,
const int padding_h, const int padding_w,
const int output_w, float *output) {
const int w_in_start = -padding_w + w_output * Stride;
const int w_in_end = w_in_start + 3;
const int w_start = w_in_start > 0 ? w_in_start : 0;
const int w_end = w_in_end < input_w ? w_in_end : input_w;
int remain_start = h_output;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int output_tiles = (h_output_end - h_output) / 6;
remain_start = h_output + output_tiles * 6;
int input_h_start = h_output * Stride - padding_h;
size_t input_offset = input_h_start * input_w + w_start;
size_t output_offset = h_output * output_w + w_output;
int valid_cols = w_end - w_start;
Pooling3x3ValidColLoadInput<P, Stride> PoolingCompute;
float32x4x2_t x0, x1, x2, y0;
float32x4_t avg = vdupq_n_f32(1.f / (3 * valid_cols));
for (int h = 0; h < output_tiles * 6; h += 6) {
float *output0 = output + output_offset;
float *output1 = output0 + output_w;
float *output2 = output1 + output_w;
float *output3 = output2 + output_w;
float *output4 = output3 + output_w;
float *output5 = output4 + output_w;
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
PoolingCompute(input + input_offset, input_w, valid_cols, x0, x1, x2, y0);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], avg);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], avg);
vst1q_lane_f32(output0, y0.val[0], 0);
vst1q_lane_f32(output1, y0.val[0], 1);
vst1q_lane_f32(output2, y0.val[0], 2);
vst1q_lane_f32(output3, y0.val[0], 3);
vst1q_lane_f32(output4, y0.val[1], 0);
vst1q_lane_f32(output5, y0.val[1], 1);
input_offset += 6 * Stride * input_w;
output_offset += 6 * output_w;
}
#endif
for (int h = remain_start; h < h_output_end; ++h) {
PoolingVal<P> val;
const int h_in_start = -padding_h + h * Stride;
for (int i = 0; i < 3; ++i) {
for (int w_in = w_start; w_in < w_end; ++w_in) {
val += input[(h_in_start + i) * input_w + w_in];
}
}
output[h * output_w + w_output] = val.Value();
}
}
template <PoolingType P, int Stride>
inline void Pooling3x3NormalRow(const float *input, const int h_output,
......@@ -223,21 +86,25 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output,
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride - 1) / Stride;
int valid_w_end = (input_w - 3) / Stride + 1 + valid_w_start;
float *output_ptr = output + h_output * output_w;
if (h_end - h_start <= 0) {
memset(output_ptr, 0, output_w * sizeof(float));
return;
}
const int valid_w_start = (padding_w + Stride - 1) / Stride;
const int valid_w_end = (input_w + padding_w - 3) / Stride + 1;
const int valid_w = valid_w_end - valid_w_start;
// border left
POOLING3X3_NORMAL_BORDER(0, valid_w_start)
// middle
int remain_start = valid_w_start;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6;
int output_tiles_w = output_tiles * 6;
Pooling3x3NormalRowLoadInput<P, Stride> PoolingCompute;
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / (3 * (h_end - h_start)));
for (int w = 0; w < output_tiles * 6; w += 6) {
for (int w = 0; w < output_tiles_w; w += 6) {
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride - padding_w;
y0.val[0] = vPoolInitq_f32<P>();
......@@ -250,16 +117,37 @@ inline void Pooling3x3NormalRow(const float *input, const int h_output,
vst1q_f32(output_ptr + output_offset, y0.val[0]);
vst1_f32(output_ptr + output_offset + 4, vget_low_f32(y0.val[1]));
}
#endif // __ARM_NEON__
for (int w = remain_start; w < valid_w_end; ++w) {
PoolingVal<P> val;
int input_start = -padding_w + w * Stride;
int remain = valid_w - output_tiles_w;
if (remain > 0) {
int remain_start = valid_w_start + output_tiles_w;
int input_w_offset = remain_start * Stride - padding_w;
float *output_ptr0 = output_ptr + remain_start;
y0.val[0] = vPoolInitq_f32<P>();
y0.val[1] = vPoolInitq_f32<P>();
for (int h_in = h_start; h_in < h_end; ++h_in) {
for (int j = 0; j < 3; ++j) {
val += input[h_in * input_w + j + input_start];
}
PoolingCompute(input + h_in * input_w + input_w_offset, x0, x1, x2, y0);
}
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
switch (remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
break;
}
output_ptr[w] = val.Value();
}
// border right
POOLING3X3_NORMAL_BORDER(valid_w_end, output_w)
......@@ -286,7 +174,6 @@ struct Pooling3x3<P, 1> {
int valid_w_start = padding_w;
int valid_w = input_w - 2;
int valid_w_end = valid_w_start + valid_w;
float avg = 1.f / 9;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
......@@ -299,23 +186,6 @@ struct Pooling3x3<P, 1> {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
......@@ -326,12 +196,61 @@ struct Pooling3x3<P, 1> {
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
float *output_ptr3 = output_ptr2 + output_w;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
// pad left
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, acc3, acc12, acc34, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
output_ptr2[w] = 0.f;
output_ptr3[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc12 = vPoolPre_f32<P>(row1, row2);
acc34 = vPoolPre_f32<P>(row3, row4);
acc0 = vPoolPre_f32<P>(row0, acc12);
acc1 = vPoolPre_f32<P>(row3, acc12);
acc2 = vPoolPre_f32<P>(row2, acc34);
acc3 = vPoolPre_f32<P>(row5, acc34);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
acc3 = vpPoolPre_f32<P>(acc3, acc3);
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
acc3 = vPoolPost_f32<P>(acc3, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
vst1_lane_f32(output_ptr1 + w, acc1, 0);
vst1_lane_f32(output_ptr2 + w, acc2, 0);
vst1_lane_f32(output_ptr3 + w, acc3, 0);
row0 = vext_f32(pad0, row0, 1);
row1 = vext_f32(pad0, row1, 1);
row2 = vext_f32(pad0, row2, 1);
row3 = vext_f32(pad0, row3, 1);
row4 = vext_f32(pad0, row4, 1);
row5 = vext_f32(pad0, row5, 1);
}
}
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
output_ptr3 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
......@@ -446,100 +365,198 @@ struct Pooling3x3<P, 1> {
output_ptr3 += 6;
}
// remain width
if (remain >= 4) {
if (output_w_remain > 0) {
float32x4x2_t y3;
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y2.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y3.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y3.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y3.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y3.val[1], y1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(y3.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(y3.val[1], y2.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y3.val[0] = vPoolPreq_f32<P>(x0.val[0], y3.val[0]);
y3.val[1] = vPoolPreq_f32<P>(x0.val[1], y3.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
vst1q_f32(output_ptr2, y2.val[0]);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr3, y0.val[0]);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
input_ptr3 += 4;
input_ptr4 += 4;
input_ptr5 += 4;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
output_ptr3 += 4;
remain -= 4;
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y3.val[0] = vPoolPreq_f32<P>(x0.val[0], y3.val[0]);
y3.val[1] = vPoolPreq_f32<P>(x0.val[1], y3.val[1]);
y3.val[0] = vPoolPostq_f32<P>(y3.val[0], post);
y3.val[1] = vPoolPostq_f32<P>(y3.val[1], post);
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
vst1q_lane_f32(output_ptr1, y1.val[0], 0);
vst1q_lane_f32(output_ptr2, y2.val[0], 0);
vst1q_lane_f32(output_ptr3, y3.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1_f32(output_ptr3, vget_low_f32(y3.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1_f32(output_ptr3, vget_low_f32(y3.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2);
vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2);
vst1q_lane_f32(output_ptr3 + 2, y3.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_f32(output_ptr3, y3.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_f32(output_ptr3, y3.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0);
vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0);
vst1q_lane_f32(output_ptr3 + 4, y3.val[1], 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
input_ptr3 += output_w_remain;
input_ptr4 += output_w_remain;
input_ptr5 += output_w_remain;
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
output_ptr2 += output_w_remain;
output_ptr3 += output_w_remain;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
float m3 = PoolPre<P>(input_ptr3[r], input_ptr3[r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[r + 2]);
float m4 = PoolPre<P>(input_ptr4[r], input_ptr4[r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[r + 2]);
float m5 = PoolPre<P>(input_ptr5[r], input_ptr5[r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m1, m2), m3);
m2 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m3 = PoolPre<P>(PoolPre<P>(m3, m4), m5);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
output_ptr3[r] = PoolPost<P>(m3, avg);
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, acc3, acc12, acc34, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
*output_ptr2 = 0.f;
*output_ptr3 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc12 = vPoolPre_f32<P>(row1, row2);
acc34 = vPoolPre_f32<P>(row3, row4);
acc0 = vPoolPre_f32<P>(row0, acc12);
acc1 = vPoolPre_f32<P>(row3, acc12);
acc2 = vPoolPre_f32<P>(row2, acc34);
acc3 = vPoolPre_f32<P>(row5, acc34);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
acc3 = vpPoolPre_f32<P>(acc3, acc3);
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
acc3 = vPoolPost_f32<P>(acc3, post);
vst1_lane_f32(output_ptr0, acc0, 0);
vst1_lane_f32(output_ptr1, acc1, 0);
vst1_lane_f32(output_ptr2, acc2, 0);
vst1_lane_f32(output_ptr3, acc3, 0);
row0 = vext_f32(row0, pad0, 1);
row1 = vext_f32(row1, pad0, 1);
row2 = vext_f32(row2, pad0, 1);
row3 = vext_f32(row3, pad0, 1);
row4 = vext_f32(row4, pad0, 1);
row5 = vext_f32(row5, pad0, 1);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
output_ptr3++;
}
}
}
// remain height
......@@ -548,9 +565,33 @@ struct Pooling3x3<P, 1> {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float *output_ptr0 = output_ptr + h * output_w;
// pad left
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w;
if (padding >= 3) {
output_ptr0[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
row0 = vext_f32(pad0, row0, 1);
row1 = vext_f32(pad0, row1, 1);
row2 = vext_f32(pad0, row2, 1);
}
}
output_ptr0 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
......@@ -601,51 +642,101 @@ struct Pooling3x3<P, 1> {
output_ptr0 += 6;
}
// remain width
if (remain >= 4) {
if (output_w_remain > 0) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
output_ptr0 += 4;
remain -= 4;
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
// restore
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
break;
}
input_ptr0 += output_w_remain;
input_ptr1 += output_w_remain;
input_ptr2 += output_w_remain;
output_ptr0 += output_w_remain;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0, avg);
// pad right
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0, acc0, 0);
row0 = vext_f32(row0, pad0, 1);
row1 = vext_f32(row1, pad0, 1);
row2 = vext_f32(row2, pad0, 1);
}
output_ptr0++;
}
}
}
// pad bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
}
}
}
......@@ -667,12 +758,22 @@ struct Pooling3x3<P, 2> {
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
int valid_h_start = (padding_h + 1) / 2;
int valid_h = (input_h - 3) / 2 + 1;
int valid_h_end = valid_h_start + valid_h;
int valid_h_end = (input_h + padding_h - 1) / 2;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2;
int valid_w = (input_w - 3) / 2 + 1;
int valid_w_end = valid_w_start + valid_w;
float avg = 1.f / 9;
int valid_w_end = (input_w + padding_w - 1) / 2;
int valid_w = valid_w_end - valid_w_start;
int padding_height = input_h + 2 * padding_h;
int padding_width = input_w + 2 * padding_w;
bool ceil_mode = (((padding_height - 1) / 2) < output_h) ||
(((padding_width - 1) / 2) < output_w);
int padding_b =
padding_h + (ceil_mode ? 2 * output_h - (padding_height - 1) : 0);
int padding_r =
padding_w + (ceil_mode ? 2 * output_w - (padding_width - 1) : 0);
// for pad left
int valid_input_w_start = (valid_w_start << 1) - padding_w;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
......@@ -685,41 +786,70 @@ struct Pooling3x3<P, 2> {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid
int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
const float *input_ptr6 = input_ptr5 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
float *output_ptr0 = output_ptr + h * output_w;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
// pad left
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t row6 = vld1_f32(input_ptr6);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f;
output_ptr2[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc1 = vPoolPre_f32<P>(row2, row3);
acc2 = vPoolPre_f32<P>(row4, row5);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc1 = vPoolPre_f32<P>(acc1, row4);
acc2 = vPoolPre_f32<P>(acc2, row6);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
}
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
vst1_lane_f32(output_ptr1 + w, acc1, 0);
vst1_lane_f32(output_ptr2 + w, acc2, 0);
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
input_ptr3 += valid_input_w_start;
input_ptr4 += valid_input_w_start;
input_ptr5 += valid_input_w_start;
input_ptr6 += valid_input_w_start;
output_ptr0 += valid_w_start;
output_ptr1 += valid_w_start;
output_ptr2 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
......@@ -823,108 +953,210 @@ struct Pooling3x3<P, 2> {
output_ptr2 += 6;
}
// remain width
if (remain >= 4) {
if (output_w_remain > 0) {
x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]);
x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]);
x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
x0 = vld2q_f32(input_ptr3);
x1.val[0] = vdupq_n_f32(input_ptr3[8]);
x1 = vld2q_f32(input_ptr3 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], y1.val[1]);
x0 = vld2q_f32(input_ptr4);
x1.val[0] = vdupq_n_f32(input_ptr4[8]);
x1 = vld2q_f32(input_ptr4 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
x0 = vld2q_f32(input_ptr5);
x1.val[0] = vdupq_n_f32(input_ptr5[8]);
x1 = vld2q_f32(input_ptr5 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
x0 = vld2q_f32(input_ptr6);
x1.val[0] = vdupq_n_f32(input_ptr6[8]);
x1 = vld2q_f32(input_ptr6 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr2, y0.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
input_ptr4 += 8;
input_ptr5 += 8;
input_ptr6 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
remain -= 4;
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
vst1q_lane_f32(output_ptr1, y1.val[0], 0);
vst1q_lane_f32(output_ptr2, y2.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1_f32(output_ptr1, vget_low_f32(y1.val[0]));
vst1_f32(output_ptr2, vget_low_f32(y2.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
vst1q_lane_f32(output_ptr1 + 2, y1.val[0], 2);
vst1q_lane_f32(output_ptr2 + 2, y2.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_f32(output_ptr1, y1.val[0]);
vst1q_f32(output_ptr2, y2.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
vst1q_lane_f32(output_ptr1 + 4, y1.val[1], 0);
vst1q_lane_f32(output_ptr2 + 4, y2.val[1], 0);
break;
}
input_ptr0 += (output_w_remain << 1);
input_ptr1 += (output_w_remain << 1);
input_ptr2 += (output_w_remain << 1);
input_ptr3 += (output_w_remain << 1);
input_ptr4 += (output_w_remain << 1);
input_ptr5 += (output_w_remain << 1);
input_ptr6 += (output_w_remain << 1);
output_ptr0 += output_w_remain;
output_ptr1 += output_w_remain;
output_ptr2 += output_w_remain;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
float m3 = PoolPre<P>(input_ptr3[2 * r], input_ptr3[2 * r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[2 * r + 2]);
float m4 = PoolPre<P>(input_ptr4[2 * r], input_ptr4[2 * r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[2 * r + 2]);
float m5 = PoolPre<P>(input_ptr5[2 * r], input_ptr5[2 * r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[2 * r + 2]);
float m6 = PoolPre<P>(input_ptr6[2 * r], input_ptr6[2 * r + 1]);
m6 = PoolPre<P>(m6, input_ptr6[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m2 = PoolPre<P>(PoolPre<P>(m4, m5), m6);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
// pad right
if (padding_r > 0) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t row3 = vld1_f32(input_ptr3);
float32x2_t row4 = vld1_f32(input_ptr4);
float32x2_t row5 = vld1_f32(input_ptr5);
float32x2_t row6 = vld1_f32(input_ptr6);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, acc1, acc2, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
*output_ptr1 = 0.f;
*output_ptr2 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc1 = vPoolPre_f32<P>(row2, row3);
acc2 = vPoolPre_f32<P>(row4, row5);
acc0 = vPoolPre_f32<P>(acc0, row2);
acc1 = vPoolPre_f32<P>(acc1, row4);
acc2 = vPoolPre_f32<P>(acc2, row6);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
acc1 = vpPoolPre_f32<P>(acc1, acc1);
acc2 = vpPoolPre_f32<P>(acc2, acc2);
}
acc0 = vPoolPost_f32<P>(acc0, post);
acc1 = vPoolPost_f32<P>(acc1, post);
acc2 = vPoolPost_f32<P>(acc2, post);
vst1_lane_f32(output_ptr0, acc0, 0);
vst1_lane_f32(output_ptr1, acc1, 0);
vst1_lane_f32(output_ptr2, acc2, 0);
}
output_ptr0++;
output_ptr1++;
output_ptr2++;
}
}
}
// remain height
int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float *output_ptr0 = output_ptr + h * output_w;
// pad left
if (padding_w) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - (w << 1);
if (padding >= 3) {
output_ptr0[w] = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
}
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0 + w, acc0, 0);
}
}
input_ptr0 += valid_input_w_start;
input_ptr1 += valid_input_w_start;
input_ptr2 += valid_input_w_start;
output_ptr0 += valid_w_start;
}
// valid
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
......@@ -969,48 +1201,94 @@ struct Pooling3x3<P, 2> {
output_ptr0 += 6;
}
// remain width
if (remain >= 4) {
if (output_w_remain > 0) {
x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]);
x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]);
x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
output_ptr0 += 4;
remain -= 4;
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
// restore
switch (output_w_remain) {
case 1:
vst1q_lane_f32(output_ptr0, y0.val[0], 0);
break;
case 2:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
break;
case 3:
vst1_f32(output_ptr0, vget_low_f32(y0.val[0]));
vst1q_lane_f32(output_ptr0 + 2, y0.val[0], 2);
break;
case 4:
vst1q_f32(output_ptr0, y0.val[0]);
break;
case 5:
vst1q_f32(output_ptr0, y0.val[0]);
vst1q_lane_f32(output_ptr0 + 4, y0.val[1], 0);
break;
}
input_ptr0 += (output_w_remain << 1);
input_ptr1 += (output_w_remain << 1);
input_ptr2 += (output_w_remain << 1);
output_ptr0 += output_w_remain;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0, avg);
// pad right
if (padding_r > 0) {
float32x2_t row0 = vld1_f32(input_ptr0);
float32x2_t row1 = vld1_f32(input_ptr1);
float32x2_t row2 = vld1_f32(input_ptr2);
float32x2_t pad0 = vPoolInit_f32<P>();
float32x2_t acc0, post;
for (int w = valid_w_end; w < output_w; ++w) {
int padding = 2 * w + 3 - (padding_w + input_w);
if (padding >= 3) {
*output_ptr0 = 0.f;
} else {
post = vdup_n_f32(1.f / (3 * (3 - padding)));
acc0 = vPoolPre_f32<P>(row0, row1);
acc0 = vPoolPre_f32<P>(acc0, row2);
if (padding == 1) {
acc0 = vpPoolPre_f32<P>(acc0, acc0);
}
acc0 = vPoolPost_f32<P>(acc0, post);
vst1_lane_f32(output_ptr0, acc0, 0);
}
output_ptr0++;
}
}
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
}
}
}
......@@ -1025,4 +1303,5 @@ template struct Pooling3x3<AVG, 2>;
} // namespace operators
} // namespace paddle_mobile
#endif // __ARM_NEON
#endif // POOL_OP
......@@ -56,6 +56,9 @@ inline int32x4_t vRoundq_f32(const float32x4_t &x) {
template <>
inline int32x4_t vRoundq_f32<ROUND_NEAREST_AWAY_ZERO>(const float32x4_t &x) {
#if __aarch64__
return vcvtaq_s32_f32(x);
#else
float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0);
......@@ -64,10 +67,14 @@ inline int32x4_t vRoundq_f32<ROUND_NEAREST_AWAY_ZERO>(const float32x4_t &x) {
temp = vaddq_f32(x, temp);
int32x4_t ret = vcvtq_s32_f32(temp);
return ret;
#endif
}
template <>
inline int32x4_t vRoundq_f32<ROUND_NEAREST_TO_EVEN>(const float32x4_t &x) {
#if __aarch64__
return vcvtnq_s32_f32(x);
#else
float32x4_t point5 = vdupq_n_f32(0.5);
int32x4_t one = vdupq_n_s32(1);
int32x4_t zero = vdupq_n_s32(0);
......@@ -90,6 +97,7 @@ inline int32x4_t vRoundq_f32<ROUND_NEAREST_TO_EVEN>(const float32x4_t &x) {
smask = vsubq_s32(smask, one);
rnd = vaddq_s32(rnd, smask);
return rnd;
#endif
}
#endif // __ARM_NEON__
......
......@@ -424,8 +424,10 @@ class ConvParam : public OpParam {
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_DEPTHWISE5x5_FLOAT,
EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
EXEC_DEPTHWISE5x5_INT8,
};
ExecMode &ExecMode() const { return exec_mode_; }
......@@ -2605,8 +2607,8 @@ class QuantizeParam : public OpParam {
// if offine scale or not
bool offline_ = false;
// round method type
// RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
// RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
};
#endif
......
......@@ -165,14 +165,12 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels,
auto filter = filter_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(filter, filter_shape, -20, 20);
for (int i = 0; i < input->numel(); ++i) {
DLOG << "input[" << i
<< "] = " << static_cast<int>(input->data<int8_t>()[i]);
}
for (int i = 0; i < filter->numel(); ++i) {
DLOG << "filter[" << i
<< "] = " << static_cast<int>(filter->data<int8_t>()[i]);
}
// for (int i = 0; i < input->numel(); ++i) {
// DLOG << "input[" << i << "] = " << float(input->data<Itype>()[i]);
// }
// for (int i = 0; i < filter->numel(); ++i) {
// DLOG << "filter[" << i << "] = " << float(filter->data<Itype>()[i]);
// }
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
......@@ -198,18 +196,12 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels,
// (ts_end.tv_nsec - ts_begin.tv_nsec) / 1e6;
// LOG(kLOG_INFO) << "elapsed: " << elapsed / 10.0 << " ms";
int kernel_extent_h = dilation_h * (kernel_h - 1) + 1;
int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;
int output_h = (input_h + 2 * pad_h - kernel_extent_h) / stride_h + 1;
int output_w = (input_w + 2 * pad_w - kernel_extent_w) / stride_w + 1;
auto output_shape = framework::make_ddim(
std::vector<int>({batch_size, output_c, output_h, output_w}));
// compare results
auto *output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
output_cmp.mutable_data<Otype>(output_shape);
output_cmp.mutable_data<Otype>(output->dims());
conv2d<Itype, Otype>(input, filter, attrs, &output_cmp);
// compare results
auto output = output_var->template Get<framework::LoDTensor>();
const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) {
......@@ -285,96 +277,39 @@ int main(int argc, char *argv[]) {
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 5, 2>(
in_channels, in_height, in_width, out_channels, groups);
// // kernel = 7, pad = 0, stride = 2
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 1, stride = 2
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 3, stride = 2
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 1, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 3, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 5, stride = 3
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 7, pad = 3, stride = 4
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
// paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 3, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 3, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1";
// paddle_mobile::TestConvOp<float, float, 3, 0, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// // kernel = 3, pad = 1, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 3, pad = 1, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
// paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
// paddle_mobile::TestConvOp<float, float, 5, 0, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// // kernel = 5, pad = 2, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
// paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(in_channels,
// in_height,
// in_width,
// out_channels, groups);
// // kernel = 5, pad = 2, stride = 1
// LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
// paddle_mobile::TestConvOp<float, float, 5, 2, 1>(in_channels, in_height,
// in_width, out_channels,
// groups);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 1, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 2, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 5, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=5, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 5, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 1, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(
in_channels, in_height, in_width, out_channels, groups);
// kernel = 5, pad = 5, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=5, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 5, 1>(
in_channels, in_height, in_width, out_channels, groups);
return 0;
}
......@@ -169,28 +169,55 @@ int main(int argc, char *argv[]) {
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 5, pad = 0, stride = 2
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 2>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 4
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=4";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
// in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<0, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=1";
// paddle_mobile::TestPoolOp<1, 2, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<0, 2, 5, 2>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=0, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=1, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=2, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=2, pad=5, stride=2";
// paddle_mobile::TestPoolOp<1, 2, 5, 2>(in_channels, in_height, in_width);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册