提交 8137d199 编写于 作者: qnqinan's avatar qnqinan

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -5,6 +5,7 @@ option(DEBUGING "enable debug mode" ON) ...@@ -5,6 +5,7 @@ option(DEBUGING "enable debug mode" ON)
option(USE_EXCEPTION "use std exception" ON) option(USE_EXCEPTION "use std exception" ON)
option(SYMBOL_HIDDEN "symbol hidden" OFF) # on when use jni or ios io option(SYMBOL_HIDDEN "symbol hidden" OFF) # on when use jni or ios io
option(LOG_PROFILE "log profile" OFF) option(LOG_PROFILE "log profile" OFF)
# select the platform to build # select the platform to build
option(CPU "armv7 with neon" ON) option(CPU "armv7 with neon" ON)
option(GPU_MALI "mali gpu" OFF) option(GPU_MALI "mali gpu" OFF)
...@@ -15,7 +16,6 @@ if(FPGA) ...@@ -15,7 +16,6 @@ if(FPGA)
option(FPGAV2 "fpga v2" OFF) option(FPGAV2 "fpga v2" OFF)
endif() endif()
project(paddle-mobile) project(paddle-mobile)
file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm) file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm)
...@@ -247,6 +247,3 @@ elseif(FPGA) ...@@ -247,6 +247,3 @@ elseif(FPGA)
add_subdirectory(test) add_subdirectory(test)
endif() endif()
...@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; ...@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_TANH = "tanh";
const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add"; const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
...@@ -134,6 +136,7 @@ std::unordered_map< ...@@ -134,6 +136,7 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
......
...@@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; ...@@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_TANH;
extern const char *G_OP_TYPE_FUSION_DECONV_RELU; extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
......
...@@ -30,7 +30,6 @@ limitations under the License. */ ...@@ -30,7 +30,6 @@ limitations under the License. */
#ifdef PADDLE_EXECUTOR_MULTITHREAD #ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue> #include <queue>
#include <utility>
#include "common/threadpool.h" #include "common/threadpool.h"
#endif #endif
...@@ -73,7 +72,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size, ...@@ -73,7 +72,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope); program_.scope);
// infer shape to reshape tensor before predict, // infer shape to reshape tensor before predict,
// but for lod tensor, it will need to reshape in runtime // but for lod tensor, it will still need to reshape in runtime
if (!loddable_) { if (!loddable_) {
op_base->InferShape(); op_base->InferShape();
} }
......
...@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU); ...@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU); LOAD_OP1(dequantize, CPU);
#endif #endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
#endif
...@@ -127,11 +127,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> { ...@@ -127,11 +127,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
virtual void InferShape() const = 0; virtual void InferShape() const = 0;
void Init() { void Init() {
// for (auto i : this->inputs_) {
// DLOG << i.first;
// DLOG << i.second;
// }
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed", PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str()); this->type_.c_str());
} }
......
...@@ -54,22 +54,6 @@ class Tensor : public TensorBase { ...@@ -54,22 +54,6 @@ class Tensor : public TensorBase {
this->offset_ = inTensor.offset_; this->offset_ = inTensor.offset_;
} }
#ifdef PADDLE_MOBILE_DEBUG
template <typename T>
inline void dump(std::string filename) const {
const T *dataptr = data<T>();
std::ofstream out(filename.c_str());
for (int i = 0; i < numel(); ++i) {
out << dataptr[i] << " ";
}
out << "形状:";
for (int j = 0; j < dims_.size(); ++j) {
out << dims_[j] << " ";
}
out.close();
}
#endif
/*! Resize the dimensions of the memory block. */ /*! Resize the dimensions of the memory block. */
inline Tensor &Resize(const DDim &dims) { inline Tensor &Resize(const DDim &dims) {
dims_ = dims; dims_ = dims;
......
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void DequantizeOp<DeviceType, T>::InferShape() const { void DequantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims(); const auto& input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims); this->param_.output_->Resize(input_dims);
} }
} // namespace operators } // namespace operators
......
...@@ -12,33 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,33 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "../test_include.h" #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/conv_op.h"
#include "operators/fusion_dequant_add_bn_relu_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::GPU_MALI> loader; namespace paddle_mobile {
// ../models/image_classification_resnet.inference.model namespace operators {
auto program = loader.Load(g_googlenet);
template <typename Dtype, typename T>
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, void FusionDequantAddBNReluOp<Dtype, T>::InferShape() const {
"program file read fail"); const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
Executor4Test<paddle_mobile::GPU_MALI, paddle_mobile::operators::ConvOp<
paddle_mobile::GPU_MALI, float>>
executor(program, "conv2d");
paddle_mobile::framework::Tensor input;
GetInput<float>(g_test_image_1x3x224x224, &input, {1, 3, 224, 224});
// // use SetupTensor if not has local input image .
// SetupTensor<float>(&input, {1, 3, 224, 224}, static_cast<float>(0),
// static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 64, 112, 112});
auto output = executor.Predict(input, "data", "conv2d_0.tmp_0", out_ddim);
auto output_ptr = output->data<float>();
for (int j = 0; j < 20; ++j) {
DLOG << " value of output: " << output_ptr[j];
}
return 0;
} }
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>> {
public:
FusionDequantAddBNReluOp(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -22,12 +22,76 @@ namespace operators { ...@@ -22,12 +22,76 @@ namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
if (param->Filter()->type() == typeid(int8_t)) {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
}
} else {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor *transformed_weight = new framework::Tensor;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight);
param->Filter() = transformed_weight;
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
}
}
return true; return true;
} }
template <> template <>
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
ConvCompute<float>(param); switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_GEMM_INT8:
GemmConv<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8:
DepthwiseConv3x3<int8_t, int32_t>(param);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT:
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
break;
case ConvParam<CPU>::EXEC_GEMM_FLOAT:
GemmConv<float, float>(param);
break;
default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode());
}
} }
template class ConvKernel<CPU, float>; template class ConvKernel<CPU, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>();
// dequantize params
const float activation_scale = param.activation_scale_->data<float>()[0];
const float weight_scale = param.weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param.output_->mutable_data<float>();
int batch_size = param.input_->dims()[0];
int channels = param.input_->dims()[1];
size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
...@@ -31,7 +31,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) { ...@@ -31,7 +31,7 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template <> template <>
void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) { void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
const Tensor *input = param.input_; const Tensor *input = param.input_;
Tensor *output = param.out_; Tensor *output = param.output_;
float activation_scale = param.activation_scale_->data<float>()[0]; float activation_scale = param.activation_scale_->data<float>()[0];
float weight_scale = param.weight_scale_; float weight_scale = param.weight_scale_;
const int32_t *x = input->data<const int32_t>(); const int32_t *x = input->data<const int32_t>();
...@@ -43,11 +43,15 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) { ...@@ -43,11 +43,15 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
size_t loop = size >> 4; size_t loop = size >> 4;
size_t remain = size & 0xF; size_t remain = size & 0xF;
float32x4_t s = vdupq_n_f32(scale); float32x4_t s = vdupq_n_f32(scale);
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) { for (size_t i = 0; i < loop; ++i) {
int32x4_t r0 = vld1q_s32(x); const int32_t *local_x = x + (i << 4);
int32x4_t r1 = vld1q_s32(x + 4); float *local_y = y + (i << 4);
int32x4_t r2 = vld1q_s32(x + 8); int32x4_t r0 = vld1q_s32(local_x);
int32x4_t r3 = vld1q_s32(x + 12); int32x4_t r1 = vld1q_s32(local_x + 4);
int32x4_t r2 = vld1q_s32(local_x + 8);
int32x4_t r3 = vld1q_s32(local_x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0); float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1); float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2); float32x4_t f2 = vcvtq_f32_s32(r2);
...@@ -56,14 +60,14 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) { ...@@ -56,14 +60,14 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
f1 = vmulq_f32(f1, s); f1 = vmulq_f32(f1, s);
f2 = vmulq_f32(f2, s); f2 = vmulq_f32(f2, s);
f3 = vmulq_f32(f3, s); f3 = vmulq_f32(f3, s);
vst1q_f32(y, f0); vst1q_f32(local_y, f0);
vst1q_f32(y + 4, f1); vst1q_f32(local_y + 4, f1);
vst1q_f32(y + 8, f2); vst1q_f32(local_y + 8, f2);
vst1q_f32(y + 12, f3); vst1q_f32(local_y + 12, f3);
x += 16;
y += 16;
} }
size = remain; size = remain;
x += (loop << 4);
y += (loop << 4);
#endif #endif
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
y[i] = x[i] * scale; y[i] = x[i] * scale;
......
...@@ -21,15 +21,15 @@ limitations under the License. */ ...@@ -21,15 +21,15 @@ limitations under the License. */
#include <arm_neon.h> #include <arm_neon.h>
#ifndef __aarch64__ #ifndef __aarch64__
float32_t vmaxvq_f32(float32x4_t r) { inline float32_t vmaxvq_f32(float32x4_t r) {
float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r));
return vget_lane_f32(vpmax_f32(v, v), 0); return vget_lane_f32(vpmax_f32(v, v), 0);
} }
#endif #endif
int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); }
int32x4_t vrnd_away_zero(float32x4_t r) { inline int32x4_t vrnd_away_zero(float32x4_t r) {
float32x4_t plus = vdupq_n_f32(0.5); float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0); float32x4_t zero = vdupq_n_f32(0);
...@@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) { ...@@ -40,7 +40,7 @@ int32x4_t vrnd_away_zero(float32x4_t r) {
return ret; return ret;
} }
int32x4_t vrnd_to_even(float32x4_t r) { inline int32x4_t vrnd_to_even(float32x4_t r) {
#if 0 #if 0
int32x4_t ret; int32x4_t ret;
float value[4]; float value[4];
...@@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) { ...@@ -84,7 +84,6 @@ int32x4_t vrnd_to_even(float32x4_t r) {
return rnd; return rnd;
#endif #endif
} }
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) { ...@@ -127,6 +126,7 @@ static float find_abs_max(const Tensor *input) {
return max_abs; return max_abs;
} }
#ifdef __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale, static void quantize_round_to_even(const Tensor *input, const float scale,
Tensor *output) { Tensor *output) {
const float *x = input->data<const float>(); const float *x = input->data<const float>();
...@@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -188,7 +188,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>(); int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel(); size_t size = input->numel();
#ifdef defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4; size_t loop = size >> 4;
size_t remain = size & 0xF; size_t remain = size & 0xF;
...@@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, ...@@ -224,7 +224,7 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
y += (loop << 4); y += (loop << 4);
#endif #endif
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
y[i] = trunc(x[i] * scale); y[i] = static_cast<int8_t>(x[i] * scale);
} }
} }
...@@ -272,6 +272,508 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale, ...@@ -272,6 +272,508 @@ static void quantize_round_to_nearest(const Tensor *input, const float scale,
y[i] = round(x[i] * scale); y[i] = round(x[i] * scale);
} }
} }
#else // __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val,
Tensor *output) {}
static void quantize_round_to_zero(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int input_spatial_size = input_h * input_w;
int output_spatial_size = output_h * output_w;
const float *x = input->data<float>();
int8_t *y = output->mutable_data<int8_t>();
// valid area start
int start = paddings[0] * output_w + paddings[1];
for (int batch = 0; batch < input->dims()[0]; ++batch) {
#pragma omp parallel for
for (int c = 0; c < channels - 3; c += 4) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
const float *input1 = input0 + input_spatial_size;
const float *input2 = input1 + input_spatial_size;
const float *input3 = input2 + input_spatial_size;
size_t offset = (batch * channels + c) * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
int loop = start >> 4;
int remain = start & 0xF;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"vst1.32 {q0}, [%[y1]]! \n"
"vst1.32 {q0}, [%[y2]]! \n"
"vst1.32 {q0}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"vst1.32 {d0}, [%[y1]]! \n"
"vst1.32 {d0}, [%[y2]]! \n"
"vst1.32 {d0}, [%[y3]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
for (int h = 0; h < input_h; ++h) {
const float *x0 = input0 + h * input_w;
const float *x1 = input1 + h * input_w;
const float *x2 = input2 + h * input_w;
const float *x3 = input3 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
int remain_steps = remain;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"vst1.32 {q9}, [%[y0]]! \n"
"vst1.32 {q10}, [%[y1]]! \n"
"vst1.32 {q11}, [%[y2]]! \n"
"vst1.32 {q12}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]] \n"
"vld1.32 {q3, q4}, [%[x1]] \n"
"vld1.32 {q5, q6}, [%[x2]] \n"
"vld1.32 {q7, q8}, [%[x3]] \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vst1.32 {d20}, [%[y1]]! \n"
"vst1.32 {d22}, [%[y2]]! \n"
"vst1.32 {d24}, [%[y3]]! \n"
"vmov.32 d18, d19 \n"
"vmov.32 d20, d21 \n"
"vmov.32 d22, d23 \n"
"vmov.32 d24, d25 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vst1.32 {d20[0]}, [%[y1]]! \n"
"vst1.32 {d22[0]}, [%[y2]]! \n"
"vst1.32 {d24[0]}, [%[y3]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"vext.32 d20, d20, d20, #1 \n"
"vext.32 d22, d22, d22, #1 \n"
"vext.32 d24, d24, d24, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vst1.16 {d20[0]}, [%[y1]]! \n"
"vst1.16 {d22[0]}, [%[y2]]! \n"
"vst1.16 {d24[0]}, [%[y3]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"vext.16 d20, d20, d20, #1 \n"
"vext.16 d22, d22, d22, #1 \n"
"vext.16 d24, d24, d24, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"vst1.8 {d20[0]}, [%[y1]]! \n"
"vst1.8 {d22[0]}, [%[y2]]! \n"
"vst1.8 {d24[0]}, [%[y3]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3),
[y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [scale] "r"(scale)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
asm volatile(
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble store_pad_2w_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"store_pad_2w_%=: \n"
"cmp %[pad_remain], #2 \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
}
for (int c = (channels & 0xFFFC); c < channels; ++c) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
size_t offset = (batch * channels + c) * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int loop = start >> 4;
int remain = start & 0xF;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
for (int h = 0; h < input_h; ++h) {
const float *x0 = input0 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"vst1.32 {q9}, [%[y0]]! \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble start_pad_%= \n"
"vldm %[x0], {d2-d9} \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vcvt.s32.f32 q1, q3 \n"
"vcvt.s32.f32 q2, q4 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vmov.32 d18, d19 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt start_pad_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"start_pad_%=: \n"
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble pad_remain_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"pad_remain_%=: \n"
"cmp %[pad_remain], #2 \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop),
[remain] "+r"(remain), [pad_loop] "+r"(pad_loop),
[pad_remain] "+r"(pad_remain)
: [scale] "r"(scale), [val] "r"(padding_val)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q9");
}
}
}
}
#endif // __aarch64__
#endif // ARM_NEON
template <> template <>
bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) { bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
...@@ -280,10 +782,10 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) { ...@@ -280,10 +782,10 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
template <> template <>
void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
float max_abs = 0.f;
const Tensor *input = param.input_; const Tensor *input = param.input_;
Tensor *output = param.out_; Tensor *output = param.output_;
Tensor *output_scale = param.online_scale_; Tensor *output_scale = param.online_scale_;
float max_abs = 0.f;
if (param.is_static_) { if (param.is_static_) {
max_abs = param.static_scale_; max_abs = param.static_scale_;
} else { } else {
...@@ -293,15 +795,19 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) { ...@@ -293,15 +795,19 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
// only support int8 currently // only support int8 currently
float scale = 127 / max_abs; float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs; param.online_scale_->mutable_data<float>()[0] = max_abs;
const auto &paddings = param.paddings_;
// std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t padding_val = 0;
switch (param.round_type_) { switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN: case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, scale, output); quantize_round_to_even(input, scale, paddings, padding_val, output);
break; break;
case ROUND_NEAREST_TOWARDS_ZERO: case ROUND_NEAREST_TOWARDS_ZERO:
quantize_round_to_zero(input, scale, output); quantize_round_to_zero(input, scale, paddings, padding_val, output);
break; break;
case ROUND_NEAREST_AWAY_ZERO: case ROUND_NEAREST_AWAY_ZERO:
quantize_round_to_nearest(input, scale, output); quantize_round_to_nearest(input, scale, paddings, padding_val, output);
break; break;
default: default:
LOG(kLOG_ERROR) << "round type is not supported."; LOG(kLOG_ERROR) << "round type is not supported.";
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -17,18 +17,19 @@ limitations under the License. */ ...@@ -17,18 +17,19 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/pad.h" #include "operators/math/pad.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/math/winograd/winograd_transform.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
inline void ConvBasic(const ConvParam<CPU> &param) { inline void GemmConv(const ConvParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
...@@ -38,10 +39,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) { ...@@ -38,10 +39,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
const std::vector<int> paddings = param.Paddings(); const std::vector<int> paddings = param.Paddings();
const std::vector<int> dilations = param.Dilations(); const std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims())); std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims())); std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
...@@ -82,6 +80,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) { ...@@ -82,6 +80,7 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
math::Vol2ColFunctor<CPU, Itype> vol2col; math::Vol2ColFunctor<CPU, Itype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col;
const int batch_size = static_cast<int>(input->dims()[0]);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
...@@ -99,7 +98,6 @@ inline void ConvBasic(const ConvParam<CPU> &param) { ...@@ -99,7 +98,6 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
std::vector<int>{paddings[0], paddings[1], paddings[0], std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]}, paddings[1]},
&col); &col);
} else if (data_dim == 3U) { } else if (data_dim == 3U) {
// vol2col // vol2col
vol2col(in_slice, dilations, strides, paddings, &col); vol2col(in_slice, dilations, strides, paddings, &col);
...@@ -116,25 +114,86 @@ inline void ConvBasic(const ConvParam<CPU> &param) { ...@@ -116,25 +114,86 @@ inline void ConvBasic(const ConvParam<CPU> &param) {
} }
} }
template <typename P> template <int tile, int kernel>
void ConvCompute(const ConvParam<CPU> &param) { inline void WinogradConv3x3(const ConvParam<CPU> &param) {
if (param.Input()->type() == typeid(int8_t)) { const Tensor *input = param.Input();
ConvBasic<int8_t, int32_t>(param); const Tensor *filter = param.Filter();
} else { Tensor *output = param.Output();
if (param.Groups() == param.Input()->dims()[1] && output->mutable_data<float>();
param.Input()->dims()[1] == param.Output()->dims()[1] && int batch_size = input->dims()[0];
param.Filter()->dims()[2] == param.Filter()->dims()[3] && int groups = param.Groups();
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { const std::vector<int> &paddings = param.Paddings();
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); auto winograd_pad = [&](int width, int pad) {
} else if (param.Groups() == param.Input()->dims()[1] && int output_tile = tile - kernel + 1;
param.Input()->dims()[1] == param.Output()->dims()[1] && // int tiles = (width + pad - kernel) / output_tile + 1;
param.Filter()->dims()[2] == param.Filter()->dims()[3] && // return (tiles - 1) * output_tile + tile - width;
param.Filter()->dims()[2] == 3) { int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile;
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), return pad_width + tile - width;
param.Filter(), nullptr, param.Output(), false); };
math::PadFunctor<CPU, float> pad;
Tensor input_pad;
framework::Tensor transformed_input;
for (int i = 0; i < batch_size; ++i) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
// int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]);
// int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]);
int pad_bottom = paddings[0];
int pad_right = paddings[1];
if (paddings[0] || paddings[1] || pad_bottom || pad_right) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += paddings[0] + pad_bottom;
pad_shape[3] += paddings[1] + pad_right;
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], pad_bottom, paddings[1], pad_right,
&input_pad);
} else {
input_pad = in_batch;
}
// tile input and transform
math::winograd_transform_input<tile, kernel>(input_pad, &transformed_input);
// caculate output
math::winograd_transform_output<tile, kernel>(transformed_input, *filter,
output);
}
}
template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
Tensor *output = param.Output();
output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch);
} else { } else {
ConvBasic<float, float>(param); // math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
PADDLE_MOBILE_THROW_EXCEPTION(
"Depthwise conv with generic strides has not been implemented.");
} }
} }
} }
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) { void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
......
...@@ -15,10 +15,9 @@ limitations under the License. */ ...@@ -15,10 +15,9 @@ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef DEPTHWISECONV_OP
#pragma once #pragma once
#include <operators/math/depthwise_conv_3x3.h>
#include <vector> #include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -44,7 +43,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) { ...@@ -44,7 +43,7 @@ void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Bias, false); Bias, false);
} else { } else {
ConvBasic<float, float>(param); GemmConv<float, float>(param);
} }
} }
......
...@@ -16,13 +16,15 @@ limitations under the License. */ ...@@ -16,13 +16,15 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) { void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
......
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#include "framework/ddim.h" #include "framework/ddim.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/math/conv_func.h" #include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv3x3.h"
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include "operators/math/math_function.h" #include "operators/math/math_function.h"
#include "operators/math/vol2col.h" #include "operators/math/vol2col.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,18 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/depthwise_conv3x3.h"
#include <vector>
#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <vector>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias, void DepthwiseConv3x3(const framework::Tensor *input,
Tensor *output, bool if_bias) { const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias) {
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
const int input_height = input->dims()[2]; const int input_height = input->dims()[2];
...@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -67,12 +71,12 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height); hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width); wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0); hstart = std::max(hstart, 0);
wstart = max(wstart, 0); wstart = std::max(wstart, 0);
hend = min(hend, input_height); hend = std::min(hend, input_height);
wend = min(wend, input_width); wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart;
...@@ -244,12 +248,14 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -244,12 +248,14 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
} }
} }
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s1p1(const framework::Tensor *input,
Tensor *output, Tensor *bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->mutable_data<float>();
const float *bias_data; const float *bias_data;
if (if_bias) { if (if_bias) {
bias_data = bias->data<float>(); bias_data = bias->data<float>();
...@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -517,9 +523,12 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -1059,9 +1068,12 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
/// w!=h not fix /// w!=h not fix
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
...@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1107,12 +1119,12 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height); hend = std::min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width); wend = std::min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0); hstart = std::max(hstart, 0);
wstart = max(wstart, 0); wstart = std::max(wstart, 0);
hend = min(hend, input_height); hend = std::min(hend, input_height);
wend = min(wend, input_width); wend = std::min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; pos3 = input_data + (hstart + 2) * input_width + wstart;
...@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1258,8 +1270,10 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
Tensor *output, Tensor bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1463,9 +1477,12 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
Tensor *output, const Tensor *new_scale, const framework::Tensor *filter,
const Tensor *new_bias, bool if_relu) { framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
// #ifdef _OPENMP // #ifdef _OPENMP
// const float *newscale_data = new_scale->data<float>(); // const float *newscale_data = new_scale->data<float>();
...@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1886,8 +1903,10 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
#endif #endif
} }
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
Tensor *output, Tensor bias, bool if_bias) { const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile {
namespace operators {
namespace math {
void DepthwiseConv3x3(const framework::Tensor *input,
const std::vector<int> &strides,
const std::vector<int> &paddings,
const framework::Tensor *filter, framework::Tensor *bias,
framework::Tensor *output, bool if_bias);
void DepthwiseConv3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output,
const framework::Tensor *new_scale,
const framework::Tensor *new_bias,
bool if_relu);
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
bool if_bias);
// TODO(hjchen2) need to be implemented
// template<typename Itype, typename Otype>
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output);
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/depthwise_conv3x3.h"
namespace paddle_mobile {
namespace operators {
namespace math {
// template<>
// void DepthwiseConv3x3<int8_t, int32_t>(
// const framework::Tensor *input, const framework::Tensor *filter,
// const std::vector<int> &strides, framework::Tensor *output) {
// PADDLE_MOBILE_THROW_EXCEPTION(
// "Depthwise conv with generic strides has not been implemented.");
// }
template <>
void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for
for (int g = 0; g < input_c; ++g) {
const int8_t* input_ptr = input_data + g * image_size;
const int8_t* filter_ptr = filter_data + g * 9;
int32_t* output_ptr = out_data + g * out_image_size;
int loops = (input_w - 2) / 6;
int remain = input_w - 2 - loops * 6;
for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int32_t* output_ptr2 = output_ptr1 + output_w;
int32_t* output_ptr3 = output_ptr2 + output_w;
int loop = loops;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_4h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vld1.32 {d10}, [%[input_ptr4]], r0 \n"
"vld1.32 {d11}, [%[input_ptr5]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
// store row 2
"vst1.32 {d28-d30}, [%[output_ptr2]]! \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 3
"vst1.32 {d20-d22}, [%[output_ptr3]]! \n"
"subs %[loop], #1 \n"
"bne loop_4h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr1]] \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr2]] \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vmull.s16 q5, d14, d0 \n"
"vmlal.s16 q5, d16, d1 \n"
"vmlal.s16 q5, d18, d2 \n"
"vld1.32 {d9}, [%[input_ptr4]] \n"
"vmull.s16 q6, d15, d0 \n"
"vmlal.s16 q6, d17, d1 \n"
"vmlal.s16 q6, d19, d2 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"vmlal.s16 q5, d14, d3 \n"
"vmlal.s16 q5, d16, d4 \n"
"vmlal.s16 q5, d18, d5 \n"
"vld1.32 {d9}, [%[input_ptr5]] \n"
"vmlal.s16 q6, d15, d3 \n"
"vmlal.s16 q6, d17, d4 \n"
"vmlal.s16 q6, d19, d5 \n"
"vmovl.s8 q7, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q8, d9 \n"
"vext.s8 d9, d9, d9, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmlal.s16 q5, d14, d6 \n"
"vmlal.s16 q5, d16, d7 \n"
"vmlal.s16 q5, d18, d8 \n"
"vmlal.s16 q6, d15, d6 \n"
"vmlal.s16 q6, d17, d7 \n"
"vmlal.s16 q6, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_4h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"vst1.32 {q5}, [%[output_ptr3]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d12[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_4h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"vst1.32 {d10}, [%[output_ptr3]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d11[0]}, [%[output_ptr3]]! \n"
"b end_%= \n"
"store_4h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d10[0]}, [%[output_ptr3]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop)
: [remain] "r"(remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
}
// remain height
int start_h = (input_h - 2) & 0xFFFC;
for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int loop = loops;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_2h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
// store row 1
"vst1.32 {d24-d26}, [%[output_ptr1]]! \n"
"subs %[loop], #1 \n"
"bne loop_2h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld1.32 {d9}, [%[input_ptr3]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_2h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_2h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"b end_%= \n"
"store_2h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[loop] "+r"(loop)
: [remain] "r"(remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
}
start_h = (input_h - 2) & 0xFFFE;
if (start_h < input_h - 2) {
const int8_t* input_ptr0 = input_ptr + start_h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
int32_t* output_ptr0 = output_ptr + start_h * output_w;
int loop = loops;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
"vld1.32 {d10}, [%[input_ptr1]], r0 \n"
"vld1.32 {d11}, [%[input_ptr2]], r0 \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
// store row 0, reuse q10/q11
"vst1.32 {d20-d22}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {d9}, [%[input_ptr0]] \n"
"vld1.32 {d10}, [%[input_ptr1]] \n"
"vld1.32 {d11}, [%[input_ptr2]] \n"
"vext.s8 d12, d9, d9, #1 \n"
"vext.s8 d13, d9, d9, #2 \n"
"vmovl.s8 q7, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d12, d10, d10, #1 \n"
"vext.s8 d13, d10, d10, #2 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vext.s8 d12, d11, d11, #1 \n"
"vext.s8 d13, d11, d11, #2 \n"
"vmovl.s8 q7, d11 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop)
: [remain] "r"(remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "r0");
}
}
#endif // __aarch64__
}
template <>
void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
#if __aarch64__
// TODO(hjchen2)
#else
#pragma omp parallel for
for (int g = 0; g < input_c; ++g) {
const int8_t* input_ptr = input_data + g * image_size;
const int8_t* filter_ptr = filter_data + g * 9;
int32_t* output_ptr = out_data + g * out_image_size;
int loops = output_w / 6;
int remain = output_w - loops * 6;
for (int h = 0; h < input_h - 6 /*(input_h - 1) - 5*/; h += 6) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
const int8_t* input_ptr6 = input_ptr5 + input_w;
int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int32_t* output_ptr2 = output_ptr1 + output_w;
int loop = loops;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #12 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_3h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
// store row 0, reuse q11/q12
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"vmull.s16 q13, d16, d0 \n"
"vmlal.s16 q13, d18, d1 \n"
"vmlal.s16 q13, d20, d2 \n"
"vmull.s16 q14, d17, d0 \n"
"vmlal.s16 q14, d19, d1 \n"
"vmlal.s16 q14, d21, d2 \n"
"vld2.8 {d10, d11}, [%[input_ptr3]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr4]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr5]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q13, d16, d3 \n"
"vmlal.s16 q13, d18, d4 \n"
"vmlal.s16 q13, d20, d5 \n"
"vmlal.s16 q14, d17, d3 \n"
"vmlal.s16 q14, d19, d4 \n"
"vmlal.s16 q14, d21, d5 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q13, d16, d6 \n"
"vmlal.s16 q13, d18, d7 \n"
"vmlal.s16 q13, d20, d8 \n"
"vmlal.s16 q14, d17, d6 \n"
"vmlal.s16 q14, d19, d7 \n"
"vmlal.s16 q14, d21, d8 \n"
// store row 1
"vst1.32 {d26-d28}, [%[output_ptr1]]! \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr6]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
// store row 2
"vst1.32 {d22-d24}, [%[output_ptr2]]! \n"
"subs %[loop], #1 \n"
"bne loop_3h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld2.8 {d10, d11}, [%[input_ptr0]] \n"
"vld2.8 {d12, d13}, [%[input_ptr1]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmull.s16 q10, d14, d0 \n"
"vmlal.s16 q10, d16, d1 \n"
"vmlal.s16 q10, d18, d2 \n"
"vmull.s16 q11, d15, d0 \n"
"vmlal.s16 q11, d17, d1 \n"
"vmlal.s16 q11, d19, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q10, d14, d3 \n"
"vmlal.s16 q10, d16, d4 \n"
"vmlal.s16 q10, d18, d5 \n"
"vmlal.s16 q11, d15, d3 \n"
"vmlal.s16 q11, d17, d4 \n"
"vmlal.s16 q11, d19, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr2]] \n"
"vld2.8 {d12, d13}, [%[input_ptr3]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q10, d14, d6 \n"
"vmlal.s16 q10, d16, d7 \n"
"vmlal.s16 q10, d18, d8 \n"
"vmlal.s16 q11, d15, d6 \n"
"vmlal.s16 q11, d17, d7 \n"
"vmlal.s16 q11, d19, d8 \n"
"vmull.s16 q12, d14, d0 \n"
"vmlal.s16 q12, d16, d1 \n"
"vmlal.s16 q12, d18, d2 \n"
"vmull.s16 q13, d15, d0 \n"
"vmlal.s16 q13, d17, d1 \n"
"vmlal.s16 q13, d19, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q12, d14, d3 \n"
"vmlal.s16 q12, d16, d4 \n"
"vmlal.s16 q12, d18, d5 \n"
"vmlal.s16 q13, d15, d3 \n"
"vmlal.s16 q13, d17, d4 \n"
"vmlal.s16 q13, d19, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr4]] \n"
"vld2.8 {d12, d13}, [%[input_ptr5]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q12, d14, d6 \n"
"vmlal.s16 q12, d16, d7 \n"
"vmlal.s16 q12, d18, d8 \n"
"vmlal.s16 q13, d15, d6 \n"
"vmlal.s16 q13, d17, d7 \n"
"vmlal.s16 q13, d19, d8 \n"
"vmull.s16 q14, d14, d0 \n"
"vmlal.s16 q14, d16, d1 \n"
"vmlal.s16 q14, d18, d2 \n"
"vmull.s16 q15, d15, d0 \n"
"vmlal.s16 q15, d17, d1 \n"
"vmlal.s16 q15, d19, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d12 \n"
"vmovl.s8 q8, d13 \n"
"vmlal.s16 q14, d14, d3 \n"
"vmlal.s16 q14, d16, d4 \n"
"vmlal.s16 q14, d18, d5 \n"
"vmlal.s16 q15, d15, d3 \n"
"vmlal.s16 q15, d17, d4 \n"
"vmlal.s16 q15, d19, d5 \n"
"vld2.8 {d10, d11}, [%[input_ptr6]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q9, d9 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmlal.s16 q14, d14, d6 \n"
"vmlal.s16 q14, d16, d7 \n"
"vmlal.s16 q14, d18, d8 \n"
"vmlal.s16 q15, d15, d6 \n"
"vmlal.s16 q15, d17, d7 \n"
"vmlal.s16 q15, d19, d8 \n"
"cmp %[remain], #4 \n"
"blt store_3h2w_%= \n"
"vst1.32 {q10}, [%[output_ptr0]]! \n"
"vst1.32 {q12}, [%[output_ptr1]]! \n"
"vst1.32 {q14}, [%[output_ptr2]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d26[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d30[0]}, [%[output_ptr2]]! \n"
"b end_%= \n"
"store_3h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_3h1w_%= \n"
"vst1.32 {d20}, [%[output_ptr0]]! \n"
"vst1.32 {d24}, [%[output_ptr1]]! \n"
"vst1.32 {d28}, [%[output_ptr2]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d21[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d25[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d29[0]}, [%[output_ptr2]]! \n"
"b end_%= \n"
"store_3h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [input_ptr6] "+r"(input_ptr6),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop)
: [remain] "r"(remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
}
int start_h = (output_h / 3) * 6;
for (int h = start_h; h < input_h - 2 /*(input_h - 1) - 1*/; h += 2) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w;
int loop = loops;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
"vmovl.s8 q15, d1 \n"
"vdup.s16 d0, d28[0] \n"
"vdup.s16 d1, d28[1] \n"
"vdup.s16 d2, d28[2] \n"
"vdup.s16 d3, d28[3] \n"
"vdup.s16 d4, d29[0] \n"
"vdup.s16 d5, d29[1] \n"
"vdup.s16 d6, d29[2] \n"
"vdup.s16 d7, d29[3] \n"
"vdup.s16 d8, d30[0] \n"
:
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #12 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
"vld2.8 {d12, d13}, [%[input_ptr1]], r0 \n"
"vld2.8 {d14, d15}, [%[input_ptr2]], r0 \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
// store row 0
"vst1.32 {d22-d24}, [%[output_ptr0]]! \n"
"subs %[loop], #1 \n"
"bne loop_1h6w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld2.8 {d10, d11}, [%[input_ptr0]] \n"
"vld2.8 {d12, d13}, [%[input_ptr1]] \n"
"vld2.8 {d14, d15}, [%[input_ptr2]] \n"
"vext.s8 d9, d10, d10, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d10 \n"
"vmovl.s8 q9, d11 \n"
"vmull.s16 q11, d16, d0 \n"
"vmlal.s16 q11, d18, d1 \n"
"vmlal.s16 q11, d20, d2 \n"
"vmull.s16 q12, d17, d0 \n"
"vmlal.s16 q12, d19, d1 \n"
"vmlal.s16 q12, d21, d2 \n"
"vext.s8 d9, d12, d12, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d12 \n"
"vmovl.s8 q9, d13 \n"
"vmlal.s16 q11, d16, d3 \n"
"vmlal.s16 q11, d18, d4 \n"
"vmlal.s16 q11, d20, d5 \n"
"vmlal.s16 q12, d17, d3 \n"
"vmlal.s16 q12, d19, d4 \n"
"vmlal.s16 q12, d21, d5 \n"
"vext.s8 d9, d14, d14, #1 \n"
"vmovl.s8 q10, d9 \n"
"vmovl.s8 q8, d14 \n"
"vmovl.s8 q9, d15 \n"
"vmlal.s16 q11, d16, d6 \n"
"vmlal.s16 q11, d18, d7 \n"
"vmlal.s16 q11, d20, d8 \n"
"vmlal.s16 q12, d17, d6 \n"
"vmlal.s16 q12, d19, d7 \n"
"vmlal.s16 q12, d21, d8 \n"
"cmp %[remain], #4 \n"
"blt store_1h2w_%= \n"
"vst1.32 {q11}, [%[output_ptr0]]! \n"
"cmp %[remain], #5 \n"
"blt end_%= \n"
"vst1.32 {d24[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1h1w_%= \n"
"vst1.32 {d22}, [%[output_ptr0]]! \n"
"cmp %[remain], #3 \n"
"blt end_%= \n"
"vst1.32 {d23[0]}, [%[output_ptr0]]! \n"
"b end_%= \n"
"store_1h1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop)
: [remain] "r"(remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "r0");
}
}
#endif // __aarch64__
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
...@@ -26,79 +26,6 @@ limitations under the License. */ ...@@ -26,79 +26,6 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
/*int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;*/
/*
// 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
for (i = 0; i < m - m_tail; i += MR) {
for (j = 0; j < k; ++j) {
Aij = &A(i, j);
*buffer++ = *Aij;
*buffer++ = *(Aij + 1);
*buffer++ = *(Aij + 2);
*buffer++ = *(Aij + 3);
}
}
if (m_tail != 0) {
for (j = 0; j < k; ++j) {
Aij = &A(m - m_tail, j);
for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i);
}
for (i = m_tail; i < MR; ++i) {
*buffer++ = 0;
}
}
}
}
// 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j);
Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2);
Bj3 = &B(0, j + 3);
for (i = 0; i < k; ++i) {
*buffer++ = *Bj++;
*buffer++ = *Bj1++;
*buffer++ = *Bj2++;
*buffer++ = *Bj3++;
}
}
if (n_tail != 0) {
for (i = 0; i < k; ++i) {
for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j);
}
for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0;
}
}
}
}
*/
// 将A矩阵分块复制到连续内存(RowMajor) // 将A矩阵分块复制到连续内存(RowMajor)
void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
......
...@@ -22,6 +22,70 @@ namespace paddle_mobile { ...@@ -22,6 +22,70 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void ExtractToImg(const float *im_data, float *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extract = (end_width - start_width + stride_w - 1) / stride_w;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(float));
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 4];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
* col = * col =
...@@ -363,7 +427,27 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -363,7 +427,27 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
col_data += 9 * oosize; col_data += 9 * oosize;
im_data += isize * isize; im_data += isize * isize;
} }
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else { } else {
#endif
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
...@@ -382,25 +466,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -382,25 +466,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
} }
} }
} }
} #if __ARM_NEON
#else
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
} }
#endif #endif
} }
...@@ -489,21 +555,26 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()( ...@@ -489,21 +555,26 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>(); const int8_t *im_data = im.data<int8_t>();
int8_t *col_data = col->data<int8_t>(); int8_t *col_data = col->mutable_data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0 // pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t)); memset(col_data, 0, col->numel() * sizeof(int8_t));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) { for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size;
int8_t *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) { for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) { for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height, ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_width, padding[0], padding[1], stride[0], stride[1], col_height, col_width, padding[0], padding[1], stride[0],
kh, kw); stride[1], kh, kw);
col_data += col_height * col_width; local_col_data += col_spatial_size;
} }
} }
im_data += im_height * im_width;
} }
} else { } else {
#endif #endif
......
...@@ -21,10 +21,12 @@ namespace math { ...@@ -21,10 +21,12 @@ namespace math {
template <typename T> template <typename T>
class PadFunctor<CPU, T> { class PadFunctor<CPU, T> {
public: public:
void operator()(const framework::Tensor &input, const int pad_h, void operator()(const framework::Tensor &input, const int pad_top,
const int pad_w, framework::Tensor *output) { const int pad_bottom, const int pad_left, const int pad_right,
framework::Tensor *output) {
const T *in_data = input.data<T>(); const T *in_data = input.data<T>();
T *out_data = output->mutable_data<T>(); T *out_data = output->mutable_data<T>();
// should check output shape is valid for such pad parameters
const framework::DDim &input_shape = input.dims(); const framework::DDim &input_shape = input.dims();
const framework::DDim &output_shape = output->dims(); const framework::DDim &output_shape = output->dims();
// fill output with 0 // fill output with 0
...@@ -32,13 +34,13 @@ class PadFunctor<CPU, T> { ...@@ -32,13 +34,13 @@ class PadFunctor<CPU, T> {
// should make sure the shape of output is match with input // should make sure the shape of output is match with input
for (int i = 0; i < input_shape[0]; ++i) { for (int i = 0; i < input_shape[0]; ++i) {
for (int c = 0; c < input_shape[1]; ++c) { for (int c = 0; c < input_shape[1]; ++c) {
out_data += pad_h * output_shape[3]; out_data += pad_top * output_shape[3];
for (int h = 0; h < input_shape[2]; ++h) { for (int h = 0; h < input_shape[2]; ++h) {
memcpy(out_data + pad_w, in_data, sizeof(T) * input_shape[3]); memcpy(out_data + pad_left, in_data, sizeof(T) * input_shape[3]);
out_data += output_shape[3]; out_data += output_shape[3];
in_data += input_shape[3]; in_data += input_shape[3];
} }
out_data += pad_h * output_shape[3]; out_data += pad_bottom * output_shape[3];
} }
} }
} }
......
...@@ -22,8 +22,9 @@ namespace math { ...@@ -22,8 +22,9 @@ namespace math {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class PadFunctor { class PadFunctor {
public: public:
void operator()(const framework::Tensor &input, const int pad_h, void operator()(const framework::Tensor &input, const int pad_top,
const int pad_w, framework::Tensor *output); const int pad_bottom, const int pad_left, const int pad_right,
framework::Tensor *output);
}; };
} // namespace math } // namespace math
......
...@@ -12,40 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,40 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef CONV_OP
#pragma once #pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h" #include "framework/tensor.h"
#include "operators/math/conv_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
using framework::Tensor;
using std::max; template <int tile, int kernel>
using std::min; void winograd_transform_weight(const framework::Tensor &weight,
using std::vector; framework::Tensor *output);
void DepthwiseConv3x3(const Tensor *input, vector<int> strides, template <int tile, int kernel>
vector<int> paddings, const Tensor *filter, Tensor *bias, void winograd_transform_input(const framework::Tensor &input,
Tensor *output, bool if_bias); framework::Tensor *output);
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias); template <int tile, int kernel>
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, void winograd_transform_output(const framework::Tensor &input,
Tensor *output, const Tensor *new_scale, const framework::Tensor &weight,
const Tensor *new_bias, bool if_relu); framework::Tensor *output);
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu);
void DepthwiseConv3x3s2p0(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// project.
#ifdef CONV_OP
#ifndef __aarch64__
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) {
/*
* w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9)
* w2 = ((g0 + g2) - g1) * (-2.0 / 9)
* w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90)
* w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90)
* w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2
*/
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
// package weight into [roundup(out_channel/4), 64, in_channel, 4] tiles
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
framework::DDim transformed_shape = framework::make_ddim(
std::vector<int>{(out_channel + 3) / 4, 64, in_channel, 4});
float *trans_outptr = output->mutable_data<float>(transformed_shape);
memset(trans_outptr, 0, output->numel() * sizeof(float));
const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
const float *inptr = weight.data<float>();
int remain_start = out_channel & 0xFFFC;
#if 0
remain_start = 0;
#else
#pragma omp parallel for
for (int oc = 0; oc < out_channel - 3; oc += 4) {
float gw[96]; // gw[3][8][4]
const float *inptr0 = inptr + oc * in_channel * 9;
const float *inptr1 = inptr + (oc + 1) * in_channel * 9;
const float *inptr2 = inptr + (oc + 2) * in_channel * 9;
const float *inptr3 = inptr + (oc + 3) * in_channel * 9;
// oc * 64 * in_channel
float *outptr = trans_outptr + ((oc * in_channel) << 6);
for (int ic = 0; ic < in_channel; ++ic) {
float *gw_ptr = gw;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #24 \n"
"vld1.32 {d2-d5}, [%[inptr0]], r0 \n"
"vld1.32 {d6-d9}, [%[inptr1]], r0 \n"
"vld1.32 {d10-d13}, [%[inptr2]], r0 \n"
"vld1.32 {d14-d17}, [%[inptr3]], r0 \n"
"vtrn.32 q1, q3 \n"
"vtrn.32 q2, q4 \n"
"vtrn.32 q5, q7 \n"
"vtrn.32 q6, q8 \n"
"vswp.32 d3, d10 \n"
"vswp.32 d7, d14 \n"
"vswp.32 d5, d12 \n"
"vswp.32 d9, d16 \n"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q1, q5 \n"
"vadd.f32 q10, q9, q3 \n"
"vsub.f32 q11, q9, q3 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q3, d0[0] \n" // 2 * g1
"vmul.f32 q11, q5, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q5, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d10-d11}, [%[gw_ptr]]! \n"
// q7: g0, q2: g1, q4: g2
"vst1.32 {d14-d15}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q7, q4 \n"
"vadd.f32 q10, q9, q2 \n"
"vsub.f32 q11, q9, q2 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q7, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q2, d0[0] \n" // 2 * g1
"vmul.f32 q11, q4, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q7, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q4, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d8-d9}, [%[gw_ptr]]! \n"
"mov r0, #12 \n"
"vld1.32 {d2-d3}, [%[inptr0]], r0 \n"
"vld1.32 {d6-d7}, [%[inptr1]], r0 \n"
"vld1.32 {d10-d11}, [%[inptr2]], r0 \n"
"vld1.32 {d14-d15}, [%[inptr3]], r0 \n"
"vtrn.32 q1, q3 \n"
"vtrn.32 q5, q7 \n"
"vswp.32 d3, d10 \n"
"vswp.32 d7, d14 \n"
// q1: g0, q3: g1, q5: g2
"vst1.32 {d2-d3}, [%[gw_ptr]]! \n"
"vadd.f32 q9, q1, q5 \n"
"vadd.f32 q10, q9, q3 \n"
"vsub.f32 q11, q9, q3 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[gw_ptr]]! \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[gw_ptr]]! \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q3, d0[0] \n" // 2 * g1
"vmul.f32 q11, q5, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vadd.f32 q12, q5, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[gw_ptr]]! \n"
"vst1.32 {d10-d11}, [%[gw_ptr]]! \n"
: [gw_ptr] "+r"(gw_ptr), [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *gw_ptr0 = gw;
float *gw_ptr1 = gw + 32;
float *gw_ptr2 = gw + 64;
float *outptr0 = outptr + (ic << 2); // ic * 4
int steps = (in_channel << 2) * sizeof(float); // in_channel * 4
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #8 \n"
"loop_8_%=: \n"
"vld1.32 {d2-d3}, [%[gw_ptr0]]! \n"
"vld1.32 {d4-d5}, [%[gw_ptr1]]! \n"
"vld1.32 {d6-d7}, [%[gw_ptr2]]! \n"
// q1: g0, q2: g1, q3: g2
"vst1.32 {d2-d3}, [%[outptr0]], %[steps] \n"
"vadd.f32 q9, q1, q3 \n"
"vadd.f32 q10, q9, q2 \n"
"vsub.f32 q11, q9, q2 \n"
"vmul.f32 q10, q10, d0[1] \n"
"vst1.32 {d20-d21}, [%[outptr0]], %[steps] \n"
"vmul.f32 q11, q11, d0[1] \n"
"vst1.32 {d22-d23}, [%[outptr0]], %[steps] \n"
"vmul.f32 q9, q1, d0[0] \n"
"vmul.f32 q9, q9, d0[0] \n" // 4 * g0
"vmul.f32 q10, q2, d0[0] \n" // 2 * g1
"vmul.f32 q11, q3, d0[0] \n"
"vmul.f32 q11, q11, d0[0] \n" // 4 * g2
"vadd.f32 q12, q1, q11 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[0] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
// w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
"vadd.f32 q12, q3, q9 \n"
"vadd.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vsub.f32 q13, q12, q10 \n"
"vmul.f32 q13, q13, d1[1] \n"
"vst1.32 {d26-d27}, [%[outptr0]], %[steps] \n"
"vst1.32 {d6-d7}, [%[outptr0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_8_%= \n"
: [outptr0] "+r"(outptr0), [gw_ptr0] "+r"(gw_ptr0),
[gw_ptr1] "+r"(gw_ptr1), [gw_ptr2] "+r"(gw_ptr2)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q9", "q10", "q11", "q12",
"q13", "r0");
}
}
#endif
// remain output channel
#pragma omp parallel for
for (int oc = remain_start; oc < out_channel; ++oc) {
float gw[3][8]; // gw[3][8]
const float *inptr0 = inptr + oc * in_channel * 9; //
// (oc / 4) * 64 * in_channel * 4 + oc % 4
int offset = ((oc & 0xFFFC) << 6) * in_channel + (oc & 0x3);
int steps = (in_channel << 2); // in_channel * 4
float *outptr = trans_outptr + offset;
for (int ic = 0; ic < in_channel; ++ic) {
for (int i = 0; i < 3; ++i, inptr0 += 3) {
float g0 = inptr0[0];
float g1 = inptr0[1];
float g2 = inptr0[2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[i][0] = g0;
gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2)
gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2)
gw[i][7] = g2;
}
for (int i = 0; i < 8; ++i) {
float g0 = gw[0][i];
float g1 = gw[1][i];
float g2 = gw[2][i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
int offset = i * 8 * steps;
outptr[offset] = g0;
outptr[offset + 1 * steps] = -2.f / 9 * (d0 + g1);
outptr[offset + 2 * steps] = -2.f / 9 * (d0 - g1);
outptr[offset + 3 * steps] = 1.f / 90 * (d1 + d3);
outptr[offset + 4 * steps] = 1.f / 90 * (d1 - d3);
outptr[offset + 5 * steps] = 1.f / 180 * (d2 + d3);
outptr[offset + 6 * steps] = 1.f / 180 * (d2 - d3);
outptr[offset + 7 * steps] = g2;
}
outptr += 4;
}
}
}
template <>
void winograd_transform_input<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
/*
* x0 = (d0 - d6) + (d4 - d2) * 5.25
* x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5)
* x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5)
* x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x7 = (d7 - d1) + (d3 - d5) * 5.25
*/
// package input into [roundup(tiles/8), 64, channel, 8] tiles
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height - 8 + 5 + 6) / 6
int w_tiles = (width + 3) / 6; // (width - 8 + 5 + 6) / 6
int tiles = (h_tiles * w_tiles + 7) / 8;
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{tiles, 64, channel, 8});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad;
inptr = input_pad.mutable_data<float>(input_shape);
pad(input, 0, height - input.dims()[2], 0, width - input.dims()[3],
&input_pad);
}
size_t image_size = height * width;
const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f,
2.f, -1.25f, 0.5f, 0.25f};
int remain_c_start = channel & 0xFFFC;
#if 1
remain_c_start = 0;
#else
#pragma omp parallel for
for (int c = 0; c < channel - 3; c += 4) {
const float *in = inptr + c * image_size;
float d_bt[64 * 4]; // d * B_t
for (int h = 0; h < h_tiles; ++h) {
for (int w = 0; w < w_tiles; ++w) {
const float *in0 = in + (h * width + w) * 6;
const float *in1 = in0 + image_size;
const float *in2 = in1 + image_size;
const float *in3 = in2 + image_size;
int steps = width * sizeof(float);
float *d_bt_ptr = d_bt;
asm volatile(
"mov r0, #8 \n"
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d7}, [%[in0]], %[steps] \n"
"vld1.32 {d8-d11}, [%[in1]], %[steps] \n"
"vld1.32 {d12-d15}, [%[in2]], %[steps] \n"
"vld1.32 {d16-d19}, [%[in3]], %[steps] \n"
"vtrn.32 q2, q4 \n" // d0: q2
"vtrn.32 q3, q5 \n" // d1: q4
"vtrn.32 q6, q8 \n" // d2: q6
"vtrn.32 q7, q9 \n" // d3: q8
"vswp.32 d5, d12 \n" // d4: q3
"vswp.32 d9, d16 \n" // d5: q5
"vswp.32 d7, d14 \n" // d6: q7
"vswp.32 d11, d18 \n" // d7: q9
"vsub.f32 q10, q2, q7 \n"
"vsub.f32 q11, q3, q6 \n"
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"vadd.f32 q10, q6, q7 \n"
"vadd.f32 q11, q4, q5 \n"
"vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q6, d2[0] \n" // 2 * d2
"vmul.f32 q11, q4, d2[0] \n" // 2 * d1
"vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vsub.f32 q10, q9, q4 \n"
"vsub.f32 q11, q8, q5 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1),
[in2] "+r"(in2), [in3] "+r"(in3)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *ptr0 = d_bt;
float *ptr1 = ptr0 + 32;
float *ptr2 = ptr1 + 32;
float *ptr3 = ptr2 + 32;
float *ptr4 = ptr3 + 32;
float *ptr5 = ptr4 + 32;
float *ptr6 = ptr5 + 32;
float *ptr7 = ptr6 + 32;
int tile_indics = h * w_tiles + w;
int tile_block = tile_indics >> 3;
int block_indics = tile_indics & 0x7;
// (tiles / 8, 64, channel, 8)
float *out0 =
outptr + (tile_block * 64 * channel + c) * 8 + block_indics;
steps = (channel - 3) * 8 * sizeof(float);
asm volatile(
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
"mov r0, 4 \n"
"mov r1, 32 \n"
"loop_col_%=: \n"
// col 0:
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
// col 1:
"vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0
"vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1
"vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2
"vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3
"vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4
"vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5
"vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6
"vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out0]], r1 \n"
"vst1.32 {d24[1]}, [%[out0]], r1 \n"
"vst1.32 {d25[0]}, [%[out0]], r1 \n"
"vst1.32 {d25[1]}, [%[out0]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20[0]}, [%[out0]], r1 \n"
"vst1.32 {d20[1]}, [%[out0]], r1 \n"
"vst1.32 {d21[0]}, [%[out0]], r1 \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_col_%= \n"
: [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1),
[ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4),
[ptr5] "+r"(ptr5), [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0", "r1");
}
}
}
#endif
// remainer channels
#pragma omp parallel for
for (int c = remain_c_start; c < channel; ++c) {
const float *in = inptr + c * image_size;
float d_bt[64]; // d * B_t
for (int h = 0; h < h_tiles; ++h) {
for (int w = 0; w < w_tiles; ++w) {
const float *in0 = in + (h * width + w) * 6;
const float *in1 = in0 + width;
const float *in2 = in1 + width;
const float *in3 = in2 + width;
float *d_bt_ptr = d_bt;
int steps = 4 * width * sizeof(float);
asm volatile(
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
"mov r0, #2 \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d7}, [%[in0]], %[steps] \n"
"vld1.32 {d8-d11}, [%[in1]], %[steps] \n"
"vld1.32 {d12-d15}, [%[in2]], %[steps] \n"
"vld1.32 {d16-d19}, [%[in3]], %[steps] \n"
"vtrn.32 q2, q4 \n" // d0: q2
"vtrn.32 q3, q5 \n" // d1: q4
"vtrn.32 q6, q8 \n" // d2: q6
"vtrn.32 q7, q9 \n" // d3: q8
"vswp.32 d5, d12 \n" // d4: q3
"vswp.32 d9, d16 \n" // d5: q5
"vswp.32 d7, d14 \n" // d6: q7
"vswp.32 d11, d18 \n" // d7: q9
"vsub.f32 q10, q2, q7 \n"
"vsub.f32 q11, q3, q6 \n"
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"vadd.f32 q10, q6, q7 \n"
"vadd.f32 q11, q4, q5 \n"
"vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vmul.f32 q10, q6, d2[0] \n" // 2 * d2
"vmul.f32 q11, q4, d2[0] \n" // 2 * d1
"vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vsub.f32 q13, q10, q11 \n"
"vst1.32 {d24-d27}, [%[d_bt]]! \n"
"vsub.f32 q10, q9, q4 \n"
"vsub.f32 q11, q8, q5 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20-d21}, [%[d_bt]]! \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1),
[in2] "+r"(in2), [in3] "+r"(in3)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
float *ptr0 = d_bt;
float *ptr1 = ptr0 + 32;
int tile_indics = h * w_tiles + w;
int tile_block = tile_indics >> 3;
int block_indics = tile_indics & 0x7;
// (tiles / 8, 64, channel, 8)
float *out0 =
outptr + (tile_block * 64 * channel + c) * 8 + block_indics;
float *out1 = out0 + channel * 8;
float *out2 = out1 + channel * 8;
float *out3 = out2 + channel * 8;
float *out4 = out3 + channel * 8;
float *out5 = out4 + channel * 8;
float *out6 = out5 + channel * 8;
float *out7 = out6 + channel * 8;
steps = 8 * channel * 8 * sizeof(float);
asm volatile(
"mov r0, #2 \n"
"vld1.32 {d0-d3}, [%[tm_ptr]] \n"
// row loop
"loop_r_%=: \n"
"vld1.32 {d4-d7}, [%[ptr0]]! \n" // q2: d0, q3: d1
"vld1.32 {d8-d11}, [%[ptr0]]! \n" // q4: d2, q5: d3
"vld1.32 {d12-d15}, [%[ptr1]]! \n" // q6: d4, q7: d5
"vld1.32 {d16-d19}, [%[ptr1]]! \n" // q8: d6, q9: d7
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
"vtrn.32 q6, q7 \n"
"vtrn.32 q8, q9 \n"
"vswp.32 d5, d8 \n"
"vswp.32 d7, d10 \n"
"vswp.32 d13, d16 \n"
"vswp.32 d15, d18 \n"
"vsub.f32 q10, q2, q8 \n" // d0 - d6
"vsub.f32 q11, q6, q4 \n" // d4 - d2
"vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 -
// d2) * 5.25
"vst1.32 {d20[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d20[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d21[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d21[1]}, [%[out0]], %[steps] \n"
"vadd.f32 q10, q4, q8 \n"
"vadd.f32 q11, q3, q7 \n"
"vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 +
// d6
"vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 +
// d5
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out1]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out2]], %[steps] \n"
"vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2
"vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1
"vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6
"vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 *
// d5
"vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6
// - 1.25 * d4
"vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 *
// d5 - 2.5 * d3
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out3]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out4]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out4]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out4]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out4]], %[steps] \n"
"vmul.f32 q10, q4, d2[0] \n" // 2 * d2
"vmul.f32 q11, q3, d2[0] \n" // 2 * d1
"vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 *
// d4
"vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 *
// d3
"vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 *
// d3 + 0.5 * d6
"vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 *
// d4 + 0.5 * d5
"vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3
// + d6
"vadd.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out5]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out5]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out5]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out5]], %[steps] \n"
"vsub.f32 q12, q10, q11 \n"
"vst1.32 {d24[0]}, [%[out6]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out6]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out6]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out6]], %[steps] \n"
"vsub.f32 q10, q9, q3 \n"
"vsub.f32 q11, q5, q7 \n"
"vmla.f32 q10, q11, d0[0] \n"
"vst1.32 {d20[0]}, [%[out7]], %[steps] \n"
"vst1.32 {d20[1]}, [%[out7]], %[steps] \n"
"vst1.32 {d21[0]}, [%[out7]], %[steps] \n"
"vst1.32 {d21[1]}, [%[out7]], %[steps] \n"
"subs r0, #1 \n"
"bne loop_r_%= \n"
: [out0] "+r"(out0), [out1] "+r"(out1), [out2] "+r"(out2),
[out3] "+r"(out3), [out4] "+r"(out4), [out5] "+r"(out5),
[out6] "+r"(out6), [out7] "+r"(out7), [ptr0] "+r"(ptr0),
[ptr1] "+r"(ptr1)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
}
}
}
}
template <>
void winograd_transform_output<8, 3>(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output) {
// weight shape is [out_channel/4, 64, in_channel, 4],
// input shape is [hw/8, 64, in_channel, 8]
int in_channel = input.dims()[2];
int tiles = input.dims()[0];
int out_channel = weight.dims()[0];
// compute U*V first
framework::Tensor uv_trans;
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64, 32});
float *uv_trans_ptr = uv_trans.mutable_data<float>(shape);
memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
#pragma omp parallel for
for (int i = 0; i < out_channel; ++i) {
float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32);
for (int j = 0; j < tiles; ++j) {
for (int k = 0; k < 64; ++k) {
const float *w_ptr = weight_ptr + (i * 64 + k) * in_channel * 4;
const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8;
int inter_channel = in_channel >> 1;
int remain_channel = in_channel & 0x1;
asm volatile(
"veor q8, q8, q8 \n"
"veor q9, q9, q9 \n"
"veor q10, q10, q10 \n"
"veor q11, q11, q11 \n"
"veor q12, q12, q12 \n"
"veor q13, q13, q13 \n"
"veor q14, q14, q14 \n"
"veor q15, q15, q15 \n"
"b store_res_%= \n"
// loop 2 channels
"loop_2c_%=: \n"
"vld1.32 {d0-d3}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vld1.32 {d8-d11}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n"
"vmla.f32 q11, q3, d0[1] \n"
"vmla.f32 q12, q2, d1[0] \n"
"vmla.f32 q13, q3, d1[0] \n"
"vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n"
"vmla.f32 q8, q4, d2[0] \n"
"vmla.f32 q9, q5, d2[0] \n"
"vmla.f32 q10, q4, d2[1] \n"
"vmla.f32 q11, q5, d2[1] \n"
"vmla.f32 q12, q4, d3[0] \n"
"vmla.f32 q13, q5, d3[0] \n"
"vmla.f32 q14, q4, d3[1] \n"
"vmla.f32 q15, q5, d3[1] \n"
"subs %[inter_channel], #1 \n"
"bne loop_2c_%= \n"
"mov pc, lr \n"
// loop 1 channel
"loop_c_%=: \n"
"vld1.32 {d0-d1}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n"
"vmla.f32 q11, q3, d0[1] \n"
"vmla.f32 q12, q2, d1[0] \n"
"vmla.f32 q13, q3, d1[0] \n"
"vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n"
"subs %[remain_channel], #1 \n"
"bne loop_c_%= \n"
"mov pc, lr \n"
"store_res_%=: \n"
"cmp %[inter_channel], #0 \n"
"it gt \n"
"blgt loop_2c_%= \n"
"cmp %[remain_channel], #0 \n"
"it gt \n"
"blgt loop_c_%= \n"
"vst1.32 {d16-d19}, [%[uv_ptr]]! \n"
"vst1.32 {d20-d23}, [%[uv_ptr]]! \n"
"vst1.32 {d24-d27}, [%[uv_ptr]]! \n"
"vst1.32 {d28-d31}, [%[uv_ptr]]! \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
[remain_channel] "+r"(remain_channel),
[inter_channel] "+r"(inter_channel)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "pc", "lr");
}
}
}
/*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
* s2 = (m1 + m2) + 4 * (m3 + m4) + 8 * (m5 + m6)
* s3 = (m1 - m2) + 8 * (m3 - m4) + 4 * (m5 - m6)
* s4 = (m1 + m2) + 16 * (m3 + m4) + 2 * (m5 + m6)
* s5 = (m1 - m2) + 32 * (m3 - m4) + (m5 - m6) + m7
*/
int out_h = output->dims()[2];
int out_w = output->dims()[3];
int h_tiles = (out_h + 5) / 6;
int w_tiles = (out_w + 5) / 6;
int remain_h = out_h - out_h / 6 * 6;
int remain_w = out_w - out_w / 6 * 6;
float *output_ptr = output->mutable_data<float>();
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
#pragma omp parallel for
for (int oc = 0; oc < output->dims()[1]; ++oc) {
float at_m[48]; // [6][8]
float output_tmp[36]; // [6][6], temporarily restore results
// (oc / 4) * tiles * 64 * 32 + (oc & 0x3) * 8
const float *uv_ptr =
uv_trans_ptr + (oc >> 2) * tiles * 64 * 32 + (oc & 0x3) * 8;
for (int tile_h = 0; tile_h < h_tiles; ++tile_h) {
for (int tile_w = 0; tile_w < w_tiles; ++tile_w) {
float *at_m_ptr = at_m;
int tile_indics = tile_h * w_tiles + tile_w;
int tile_block = tile_indics >> 3;
int block_indics = tile_indics & 0x7;
const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics;
int steps = 32 * sizeof(float);
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #2 \n"
"loop_%=: \n"
"vld1.32 {d2[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d6[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d10[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d14[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d4[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d8[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d12[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d16[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d2[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d6[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d10[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d14[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d4[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d8[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d12[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d16[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d3[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d7[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d11[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d15[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d5[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d9[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d13[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d17[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d3[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d7[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d11[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d15[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d5[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d9[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d13[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d17[1]}, [%[uv_ptr0]], %[steps] \n"
"vadd.f32 q9, q3, q5 \n" // m1 + m2
"vadd.f32 q10, q7, q2 \n" // m3 + m4
"vadd.f32 q11, q4, q6 \n" // m5 + m6
"vsub.f32 q12, q3, q5 \n" // m1 - m2
"vsub.f32 q13, q7, q2 \n" // m3 - m4
"vsub.f32 q14, q4, q6 \n" // m5 - m6
"vmul.f32 q2, q13, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 q3, q11, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 q15, q1, q9 \n"
"vadd.f32 q15, q15, q10 \n"
"vmla.f32 q15, q3, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vadd.f32 q15, q12, q2 \n"
"vmla.f32 q15, q14, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vmov.32 q15, q9 \n"
"vmla.f32 q15, q10, d0[1] \n"
"vmla.f32 q15, q11, d1[0] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vmov.32 q15, q12 \n"
"vmla.f32 q15, q13, d1[0] \n"
"vmla.f32 q15, q14, d0[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vadd.f32 q15, q9, q3 \n"
"vmla.f32 q15, q10, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"vadd.f32 q15, q12, q8 \n"
"vadd.f32 q15, q15, q14 \n"
"vmla.f32 q15, q2, d1[1] \n"
"vst1.32 {d30-d31}, [%[at_m_ptr]]! \n"
"subs r0, #1 \n"
"bne loop_%= \n"
: [uv_ptr0] "+r"(uv_ptr0), [at_m_ptr] "+r"(at_m_ptr)
: [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
float *at_m_ptr0 = at_m;
float *at_m_ptr1 = at_m + 24;
if ((remain_w > 0 && tile_w == w_tiles - 1) ||
(remain_h > 0 && tile_h == h_tiles - 1)) {
float *out_ptr0 = output_tmp;
float *out_ptr1 = output_tmp + 6;
float *out_ptr2 = output_tmp + 12;
float *out_ptr3 = output_tmp + 18;
float *out_ptr4 = output_tmp + 24;
float *out_ptr5 = output_tmp + 30;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
// process 4 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // q1: m0, q2: m1
"vld1.32 {d6-d9}, [%[at_m_ptr0]]! \n" // q3: m2, q4: m3
"vld1.32 {d10-d13}, [%[at_m_ptr1]]! \n" // q5: m4, q6: m5
"vld1.32 {d14-d17}, [%[at_m_ptr1]]! \n" // q7: m6, q8: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vtrn.32 q5, q6 \n"
"vtrn.32 q7, q8 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vswp.32 d11, d14 \n"
"vswp.32 d13, d16 \n"
"vadd.f32 q9, q2, q3 \n" // m1 + m2
"vadd.f32 q10, q4, q5 \n" // m3 + m4
"vadd.f32 q11, q6, q7 \n" // m5 + m6
"vsub.f32 q12, q2, q3 \n" // m1 - m2
"vsub.f32 q13, q4, q5 \n" // m3 - m4
"vsub.f32 q14, q6, q7 \n" // m5 - m6
"vmul.f32 q6, q13, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 q7, q11, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 q1, q1, q9 \n"
"vadd.f32 q1, q1, q10 \n"
"vmla.f32 q1, q7, d1[1] \n"
"vadd.f32 q2, q12, q6 \n"
"vmla.f32 q2, q14, d1[1] \n"
"vmov.32 q3, q9 \n"
"vmla.f32 q3, q10, d0[1] \n"
"vmla.f32 q3, q11, d1[0] \n"
"vmov.32 q4, q12 \n"
"vmla.f32 q4, q13, d1[0] \n"
"vmla.f32 q4, q14, d0[1] \n"
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vst1.32 {d2-d3}, [%[out_ptr0]]! \n"
"vst1.32 {d4-d5}, [%[out_ptr1]]! \n"
"vst1.32 {d6-d7}, [%[out_ptr2]]! \n"
"vst1.32 {d8-d9}, [%[out_ptr3]]! \n"
"vadd.f32 q1, q9, q7 \n"
"vmla.f32 q1, q10, d1[1] \n"
"vadd.f32 q2, q12, q8 \n"
"vadd.f32 q2, q2, q14 \n"
"vmla.f32 q2, q6, d1[1] \n"
"vtrn.32 q1, q2 \n"
"vst1.32 {d2}, [%[out_ptr0]]! \n"
"vst1.32 {d4}, [%[out_ptr1]]! \n"
"vst1.32 {d3}, [%[out_ptr2]]! \n"
"vst1.32 {d5}, [%[out_ptr3]]! \n"
// remain 2 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // d2: m0, d3: m2,
// d4: m1, d5: m3
"vld1.32 {d6-d9}, [%[at_m_ptr1]]! \n" // d6: m4, d7: m6,
// d8: m5, d9: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vadd.f32 d10, d4, d3 \n" // m1 + m2
"vadd.f32 d11, d5, d6 \n" // m3 + m4
"vadd.f32 d12, d8, d7 \n" // m5 + m6
"vsub.f32 d13, d4, d3 \n" // m1 - m2
"vsub.f32 d14, d5, d6 \n" // m3 - m4
"vsub.f32 d15, d8, d7 \n" // m5 - m6
"vmul.f32 d16, d14, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 d17, d12, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 d18, d2, d10 \n"
"vadd.f32 d18, d18, d11 \n"
"vmla.f32 d18, d17, d1[1] \n"
"vadd.f32 d20, d13, d16 \n"
"vmla.f32 d20, d15, d1[1] \n"
"vmov.32 d19, d10 \n"
"vmla.f32 d19, d11, d0[1] \n"
"vmla.f32 d19, d12, d1[0] \n"
"vmov.32 d21, d13 \n"
"vmla.f32 d21, d14, d1[0] \n"
"vmla.f32 d21, d15, d0[1] \n"
"vtrn.32 d18, d20 \n"
"vtrn.32 d19, d21 \n"
"vst1.32 {d18-d19}, [%[out_ptr4]]! \n"
"vst1.32 {d20-d21}, [%[out_ptr5]]! \n"
"vadd.f32 d18, d10, d17 \n"
"vmla.f32 d18, d11, d1[1] \n"
"vadd.f32 d19, d13, d9 \n"
"vadd.f32 d19, d19, d15 \n"
"vmla.f32 d19, d16, d1[1] \n"
"vtrn.32 d18, d19 \n"
"vst1.32 {d18}, [%[out_ptr4]]! \n"
"vst1.32 {d19}, [%[out_ptr5]]! \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1),
[out_ptr2] "+r"(out_ptr2), [out_ptr3] "+r"(out_ptr3),
[out_ptr4] "+r"(out_ptr4), [out_ptr5] "+r"(out_ptr5),
[at_m_ptr0] "+r"(at_m_ptr0), [at_m_ptr1] "+r"(at_m_ptr1)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w;
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
}
} else {
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr0 = output_ptr + offset;
float *out_ptr1 = out_ptr0 + out_w;
float *out_ptr2 = out_ptr1 + out_w;
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n"
// process 4 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // q1: m0, q2: m1
"vld1.32 {d6-d9}, [%[at_m_ptr0]]! \n" // q3: m2, q4: m3
"vld1.32 {d10-d13}, [%[at_m_ptr1]]! \n" // q5: m4, q6: m5
"vld1.32 {d14-d17}, [%[at_m_ptr1]]! \n" // q7: m6, q8: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vtrn.32 q5, q6 \n"
"vtrn.32 q7, q8 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vswp.32 d11, d14 \n"
"vswp.32 d13, d16 \n"
"vadd.f32 q9, q2, q3 \n" // m1 + m2
"vadd.f32 q10, q4, q5 \n" // m3 + m4
"vadd.f32 q11, q6, q7 \n" // m5 + m6
"vsub.f32 q12, q2, q3 \n" // m1 - m2
"vsub.f32 q13, q4, q5 \n" // m3 - m4
"vsub.f32 q14, q6, q7 \n" // m5 - m6
"vmul.f32 q6, q13, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 q7, q11, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 q1, q1, q9 \n"
"vadd.f32 q1, q1, q10 \n"
"vmla.f32 q1, q7, d1[1] \n"
"vadd.f32 q2, q12, q6 \n"
"vmla.f32 q2, q14, d1[1] \n"
"vmov.32 q3, q9 \n"
"vmla.f32 q3, q10, d0[1] \n"
"vmla.f32 q3, q11, d1[0] \n"
"vmov.32 q4, q12 \n"
"vmla.f32 q4, q13, d1[0] \n"
"vmla.f32 q4, q14, d0[1] \n"
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vswp.32 d3, d6 \n"
"vswp.32 d5, d8 \n"
"vst1.32 {d2-d3}, [%[out_ptr0]]! \n"
"vst1.32 {d4-d5}, [%[out_ptr1]]! \n"
"vst1.32 {d6-d7}, [%[out_ptr2]]! \n"
"vst1.32 {d8-d9}, [%[out_ptr3]]! \n"
"vadd.f32 q1, q9, q7 \n"
"vmla.f32 q1, q10, d1[1] \n"
"vadd.f32 q2, q12, q8 \n"
"vadd.f32 q2, q2, q14 \n"
"vmla.f32 q2, q6, d1[1] \n"
"vtrn.32 q1, q2 \n"
"vst1.32 {d2}, [%[out_ptr0]]! \n"
"vst1.32 {d4}, [%[out_ptr1]]! \n"
"vst1.32 {d3}, [%[out_ptr2]]! \n"
"vst1.32 {d5}, [%[out_ptr3]]! \n"
// remain 2 rows
"vld1.32 {d2-d5}, [%[at_m_ptr0]]! \n" // d2: m0, d3: m2,
// d4: m1, d5: m3
"vld1.32 {d6-d9}, [%[at_m_ptr1]]! \n" // d6: m4, d7: m6,
// d8: m5, d9: m7
"vtrn.32 q1, q2 \n"
"vtrn.32 q3, q4 \n"
"vadd.f32 d10, d4, d3 \n" // m1 + m2
"vadd.f32 d11, d5, d6 \n" // m3 + m4
"vadd.f32 d12, d8, d7 \n" // m5 + m6
"vsub.f32 d13, d4, d3 \n" // m1 - m2
"vsub.f32 d14, d5, d6 \n" // m3 - m4
"vsub.f32 d15, d8, d7 \n" // m5 - m6
"vmul.f32 d16, d14, d0[0] \n" // 2 * (m3 - m4)
"vmul.f32 d17, d12, d0[0] \n" // 2 * (m5 + m6)
"vadd.f32 d18, d2, d10 \n"
"vadd.f32 d18, d18, d11 \n"
"vmla.f32 d18, d17, d1[1] \n"
"vadd.f32 d20, d13, d16 \n"
"vmla.f32 d20, d15, d1[1] \n"
"vmov.32 d19, d10 \n"
"vmla.f32 d19, d11, d0[1] \n"
"vmla.f32 d19, d12, d1[0] \n"
"vmov.32 d21, d13 \n"
"vmla.f32 d21, d14, d1[0] \n"
"vmla.f32 d21, d15, d0[1] \n"
"vtrn.32 d18, d20 \n"
"vtrn.32 d19, d21 \n"
"vst1.32 {d18-d19}, [%[out_ptr4]]! \n"
"vst1.32 {d20-d21}, [%[out_ptr5]]! \n"
"vadd.f32 d18, d10, d17 \n"
"vmla.f32 d18, d11, d1[1] \n"
"vadd.f32 d19, d13, d9 \n"
"vadd.f32 d19, d19, d15 \n"
"vmla.f32 d19, d16, d1[1] \n"
"vtrn.32 d18, d19 \n"
"vst1.32 {d18}, [%[out_ptr4]]! \n"
"vst1.32 {d19}, [%[out_ptr5]]! \n"
: [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1),
[out_ptr2] "+r"(out_ptr2), [out_ptr3] "+r"(out_ptr3),
[out_ptr4] "+r"(out_ptr4), [out_ptr5] "+r"(out_ptr5),
[at_m_ptr0] "+r"(at_m_ptr0), [at_m_ptr1] "+r"(at_m_ptr1)
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __aarch64__
#endif // CONV_OP
...@@ -405,9 +405,9 @@ class ConvParam : public OpParam { ...@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const RType *Input() const { return input_; } const RType *Input() const { return input_; }
RType *Filter() const { return filter_; } RType *&Filter() const { return filter_; }
RType *Output() const { return output_; } RType *&Output() const { return output_; }
const vector<int> &Strides() const { return strides_; } const vector<int> &Strides() const { return strides_; }
...@@ -415,6 +415,19 @@ class ConvParam : public OpParam { ...@@ -415,6 +415,19 @@ class ConvParam : public OpParam {
const vector<int> &Dilations() const { return dilations_; } const vector<int> &Dilations() const { return dilations_; }
enum ExecMode {
EXEC_INVALID = 0,
EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
EXEC_GEMM_INT8,
EXEC_DEPTHWISE3x3_INT8,
};
ExecMode &ExecMode() const { return exec_mode_; }
const int &Groups() const { return groups; } const int &Groups() const { return groups; }
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
...@@ -426,11 +439,12 @@ class ConvParam : public OpParam { ...@@ -426,11 +439,12 @@ class ConvParam : public OpParam {
private: private:
RType *input_; RType *input_;
RType *output_; mutable RType *output_;
RType *filter_; mutable RType *filter_;
vector<int> strides_; vector<int> strides_;
vector<int> paddings_; vector<int> paddings_;
vector<int> dilations_; vector<int> dilations_;
mutable enum ExecMode exec_mode_;
int groups; int groups;
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
...@@ -2509,10 +2523,10 @@ class QuantizeParam : public OpParam { ...@@ -2509,10 +2523,10 @@ class QuantizeParam : public OpParam {
QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, QuantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
// online // online
// scale = max(abs(x)) // scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope); online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, scope);
// offline // offline
if (HasAttr("static_scale", attrs)) { if (HasAttr("static_scale", attrs)) {
is_static_ = true; is_static_ = true;
...@@ -2522,14 +2536,18 @@ class QuantizeParam : public OpParam { ...@@ -2522,14 +2536,18 @@ class QuantizeParam : public OpParam {
if (HasAttr("round_type", attrs)) { if (HasAttr("round_type", attrs)) {
round_type_ = GetAttr<RoundType>("round_type", attrs); round_type_ = GetAttr<RoundType>("round_type", attrs);
} }
// get paddings
paddings_ = std::vector<int>({0, 0});
if (HasAttr("paddings", attrs)) {
paddings_ = GetAttr<vector<int>>("paddings", attrs);
}
} }
public: public:
// op input // op input
RType *input_; RType *input_;
// op output // op output
RType *out_; RType *output_;
//
RType *online_scale_; RType *online_scale_;
// if static scale or not // if static scale or not
bool is_static_ = false; bool is_static_ = false;
...@@ -2537,7 +2555,11 @@ class QuantizeParam : public OpParam { ...@@ -2537,7 +2555,11 @@ class QuantizeParam : public OpParam {
float static_scale_ = 1.0f; float static_scale_ = 1.0f;
// round method type // round method type
// nearest_zero and nearest_even is valid currently // nearest_zero and nearest_even is valid currently
RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO;
RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO;
// optional paddings
std::vector<int> paddings_;
int8_t padding_val_;
}; };
#endif #endif
...@@ -2551,8 +2573,8 @@ class DequantizeParam : public OpParam { ...@@ -2551,8 +2573,8 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
activation_scale_ = GetVarValue<GType>("Scale", inputs, scope); activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale // dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) { if (HasAttr("weight_scale", attrs)) {
weight_scale_ = GetAttr<float>("weight_scale", attrs); weight_scale_ = GetAttr<float>("weight_scale", attrs);
...@@ -2565,11 +2587,50 @@ class DequantizeParam : public OpParam { ...@@ -2565,11 +2587,50 @@ class DequantizeParam : public OpParam {
// op input // op input
RType *input_; RType *input_;
// op output // op output
RType *out_; RType *output_;
RType *activation_scale_; RType *activation_scale_;
float weight_scale_; float weight_scale_;
}; };
#endif #endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
bn_scale_ = OpParam::GetVarValue<GType>("BNScale", inputs, scope);
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
public:
// elementwise add
int axis_;
RType *bias_;
// batch norm
RType *bn_mean_;
RType *bn_variance_;
RType *bn_scale_;
RType *bn_bias_;
float epsilon_;
// output
RType *output_;
};
#endif
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,8 +22,11 @@ namespace operators { ...@@ -22,8 +22,11 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
void QuantizeOp<DeviceType, T>::InferShape() const { void QuantizeOp<DeviceType, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims(); auto input_dims = this->param_.input_->dims();
this->param_.out_->Resize(input_dims); const std::vector<int> &paddings = this->param_.paddings_;
input_dims[2] += 2 * paddings[0];
input_dims[3] += 2 * paddings[1];
this->param_.output_->Resize(input_dims);
auto scale_dims = framework::make_ddim(std::vector<int>{1}); auto scale_dims = framework::make_ddim(std::vector<int>{1});
this->param_.online_scale_->Resize(scale_dims); this->param_.online_scale_->Resize(scale_dims);
} }
......
...@@ -155,7 +155,7 @@ if (NOT FOUND_MATCH) ...@@ -155,7 +155,7 @@ if (NOT FOUND_MATCH)
target_link_libraries(test-googlenet-quali paddle-mobile) target_link_libraries(test-googlenet-quali paddle-mobile)
# gen test # gen test
ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-conv-op operators/test_conv_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-op paddle-mobile) target_link_libraries(test-conv-op paddle-mobile)
# gen test # gen test
...@@ -242,10 +242,6 @@ if (NOT FOUND_MATCH) ...@@ -242,10 +242,6 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-dequantize-op operators/test_dequantize_op.cpp test_helper.h test_include.h)
target_link_libraries(test-dequantize-op paddle-mobile) target_link_libraries(test-dequantize-op paddle-mobile)
# test int8 conv op
ADD_EXECUTABLE(test-int8-conv-op operators/test_int8_conv_op.cpp test_helper.h test_include.h)
target_link_libraries(test-int8-conv-op paddle-mobile)
# gen test log # gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp) ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile) target_link_libraries(test-log paddle-mobile)
...@@ -368,6 +364,10 @@ if (NOT FOUND_MATCH) ...@@ -368,6 +364,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h)
target_link_libraries(test-multi-process paddle-mobile) target_link_libraries(test-multi-process paddle-mobile)
# gen test benchmark
ADD_EXECUTABLE(test-benchmark net/test_benchmark.cpp)
target_link_libraries(test-benchmark paddle-mobile)
# gen test # gen test
ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h)
target_link_libraries(test-eng paddle-mobile) target_link_libraries(test-eng paddle-mobile)
......
...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream>
#include <string> #include <string>
#include "../test_helper.h" #include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
static size_t ReadBuffer(const char *file_name, uint8_t **out) { static size_t ReadBuffer(const char *file_name, uint8_t **out) {
FILE *fp; FILE *fp;
fp = fopen(file_name, "rb"); fp = fopen(file_name, "rb");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: " << std::endl
<< "./test_benchmark fluid_model feed_shape thread_num [use_fuse]"
<< std::endl;
std::cout << "use_fuse: optional, bool, default is 1\n";
return 1;
}
bool optimize = true;
char* fluid_model = argv[1];
char* feed_shape = argv[2];
int thread_num = atoi(argv[3]);
if (argc == 5) {
optimize = atoi(argv[4]);
}
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(thread_num);
auto time1 = time();
if (paddle_mobile.Load(fluid_model, optimize)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::framework::Tensor input;
std::shared_ptr<paddle_mobile::framework::Tensor> output;
std::vector<int64_t> dims{1, 3, 224, 224};
if (feed_shape) {
sscanf(feed_shape, "%d,%d,%d,%d", &dims[0], &dims[1], &dims[2], &dims[3]);
}
std::cout << "feed shape: [" << dims[0] << ", " << dims[1] << ", "
<< dims[2] << ", " << dims[3] << "]\n";
paddle_mobile::framework::DDim in_shape =
paddle_mobile::framework::make_ddim(dims);
SetupTensor<float>(&input, in_shape, 0.f, 255.f);
// warmup
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n";
}
return 0;
}
...@@ -20,12 +20,11 @@ int main() { ...@@ -20,12 +20,11 @@ int main() {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::FPGA> paddle_mobile;
#endif #endif
#ifdef PADDLE_MOBILE_CPU #ifdef PADDLE_MOBILE_CPU
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif #endif
paddle_mobile.SetThreadNum(4); paddle_mobile.SetThreadNum(1);
bool optimize = true; bool optimize = true;
auto time1 = time(); auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) { if (paddle_mobile.Load(g_googlenet, optimize)) {
...@@ -36,7 +35,7 @@ int main() { ...@@ -36,7 +35,7 @@ int main() {
std::vector<float> output; std::vector<float> output;
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// 预热十次 // warmup
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input, dims); output = paddle_mobile.Predict(input, dims);
} }
...@@ -46,8 +45,7 @@ int main() { ...@@ -46,8 +45,7 @@ int main() {
} }
auto time4 = time(); auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms" std::cout << "predict cost: " << time_diff(time3, time4) / 10 << "ms\n";
<< std::endl;
} }
return 0; return 0;
} }
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
// Reference convolution for checking results: // Reference convolution from Caffe for checking results.
// accumulate through explicit loops over input, output, and filters. // accumulate through explicit loops over input, output, and filters.
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void conv2d(const framework::Tensor *input, const framework::Tensor *filter, void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
...@@ -129,7 +129,7 @@ void conv2d(const framework::Tensor *input, const framework::Tensor *filter, ...@@ -129,7 +129,7 @@ void conv2d(const framework::Tensor *input, const framework::Tensor *filter,
} }
template <typename Itype, typename Otype, int Kernel, int Pad, int Stride> template <typename Itype, typename Otype, int Kernel, int Pad, int Stride>
int TestConvOp() { int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) {
int kernel_h = Kernel; int kernel_h = Kernel;
int kernel_w = Kernel; int kernel_w = Kernel;
int pad_h = Pad; int pad_h = Pad;
...@@ -140,10 +140,10 @@ int TestConvOp() { ...@@ -140,10 +140,10 @@ int TestConvOp() {
int dilation_w = 1; int dilation_w = 1;
int batch_size = 1; int batch_size = 1;
int input_c = 3; int input_c = in_channels;
int input_h = 100; int input_h = in_height;
int input_w = 100; int input_w = in_width;
int output_c = 10; int output_c = out_channels;
framework::DDim input_shape = framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w}); framework::make_ddim({batch_size, input_c, input_h, input_w});
framework::DDim filter_shape = framework::DDim filter_shape =
...@@ -158,7 +158,7 @@ int TestConvOp() { ...@@ -158,7 +158,7 @@ int TestConvOp() {
auto input_var = scope.get()->Var("input"); auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>(); auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<Itype>(input, input_shape, -20, 20); SetupTensor<Itype>(input, input_shape, -20.0, 20.0);
auto filter_var = scope.get()->Var("filter"); auto filter_var = scope.get()->Var("filter");
auto filter = filter_var->template GetMutable<framework::LoDTensor>(); auto filter = filter_var->template GetMutable<framework::LoDTensor>();
...@@ -174,8 +174,9 @@ int TestConvOp() { ...@@ -174,8 +174,9 @@ int TestConvOp() {
auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs, auto *op = new operators::ConvOp<CPU, float>("conv2d", inputs, outputs, attrs,
scope); scope);
// struct timespec ts_begin, ts_end;
op->InferShape(); op->InferShape();
op->Init();
// struct timespec ts_begin, ts_end;
// warmup // warmup
// op->Run(); // op->Run();
// clock_gettime(CLOCK_MONOTONIC, &ts_begin); // clock_gettime(CLOCK_MONOTONIC, &ts_begin);
...@@ -202,9 +203,16 @@ int TestConvOp() { ...@@ -202,9 +203,16 @@ int TestConvOp() {
const Otype *output_data = output->data<Otype>(); const Otype *output_data = output->data<Otype>();
Otype *output_cmp_data = output_cmp.data<Otype>(); Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], float gap = output_data[i] - output_cmp_data[i];
PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3,
"output[%d] = %d, output_cmp[%d] = %d", i, "output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]); output_data[i], i, output_cmp_data[i]);
// if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
// LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
// << ", output_cmp_data[" << i << "] = " <<
// output_cmp_data[i];
// return 1;
// }
} }
delete op; delete op;
return 0; return 0;
...@@ -212,68 +220,88 @@ int TestConvOp() { ...@@ -212,68 +220,88 @@ int TestConvOp() {
} // namespace paddle_mobile } // namespace paddle_mobile
int main() { int main(int argc, char *argv[]) {
if (argc < 5) {
LOG(paddle_mobile::kLOG_INFO)
<< "Usage:\n"
<< " ./test-int8-conv-op in_channels in_height in_width out_channels\n"
<< " params:\n"
<< " -in_channels: int, input image's channels\n"
<< " -in_height: int, input image's height\n"
<< " -in_width: int, input image's width\n"
<< " -out_channels: int, conv output channels\n";
return 1;
}
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
int out_channels = atoi(argv[4]);
// kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 0, stride = 2 // kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 1, stride = 2 // kernel = 7, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 3, stride = 2 // kernel = 7, pad = 3, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 0, stride = 1 // kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 1, stride = 1 // kernel = 7, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 3, stride = 1 // kernel = 7, pad = 3, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 5, stride = 3 // kernel = 7, pad = 5, stride = 3
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(in_channels, in_height,
in_width, out_channels);
// kernel = 7, pad = 3, stride = 4 // kernel = 7, pad = 3, stride = 4
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(); paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(in_channels, in_height,
LOG(paddle_mobile::kLOG_INFO) << "\n"; in_width, out_channels);
// kernel = 3, pad = 0, stride = 1 // kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 3, pad = 0, stride = 1 // kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 0, 1>(); paddle_mobile::TestConvOp<float, float, 3, 0, 1>(in_channels, in_height,
LOG(paddle_mobile::kLOG_INFO) << "\n"; in_width, out_channels);
// kernel = 3, pad = 1, stride = 1 // kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 3, pad = 1, stride = 1 // kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(); paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
LOG(paddle_mobile::kLOG_INFO) << "\n"; in_width, out_channels);
// kernel = 5, pad = 0, stride = 1 // kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 5, pad = 0, stride = 1 // kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>(); paddle_mobile::TestConvOp<float, float, 5, 0, 1>(in_channels, in_height,
LOG(paddle_mobile::kLOG_INFO) << "\n"; in_width, out_channels);
// kernel = 5, pad = 2, stride = 1 // kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(); paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(in_channels, in_height,
in_width, out_channels);
// kernel = 5, pad = 2, stride = 1 // kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 2, 1>(); paddle_mobile::TestConvOp<float, float, 5, 2, 1>(in_channels, in_height,
in_width, out_channels);
} }
...@@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,58 +12,131 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream>
#include "../test_helper.h" #include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
#include "operators/quantize_op.h" #include "operators/quantize_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace round {
static float find_abs_max(const Tensor *input) { enum RoundType {
float max_abs = 0.f; RoundToEven = 0,
const float *x = input->data<const float>(); RoundAwayZero = 1,
size_t size = input->numel(); RoundTowardsZero = 2,
for (size_t i = 0; i < size; ++i) { };
float value = std::abs(x[i]);
if (value > max_abs) {
max_abs = value;
}
}
return max_abs;
} }
static void quantize_round_to_even(const Tensor *input, const float scale, template <round::RoundType T>
Tensor *output) { struct Round {
const float *x = input->data<const float>(); int8_t operator()(float x);
int8_t *y = output->mutable_data<int8_t>(); };
size_t size = input->numel();
for (size_t i = 0; i < size; ++i) { template <>
float value = x[i] * scale; struct Round<round::RoundAwayZero> {
float v = round(value); int8_t operator()(float x) { return std::round(x); }
};
template <>
struct Round<round::RoundTowardsZero> {
int8_t operator()(float x) { return int8_t(x); }
};
template <>
struct Round<round::RoundToEven> {
int8_t operator()(float x) {
int8_t ret = 0;
float v = std::round(x);
int32_t q = (int32_t)v; int32_t q = (int32_t)v;
if (abs(abs(q - value) - 0.5) > 0) { if (abs(abs(q - x) - 0.5) > 0) {
y[i] = q; ret = q;
} else { } else {
if (abs(q) % 2 == 0) { if (abs(q) % 2 == 0) {
y[i] = q; ret = q;
} else { } else {
y[i] = q + ((q > 0) ? -1 : 1); ret = q + ((q > 0) ? -1 : 1);
}
}
return ret;
}
};
template <round::RoundType T>
static void quantize(const Tensor *input, const float scale, const int pad,
const int8_t pad_val, Tensor *output) {
int batch_size = input->dims()[0];
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
size_t input_spatial = input_h * input_w;
size_t output_spatial = output_h * output_w;
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>()(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
} }
} }
} }
} }
static void quantize_round_to_nearest(const Tensor *input, const float scale, static float find_abs_max(const Tensor *input) {
Tensor *output) { float max_abs = 0.f;
const float *x = input->data<const float>(); const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel(); size_t size = input->numel();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
y[i] = round(x[i] * scale); float value = std::abs(x[i]);
if (value > max_abs) {
max_abs = value;
}
} }
return max_abs;
} }
int TestQuqntizeOp() { int TestQuqntizeOp(int argc, char *argv[]) {
framework::DDim dim = framework::make_ddim({1, 3, 224, 224}); if (argc < 5) {
std::cout
<< "Usage: ./test-quantize-op batch_size channel height width [pad]"
<< std::endl;
return 1;
}
int pad = 0;
int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]);
int height = atoi(argv[3]);
int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim =
framework::make_ddim({batch_size, channel, height, width});
VariableNameMap inputs; VariableNameMap inputs;
VariableNameMap outputs; VariableNameMap outputs;
...@@ -80,6 +153,7 @@ int TestQuqntizeOp() { ...@@ -80,6 +153,7 @@ int TestQuqntizeOp() {
auto output_scale_var = scope.get()->Var("output_scale"); auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs, auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope); attrs, scope);
op->InferShape(); op->InferShape();
...@@ -96,10 +170,11 @@ int TestQuqntizeOp() { ...@@ -96,10 +170,11 @@ int TestQuqntizeOp() {
output_scale_cmp, output_scale_data[0]); output_scale_cmp, output_scale_data[0]);
framework::Tensor output_cmp; framework::Tensor output_cmp;
output_cmp.Resize(dim); output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp; float scale = 127 / output_scale_cmp;
// quantize_round_to_even(input, scale, &output_cmp); // quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp);
quantize_round_to_nearest(input, scale, &output_cmp); // quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>(); int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
...@@ -113,4 +188,6 @@ int TestQuqntizeOp() { ...@@ -113,4 +188,6 @@ int TestQuqntizeOp() {
} // namespace paddle_mobile } // namespace paddle_mobile
int main() { return paddle_mobile::TestQuqntizeOp(); } int main(int argc, char *argv[]) {
return paddle_mobile::TestQuqntizeOp(argc, argv);
}
...@@ -212,4 +212,4 @@ else ...@@ -212,4 +212,4 @@ else
else else
build_error "$1" build_error "$1"
fi fi
fi fi
\ No newline at end of file
...@@ -249,6 +249,7 @@ if(NOT FOUND_MATCH) ...@@ -249,6 +249,7 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON) set(SUM_OP ON)
set(QUANT_OP ON) set(QUANT_OP ON)
set(DEQUANT_OP ON) set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON)
endif() endif()
# option(BATCHNORM_OP "" ON) # option(BATCHNORM_OP "" ON)
...@@ -450,6 +451,9 @@ endif() ...@@ -450,6 +451,9 @@ endif()
if (DEQUANT_OP) if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP) add_definitions(-DDEQUANT_OP)
endif() endif()
if (FUSION_DEQUANT_ADD_BN_RELU)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif()
if (TANH_OP) if (TANH_OP)
add_definitions(-DTANH_OP) add_definitions(-DTANH_OP)
...@@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP) ...@@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP)
endif() endif()
if (FUSION_DECONVADDRELU_OP) if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP)
endif() endif()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册