提交 0c4be5a4 编写于 作者: X xiebaiyuan 提交者: GitHub

Merge pull request #1324 from hjchen2/dev-latest

Fix ios cross compile, revert quantize kernel, and add other dequantize fusion pattern
......@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn";
const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh";
......@@ -136,6 +138,8 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
......
......@@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN;
extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH;
......
......@@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_OP
LOAD_OP1(fusion_dequant_add_bn, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn);
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
LOAD_OP1(fusion_dequant_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_bn_relu);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
......
......@@ -95,7 +95,8 @@ static std::mutex shared_mutex;
andModelParamsLen:(size_t)combinedParamsLen
andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf {
pam_->SetThreadNum(2);
return loaded_ = pam_->LoadCombinedMemory(modelLen, modelBuf, combinedParamsLen, combinedParamsBuf);
return loaded_ = pam_->LoadCombinedMemory(modelLen, modelBuf, combinedParamsLen,
const_cast<uint8_t*>(combinedParamsBuf));
}
- (BOOL)load:(NSString *)modelAndWeightPath{
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <string>
#include "framework/operator.h"
#include "operators/kernel/depthwise_conv_kernel.h"
#include "operators/kernel/conv_kernel.h"
namespace paddle_mobile {
namespace operators {
......@@ -26,19 +26,16 @@ namespace operators {
template <typename DeviceType, typename T>
class DepthwiseConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>> {
operators::ConvKernel<DeviceType, T>> {
public:
DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>>(
: framework::OperatorWithKernel<DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
private:
};
} // namespace operators
......
......@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/depthwise_conv_kernel.h"
#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h"
#include "operators/fusion_dequant_add_bn_op.h"
namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
return true;
template <typename Dtype, typename T>
void FusionDequantAddBNOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConvCompute<float>(param);
}
template class DepthwiseConvKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn, ops::FusionDequantAddBNMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn, ops::FusionDequantAddBNOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>> {
public:
FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -20,7 +20,7 @@ limitations under the License. */
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
......@@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once
#include "framework/operator.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
#include "operators/fusion_dequant_bn_relu_op.h"
namespace paddle_mobile {
namespace operators {
using framework::OpKernelBase;
template <typename Dtype, typename T>
void FusionDequantBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
template <typename DeviceType, typename T>
class DepthwiseConvKernel
: public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public:
void Compute(const ConvParam<DeviceType> &param);
bool Init(ConvParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_bn_relu,
ops::FusionDequantBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_bn_relu, ops::FusionDequantBNReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>> {
public:
FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -22,41 +22,43 @@ namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
if (depth3x3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
}
} else {
if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) {
if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
} else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Strides()[0] == param->Strides()[1] &&
} else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor *transformed_weight = new framework::Tensor;
framework::Tensor transformed_weight;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight);
param->Filter() = transformed_weight;
&transformed_weight);
framework::TensorCopy(transformed_weight, param->Filter());
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
......@@ -78,9 +80,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param);
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/kernel/dequant_add_bn_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
......@@ -24,8 +24,8 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
bool FusionDequantAddBNKernel<CPU, float>::Init(
FusionDequantAddBNParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
......@@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init(
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
void FusionDequantAddBNKernel<CPU, float>::Compute(
const FusionDequantAddBNParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>();
......@@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
......@@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
......@@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
y[k] = scale * x[k] + bias;
}
}
}
......@@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
#endif // FUSION_DEQUANT_ADD_BN_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
#if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP)
void DequantBNReluCompute(const FusionDequantBNParam<CPU> *param) {
const int32_t *input = param->input_->data<int32_t>();
const float *bn_scale = param->bn_scale_->data<float>();
const float *bn_bias = param->bn_bias_->data<float>();
// dequantize params
const float activation_scale = param->activation_scale_->data<float>()[0];
const float weight_scale = param->weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param->output_->mutable_data<float>();
int batch_size = param->input_->dims()[0];
int channels = param->input_->dims()[1];
size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <>
bool FusionDequantBNReluKernel<CPU, float>::Init(
FusionDequantBNReluParam<CPU> *param) {
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c];
}
return true;
}
template <>
void FusionDequantBNReluKernel<CPU, float>::Compute(
const FusionDequantBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -20,6 +20,9 @@ limitations under the License. */
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
namespace paddle_mobile {
namespace operators {
#ifndef __aarch64__
inline float32_t vmaxvq_f32(float32x4_t r) {
float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r));
......@@ -27,9 +30,13 @@ inline float32_t vmaxvq_f32(float32x4_t r) {
}
#endif
inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); }
template <RoundType R = ROUND_NEAREST_TOWARDS_ZERO>
inline int32x4_t vround_f32(float32x4_t r) {
return vcvtq_s32_f32(r);
}
inline int32x4_t vrnd_away_zero(float32x4_t r) {
template <>
inline int32x4_t vround_f32<ROUND_NEAREST_AWAY_ZERO>(float32x4_t r) {
float32x4_t plus = vdupq_n_f32(0.5);
float32x4_t minus = vdupq_n_f32(-0.5);
float32x4_t zero = vdupq_n_f32(0);
......@@ -40,31 +47,13 @@ inline int32x4_t vrnd_away_zero(float32x4_t r) {
return ret;
}
inline int32x4_t vrnd_to_even(float32x4_t r) {
#if 0
int32x4_t ret;
float value[4];
vst1q_f32(value, r);
for (int i = 0; i < 4; ++i) {
float v = round(value[i]);
int32_t q = (int32_t)v;
if (abs(abs(v - value[i]) - 0.5) > 0) {
ret[i] = q;
} else {
if (abs(q) % 2 == 0) {
ret[i] = q;
} else {
ret[i] = q + ((q > 0) ? -1 : 1);
}
}
}
return ret;
#else
template <>
inline int32x4_t vround_f32<ROUND_NEAREST_TO_EVEN>(float32x4_t r) {
float32x4_t point5 = vdupq_n_f32(0.5);
int32x4_t one = vdupq_n_s32(1);
int32x4_t zero = vdupq_n_s32(0);
int32x4_t rnd = vrnd_away_zero(r);
int32x4_t rnd = vround_f32<ROUND_NEAREST_AWAY_ZERO>(r);
float32x4_t frnd = vcvtq_f32_s32(rnd);
frnd = vsubq_f32(frnd, r);
frnd = vabsq_f32(frnd);
......@@ -82,115 +71,39 @@ inline int32x4_t vrnd_to_even(float32x4_t r) {
smask = vsubq_s32(smask, one);
rnd = vaddq_s32(rnd, smask);
return rnd;
#endif
}
namespace paddle_mobile {
namespace operators {
static float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
size_t size = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
for (size_t i = 0; i < loop; ++i) {
float32x4_t max;
float32x4_t r0 = vld1q_f32(x);
float32x4_t r1 = vld1q_f32(x + 4);
float32x4_t r2 = vld1q_f32(x + 8);
float32x4_t r3 = vld1q_f32(x + 12);
r0 = vabsq_f32(r0);
r1 = vabsq_f32(r1);
r2 = vabsq_f32(r2);
r3 = vabsq_f32(r3);
max[0] = vmaxvq_f32(r0);
max[1] = vmaxvq_f32(r1);
max[2] = vmaxvq_f32(r2);
max[3] = vmaxvq_f32(r3);
max[0] = vmaxvq_f32(max);
if (max[0] > max_abs) {
max_abs = max[0];
}
x += 16;
}
size = remain;
#endif
for (size_t i = 0; i < size; ++i) {
float value = std::abs(x[i]);
if (value > max_abs) {
max_abs = value;
}
}
return max_abs;
template <RoundType R = ROUND_NEAREST_TOWARDS_ZERO>
inline int8_t Round(const float &x) {
return static_cast<int8_t>(x);
}
#ifdef __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
Tensor *output) {
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
template <>
inline int8_t Round<ROUND_NEAREST_AWAY_ZERO>(const float &x) {
return std::round(x);
}
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
int8_t *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = vmulq_n_f32(r0, scale);
r1 = vmulq_n_f32(r1, scale);
r2 = vmulq_n_f32(r2, scale);
r3 = vmulq_n_f32(r3, scale);
int32x4_t q0 = vrnd_to_even(r0);
int32x4_t q1 = vrnd_to_even(r1);
int32x4_t q2 = vrnd_to_even(r2);
int32x4_t q3 = vrnd_to_even(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
int16x4_t d3 = vmovn_s32(q3);
int16x8_t q5 = vcombine_s16(d0, d1);
int16x8_t q6 = vcombine_s16(d2, d3);
int8x8_t d5 = vmovn_s16(q5);
int8x8_t d6 = vmovn_s16(q6);
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < size; ++i) {
float value = x[i] * scale;
float v = round(value);
int32_t q = (int32_t)v;
if (abs(abs(q - value) - 0.5) > 0) {
y[i] = q;
} else {
if (abs(q) % 2 == 0) {
y[i] = q;
} else {
y[i] = q + ((q > 0) ? -1 : 1);
}
template <>
inline int8_t Round<ROUND_NEAREST_TO_EVEN>(const float &x) {
float v = std::round(x);
int32_t q = static_cast<int32_t>(v);
if (std::abs(std::abs(q - v) - 0.5) <= 0) {
if (std::abs(q) % 2 != 0) {
q = q + ((q > 0) ? -1 : 1);
}
}
return static_cast<int8_t>(q);
}
static void quantize_round_to_zero(const Tensor *input, const float scale,
Tensor *output) {
template <RoundType R>
static void Quantize(const Tensor *input, const float scale, Tensor *output) {
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
size_t loop = remain >> 4;
remain = remain & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
......@@ -204,10 +117,10 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
r1 = vmulq_n_f32(r1, scale);
r2 = vmulq_n_f32(r2, scale);
r3 = vmulq_n_f32(r3, scale);
int32x4_t q0 = vrnd_towards_zero(r0);
int32x4_t q1 = vrnd_towards_zero(r1);
int32x4_t q2 = vrnd_towards_zero(r2);
int32x4_t q3 = vrnd_towards_zero(r3);
int32x4_t q0 = vround_f32<R>(r0);
int32x4_t q1 = vround_f32<R>(r1);
int32x4_t q2 = vround_f32<R>(r2);
int32x4_t q3 = vround_f32<R>(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
......@@ -219,561 +132,44 @@ static void quantize_round_to_zero(const Tensor *input, const float scale,
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < size; ++i) {
y[i] = static_cast<int8_t>(x[i] * scale);
for (size_t i = 0; i < remain; ++i) {
y[i] = Round<R>(x[i] * scale);
}
}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
Tensor *output) {
float find_abs_max(const Tensor *input) {
float max_abs = 0.f;
const float *x = input->data<const float>();
int8_t *y = output->mutable_data<int8_t>();
size_t size = input->numel();
size_t remain = input->numel();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = size >> 4;
size_t remain = size & 0xF;
size_t loop = remain >> 4;
remain = remain & 0xF;
float32x4_t __max = {0.f, 0.f, 0.f, 0.f};
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
int8_t *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
r0 = vmulq_n_f32(r0, scale);
r1 = vmulq_n_f32(r1, scale);
r2 = vmulq_n_f32(r2, scale);
r3 = vmulq_n_f32(r3, scale);
int32x4_t q0 = vrnd_away_zero(r0);
int32x4_t q1 = vrnd_away_zero(r1);
int32x4_t q2 = vrnd_away_zero(r2);
int32x4_t q3 = vrnd_away_zero(r3);
int16x4_t d0 = vmovn_s32(q0);
int16x4_t d1 = vmovn_s32(q1);
int16x4_t d2 = vmovn_s32(q2);
int16x4_t d3 = vmovn_s32(q3);
int16x8_t q5 = vcombine_s16(d0, d1);
int16x8_t q6 = vcombine_s16(d2, d3);
int8x8_t d5 = vmovn_s16(q5);
int8x8_t d6 = vmovn_s16(q6);
vst1_s8(local_y, d5);
vst1_s8(local_y + 8, d6);
for (size_t i = 0; i < loop; ++i, x += 16) {
float32x4_t r0 = vld1q_f32(x);
float32x4_t r1 = vld1q_f32(x + 4);
float32x4_t r2 = vld1q_f32(x + 8);
float32x4_t r3 = vld1q_f32(x + 12);
r0 = vabsq_f32(r0);
r1 = vabsq_f32(r1);
r2 = vabsq_f32(r2);
r3 = vabsq_f32(r3);
r0 = vmaxq_f32(r0, r1);
r1 = vmaxq_f32(r2, r3);
r0 = vmaxq_f32(r0, r1);
__max = vmaxq_f32(r0, __max);
}
size = remain;
x += (loop << 4);
y += (loop << 4);
max_abs = vmaxvq_f32(__max);
#endif
for (size_t i = 0; i < size; ++i) {
y[i] = round(x[i] * scale);
}
}
#else // __aarch64__
static void quantize_round_to_even(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {}
static void quantize_round_to_nearest(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val,
Tensor *output) {}
static void quantize_round_to_zero(const Tensor *input, const float scale,
const std::vector<int> &paddings,
const int8_t padding_val, Tensor *output) {
int channels = input->dims()[1];
int input_h = input->dims()[2];
int input_w = input->dims()[3];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int input_spatial_size = input_h * input_w;
int output_spatial_size = output_h * output_w;
const float *x = input->data<float>();
int8_t *y = output->mutable_data<int8_t>();
// valid area start
int start = paddings[0] * output_w + paddings[1];
for (int batch = 0; batch < input->dims()[0]; ++batch) {
#pragma omp parallel for
for (int c = 0; c < channels - 3; c += 4) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
const float *input1 = input0 + input_spatial_size;
const float *input2 = input1 + input_spatial_size;
const float *input3 = input2 + input_spatial_size;
size_t offset = (batch * channels + c) * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
int loop = start >> 4;
int remain = start & 0xF;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"vst1.32 {q0}, [%[y1]]! \n"
"vst1.32 {q0}, [%[y2]]! \n"
"vst1.32 {q0}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"vst1.32 {d0}, [%[y1]]! \n"
"vst1.32 {d0}, [%[y2]]! \n"
"vst1.32 {d0}, [%[y3]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
int8_t *y1 = y0 + output_spatial_size;
int8_t *y2 = y1 + output_spatial_size;
int8_t *y3 = y2 + output_spatial_size;
for (int h = 0; h < input_h; ++h) {
const float *x0 = input0 + h * input_w;
const float *x1 = input1 + h * input_w;
const float *x2 = input2 + h * input_w;
const float *x3 = input3 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
int remain_steps = remain;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"vst1.32 {q9}, [%[y0]]! \n"
"vst1.32 {q10}, [%[y1]]! \n"
"vst1.32 {q11}, [%[y2]]! \n"
"vst1.32 {q12}, [%[y3]]! \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble end_%= \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vld1.32 {q3, q4}, [%[x1]]! \n"
"vld1.32 {q5, q6}, [%[x2]]! \n"
"vld1.32 {q7, q8}, [%[x3]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d18, q1 \n"
"vmovn.s16 d20, q2 \n"
"vmovn.s16 d22, q3 \n"
"vmovn.s16 d24, q4 \n"
"vld1.32 {q1, q2}, [%[x0]] \n"
"vld1.32 {q3, q4}, [%[x1]] \n"
"vld1.32 {q5, q6}, [%[x2]] \n"
"vld1.32 {q7, q8}, [%[x3]] \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vmul.f32 q5, q5, q0 \n"
"vmul.f32 q6, q6, q0 \n"
"vmul.f32 q7, q7, q0 \n"
"vmul.f32 q8, q8, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vcvt.s32.f32 q3, q3 \n"
"vcvt.s32.f32 q4, q4 \n"
"vcvt.s32.f32 q5, q5 \n"
"vcvt.s32.f32 q6, q6 \n"
"vcvt.s32.f32 q7, q7 \n"
"vcvt.s32.f32 q8, q8 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s32 d4, q3 \n"
"vmovn.s32 d5, q4 \n"
"vmovn.s32 d6, q5 \n"
"vmovn.s32 d7, q6 \n"
"vmovn.s32 d8, q7 \n"
"vmovn.s32 d9, q8 \n"
"vmovn.s16 d19, q1 \n"
"vmovn.s16 d21, q2 \n"
"vmovn.s16 d23, q3 \n"
"vmovn.s16 d25, q4 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vst1.32 {d20}, [%[y1]]! \n"
"vst1.32 {d22}, [%[y2]]! \n"
"vst1.32 {d24}, [%[y3]]! \n"
"vmov.32 d18, d19 \n"
"vmov.32 d20, d21 \n"
"vmov.32 d22, d23 \n"
"vmov.32 d24, d25 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vst1.32 {d20[0]}, [%[y1]]! \n"
"vst1.32 {d22[0]}, [%[y2]]! \n"
"vst1.32 {d24[0]}, [%[y3]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"vext.32 d20, d20, d20, #1 \n"
"vext.32 d22, d22, d22, #1 \n"
"vext.32 d24, d24, d24, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vst1.16 {d20[0]}, [%[y1]]! \n"
"vst1.16 {d22[0]}, [%[y2]]! \n"
"vst1.16 {d24[0]}, [%[y3]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"vext.16 d20, d20, d20, #1 \n"
"vext.16 d22, d22, d22, #1 \n"
"vext.16 d24, d24, d24, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"vst1.8 {d20[0]}, [%[y1]]! \n"
"vst1.8 {d22[0]}, [%[y2]]! \n"
"vst1.8 {d24[0]}, [%[y3]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3),
[y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[loop] "+r"(loop), [remain] "+r"(remain)
: [scale] "r"(scale)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
asm volatile(
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble store_pad_2w_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"vst1.32 {d0[0]}, [%[y1]]! \n"
"vst1.32 {d0[0]}, [%[y2]]! \n"
"vst1.32 {d0[0]}, [%[y3]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"store_pad_2w_%=: \n"
"cmp %[pad_remain], #2 \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"vst1.16 {d0[0]}, [%[y1]]! \n"
"vst1.16 {d0[0]}, [%[y2]]! \n"
"vst1.16 {d0[0]}, [%[y3]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"vst1.8 {d0[0]}, [%[y1]]! \n"
"vst1.8 {d0[0]}, [%[y2]]! \n"
"vst1.8 {d0[0]}, [%[y3]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3),
[pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
}
for (int c = (channels & 0xFFFC); c < channels; ++c) {
const float *input0 = x + (batch * channels + c) * input_spatial_size;
size_t offset = (batch * channels + c) * output_spatial_size;
for (int h = 0; h < 2; ++h) {
int8_t *y0 =
y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]);
int loop = start >> 4;
int remain = start & 0xF;
asm volatile(
"vdup.s8 q0, %[val] \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"store_16w_%=: \n"
"vst1.32 {q0}, [%[y0]]! \n"
"subs %[loop], #1 \n"
"bne store_16w_%= \n"
"start_remain_%=: \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d0}, [%[y0]]! \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[remain], #2 \n"
"store_1w_%=: \n"
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain)
: [val] "r"(padding_val)
: "cc", "memory", "q0");
}
// quantize valid area
int8_t *y0 = y + offset + start;
for (int h = 0; h < input_h; ++h) {
const float *x0 = input0 + h * input_w;
int loop = input_w >> 4;
int remain = input_w & 0xF;
int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2
int pad_remain = (paddings[1] << 1) & 0x3;
asm volatile(
"vdup.f32 q0, %[scale] \n"
"cmp %[loop], #0 \n"
"ble quantize_remain_%= \n"
"loop_quantize_%=: \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vld1.32 {q1, q2}, [%[x0]]! \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"vst1.32 {q9}, [%[y0]]! \n"
"subs %[loop], #1 \n"
"bne loop_quantize_%= \n"
"quantize_remain_%=: \n"
"cmp %[remain], #0 \n"
"ble start_pad_%= \n"
"vldm %[x0], {d2-d9} \n"
"vmul.f32 q1, q1, q0 \n"
"vmul.f32 q2, q2, q0 \n"
"vcvt.s32.f32 q1, q1 \n"
"vcvt.s32.f32 q2, q2 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d18, q1 \n"
"vmul.f32 q3, q3, q0 \n"
"vmul.f32 q4, q4, q0 \n"
"vcvt.s32.f32 q1, q3 \n"
"vcvt.s32.f32 q2, q4 \n"
"vmovn.s32 d2, q1 \n"
"vmovn.s32 d3, q2 \n"
"vmovn.s16 d19, q1 \n"
"cmp %[remain], #8 \n"
"blt store_4w_%= \n"
"vst1.32 {d18}, [%[y0]]! \n"
"vmov.32 d18, d19 \n"
"sub %[remain], #8 \n"
"store_4w_%=: \n"
"cmp %[remain], #4 \n"
"blt store_2w_%= \n"
"vst1.32 {d18[0]}, [%[y0]]! \n"
"vext.32 d18, d18, d18, #1 \n"
"sub %[remain], #4 \n"
"store_2w_%=: \n"
"cmp %[remain], #2 \n"
"blt store_1w_%= \n"
"vst1.16 {d18[0]}, [%[y0]]! \n"
"vext.16 d18, d18, d18, #1 \n"
"sub %[remain], #2 \n"
"store_1w_%=:"
"cmp %[remain], #1 \n"
"blt start_pad_%= \n"
"vst1.8 {d18[0]}, [%[y0]]! \n"
"start_pad_%=: \n"
"vdup.s8 d0, %[val] \n"
"cmp %[pad_loop], #0 \n"
"ble pad_remain_%= \n"
"loop_pad_4w_%=: \n"
"vst1.32 {d0[0]}, [%[y0]]! \n"
"subs %[pad_loop], #1 \n"
"bne loop_pad_4w_%= \n"
"pad_remain_%=: \n"
"cmp %[pad_remain], #2 \n"
"blt store_pad_1w_%= \n"
"vst1.16 {d0[0]}, [%[y0]]! \n"
"sub %[pad_remain], #2 \n"
"store_pad_1w_%=: \n"
"cmp %[pad_remain], #1 \n"
"blt end_%= \n"
"vst1.8 {d0[0]}, [%[y0]]! \n"
"end_%=: \n"
: [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop),
[remain] "+r"(remain), [pad_loop] "+r"(pad_loop),
[pad_remain] "+r"(pad_remain)
: [scale] "r"(scale), [val] "r"(padding_val)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q9");
}
}
for (size_t i = 0; i < remain; ++i) {
max_abs = std::max(max_abs, std::abs(x[i]));
}
return max_abs;
}
#endif // __aarch64__
#endif // ARM_NEON
template <>
bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
......@@ -795,19 +191,15 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
// only support int8 currently
float scale = 127 / max_abs;
param.online_scale_->mutable_data<float>()[0] = max_abs;
const auto &paddings = param.paddings_;
// std::vector<int> paddings = {0, 0};
// const auto padding_val = param.padding_val_;
int8_t padding_val = 0;
switch (param.round_type_) {
case ROUND_NEAREST_TO_EVEN:
quantize_round_to_even(input, scale, paddings, padding_val, output);
Quantize<ROUND_NEAREST_TO_EVEN>(input, scale, output);
break;
case ROUND_NEAREST_TOWARDS_ZERO:
quantize_round_to_zero(input, scale, paddings, padding_val, output);
Quantize<ROUND_NEAREST_TOWARDS_ZERO>(input, scale, output);
break;
case ROUND_NEAREST_AWAY_ZERO:
quantize_round_to_nearest(input, scale, paddings, padding_val, output);
Quantize<ROUND_NEAREST_AWAY_ZERO>(input, scale, output);
break;
default:
LOG(kLOG_ERROR) << "round type is not supported.";
......
......@@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
// param.Output(), false);
if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true);
param.Bias(), true);
} else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), *param.Bias(), true);
param.Output(), param.Bias(), true);
}
} else {
ConvAddBasic(param);
......
......@@ -164,31 +164,21 @@ template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output();
output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch);
math::DepthwiseConv3x3S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch);
math::DepthwiseConv3x3S2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch);
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_OP
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -23,12 +23,12 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
class FusionDequantAddBNKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
FusionDequantAddBNParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
void Compute(const FusionDequantAddBNParam<DeviceType> &param);
bool Init(FusionDequantAddBNParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -12,42 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DEPTHWISECONV_OP
#pragma once
#include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/depthwise_conv3x3.h"
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
&Bias, false);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConv3x3(param.Input(), param.Strides(),
// param.Paddings(),
// param.Filter(), &Bias, param.Output(), false);
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(),
Bias, false);
} else {
GemmConv<float, float>(param);
}
}
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <typename DeviceType, typename T>
class FusionDequantBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantBNReluParam<DeviceType> &param);
bool Init(FusionDequantBNReluParam<DeviceType> *param);
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
};
#endif
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -1272,13 +1272,16 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *bias_data = bias.data<float>();
const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]);
......@@ -1905,7 +1908,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias) {
#if __ARM_NEON
......@@ -1925,7 +1928,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias.data<float>() + c;
const float *bias_data = bias->data<float>() + c;
float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0];
float w01 = filter_data[1];
......
......@@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
......@@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias,
framework::Tensor *output, framework::Tensor *bias,
bool if_bias);
// TODO(hjchen2) need to be implemented
......@@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
// void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter,
// const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input,
void DepthwiseConv3x3S1(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input,
void DepthwiseConv3x3S2(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output);
} // namespace math
......
......@@ -12,12 +12,300 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(__ARM_NEON__) && !defined(__aarch64__)
#include "operators/math/depthwise_conv3x3.h"
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
template <int Stride>
inline void Depth3x3ValidColLoadInput(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride);
}
template <>
inline void Depth3x3ValidColLoadInput<1>(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
int8_t fake_input[3][8];
if (valid_cols == 1) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 8; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
int8x8_t input0 = vld1_s8(fake_input[0]);
int8x8_t input1 = vld1_s8(fake_input[1]);
int8x8_t input2 = vld1_s8(fake_input[2]);
y0[0] = vmovl_s8(input0);
y1[0] = vmovl_s8(input1);
y2[0] = vmovl_s8(input2);
y0[1] = vextq_s16(y0[0], y0[0], 1);
y0[2] = vextq_s16(y0[0], y0[0], 2);
y1[1] = vextq_s16(y1[0], y1[0], 1);
y1[2] = vextq_s16(y1[0], y1[0], 2);
y2[1] = vextq_s16(y2[0], y2[0], 1);
y2[2] = vextq_s16(y2[0], y2[0], 2);
}
template <>
inline void Depth3x3ValidColLoadInput<2>(const int8_t *input, const int input_w,
const int valid_cols, int16x8_t *y0,
int16x8_t *y1, int16x8_t *y2) {
int8_t fake_input[3][13];
if (valid_cols == 1) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
}
} else if (valid_cols == 2) {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
}
} else {
for (int i = 0; i < 13; ++i, input += input_w) {
fake_input[0][i] = input[0];
fake_input[1][i] = input[1];
fake_input[2][i] = input[2];
}
}
int8x8x2_t input0 = vld2_s8(fake_input[0]);
int8x8x2_t input1 = vld2_s8(fake_input[1]);
int8x8x2_t input2 = vld2_s8(fake_input[2]);
y0[0] = vmovl_s8(input0.val[0]);
y0[1] = vmovl_s8(input0.val[1]);
y0[2] = vextq_s16(y0[0], y0[0], 1);
y1[0] = vmovl_s8(input1.val[0]);
y1[1] = vmovl_s8(input1.val[1]);
y1[2] = vextq_s16(y1[0], y1[0], 1);
y2[0] = vmovl_s8(input2.val[0]);
y2[1] = vmovl_s8(input2.val[1]);
y2[2] = vextq_s16(y2[0], y2[0], 1);
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv3x3ValidCol(const int8_t *input, const int8_t *filter,
const int h_output, const int h_output_end,
const int w_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
int32_t *output) {
const int w_in_start = -padding_w + w_output * Stride_w;
const int w_in_end = w_in_start + 3;
const int w_start = w_in_start > 0 ? w_in_start : 0;
const int w_end = w_in_end < input_w ? w_in_end : input_w;
int remain_start = h_output;
#ifdef __ARM_NEON__
int output_tiles = (h_output_end - h_output) / 6;
remain_start = h_output + output_tiles * 6;
int input_h_start = h_output * Stride_h - padding_h;
size_t input_offset = input_h_start * input_w + w_start;
size_t output_offset = h_output * output_w + w_output;
int16x8_t _input[3][3];
int16x4_t _kernel[3];
int32x4_t _sum0, _sum1;
const int8_t *filter_ptr = filter;
asm volatile(
"mov r0, #3 \n"
"vld1.s8 d10, [%[filter]], r0 \n"
"vld1.s8 d11, [%[filter]], r0 \n"
"vld1.s8 d12, [%[filter]] \n"
"vtrn.8 d10, d11 \n"
"vtrn.8 d12, d13 \n"
"vtrn.16 d10, d12 \n"
"vtrn.16 d11, d13 \n"
"vmovl.s8 q7, d10 \n"
"vmovl.s8 q8, d11 \n"
"vmovl.s8 q9, d12 \n"
"vmov.32 %[_kernel0], d14 \n"
"vmov.32 %[_kernel1], d16 \n"
"vmov.32 %[_kernel2], d18 \n"
: [_kernel0] "+w"(_kernel[0]), [_kernel1] "+w"(_kernel[1]),
[_kernel2] "+w"(_kernel[2])
: [filter] "r"(filter_ptr)
: "memory", "q5", "q6", "q7", "q8", "q9", "r0");
int valid_cols = w_end - w_start;
for (int h = 0; h < output_tiles * 6; h += 6) {
int32_t *output0 = output + output_offset;
int32_t *output1 = output0 + output_w;
int32_t *output2 = output1 + output_w;
int32_t *output3 = output2 + output_w;
int32_t *output4 = output3 + output_w;
int32_t *output5 = output4 + output_w;
Depth3x3ValidColLoadInput<Stride_w>(input + input_offset, input_w,
valid_cols, _input[0], _input[1],
_input[2]);
_sum0 = veorq_s32(_sum0, _sum0);
_sum1 = veorq_s32(_sum1, _sum1);
for (int w_in = 0; w_in < valid_cols; ++w_in) {
int index = w_in + w_start - w_in_start;
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][0]),
_kernel[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][1]),
_kernel[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(_input[w_in][2]),
_kernel[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][0]),
_kernel[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][1]),
_kernel[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(_input[w_in][2]),
_kernel[index], 2);
}
vst1q_lane_s32(output0, _sum0, 0);
vst1q_lane_s32(output1, _sum0, 1);
vst1q_lane_s32(output2, _sum0, 2);
vst1q_lane_s32(output3, _sum0, 3);
vst1q_lane_s32(output4, _sum1, 0);
vst1q_lane_s32(output5, _sum1, 1);
input_offset += 6 * Stride_h * input_w;
output_offset += 6 * output_w;
}
#endif
for (int h = remain_start; h < h_output_end; ++h) {
int32_t value = 0;
const int h_in_start = -padding_h + h * Stride_h;
for (int i = 0; i < 3; ++i) {
for (int w_in = w_start; w_in < w_end; ++w_in) {
value += filter[i * 3 + (w_in - w_in_start)] *
input[(h_in_start + i) * input_w + w_in];
}
}
output[h * output_w + w_output] = value;
}
}
#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \
for (int w = start; w < end; ++w) { \
const int w_in_start = -padding_w + w * Stride_w; \
const int w_in_end = w_in_start + 3; \
const int w_start = w_in_start > 0 ? w_in_start : 0; \
const int w_end = w_in_end < input_w ? w_in_end : input_w; \
int32_t value = 0; \
for (int h_in = h_start; h_in < h_end; ++h_in) { \
for (int w_in = w_start; w_in < w_end; ++w_in) { \
value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \
input[h_in * input_w + w_in]; \
} \
} \
output_ptr[w] = value; \
}
template <int Stride>
inline void Depth3x3NormalRowLoadInput(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
PADDLE_MOBILE_THROW_EXCEPTION("Stride %d is not supported.", Stride);
}
template <>
inline void Depth3x3NormalRowLoadInput<1>(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
int8x8_t x0 = vld1_s8(input);
y0 = vmovl_s8(x0);
y1 = vextq_s16(y0, y0, 1);
y2 = vextq_s16(y1, y1, 1);
}
template <>
inline void Depth3x3NormalRowLoadInput<2>(const int8_t *input,
int16x8_t &y0, // NOLINT
int16x8_t &y1, // NOLINT
int16x8_t &y2) { // NOLINT
int8x8x2_t x0 = vld2_s8(input);
y0 = vmovl_s8(x0.val[0]);
y1 = vmovl_s8(x0.val[1]);
y2 = vextq_s16(y0, y0, 1);
}
template <int Stride_h, int Stride_w>
inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter,
const int h_output, const int input_h,
const int input_w, const int padding_h,
const int padding_w, const int output_w,
int32_t *output) {
const int h_in_start = -padding_h + h_output * Stride_h;
const int h_in_end = h_in_start + 3;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int h_end = h_in_end < input_h ? h_in_end : input_h;
int valid_w_start = (padding_w + Stride_w - 1) / Stride_w;
int valid_w_end = output_w - valid_w_start;
int32_t *output_ptr = output + h_output * output_w;
// border left
DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start)
// middle
int remain_start = valid_w_start;
#ifdef __ARM_NEON__
int output_tiles = (valid_w_end - valid_w_start) / 6;
remain_start = valid_w_start + output_tiles * 6;
int32x4_t _sum0, _sum1;
int16x8_t y0, y1, y2;
int16x4_t _kernel[3];
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
int8x8_t w0 = vld1_s8(filter + index * 3);
int16x8_t w1 = vmovl_s8(w0);
_kernel[index] = vget_low_s16(w1);
}
for (int w = 0; w < output_tiles * 6; w += 6) {
_sum0 = veorq_s32(_sum0, _sum0);
_sum1 = veorq_s32(_sum1, _sum1);
int output_offset = valid_w_start + w;
int input_w_offset = output_offset * Stride_w - padding_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
int index = h_in - h_in_start;
Depth3x3NormalRowLoadInput<Stride_w>(
input + h_in * input_w + input_w_offset, y0, y1, y2);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y0), _kernel[index], 0);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y1), _kernel[index], 1);
_sum0 = vmlal_lane_s16(_sum0, vget_low_s16(y2), _kernel[index], 2);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y0), _kernel[index], 0);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y1), _kernel[index], 1);
_sum1 = vmlal_lane_s16(_sum1, vget_high_s16(y2), _kernel[index], 2);
}
vst1q_s32(output_ptr + output_offset, _sum0);
vst1q_lane_s32(output_ptr + output_offset + 4, _sum1, 0);
vst1q_lane_s32(output_ptr + output_offset + 5, _sum1, 1);
}
#endif
for (int w = remain_start; w < valid_w_end; ++w) {
int32_t value = 0;
int input_start = -padding_w + w * Stride_w;
for (int h_in = h_start; h_in < h_end; ++h_in) {
for (int j = 0; j < 3; ++j) {
value += filter[(h_in - h_in_start) * 3 + j] *
input[h_in * input_w + j + input_start];
}
}
output_ptr[w] = value;
}
// border right
DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w)
}
// template<>
// void DepthwiseConv3x3<int8_t, int32_t>(
// const framework::Tensor *input, const framework::Tensor *filter,
......@@ -27,43 +315,72 @@ namespace math {
// }
template <>
void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
void DepthwiseConv3x3S1<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
#if __aarch64__
// TODO(hjchen2)
#else
int valid_h_start = padding_h;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for
for (int g = 0; g < input_c; ++g) {
const int8_t* input_ptr = input_data + g * image_size;
const int8_t* filter_ptr = filter_data + g * 9;
int32_t* output_ptr = out_data + g * out_image_size;
int loops = (input_w - 2) / 6;
int remain = input_w - 2 - loops * 6;
for (int h = 0; h < input_h - 5 /*(input_h - 2) - 3*/; h += 4) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int32_t* output_ptr2 = output_ptr1 + output_w;
int32_t* output_ptr3 = output_ptr2 + output_w;
int loop = loops;
for (int g = 0; g < input.dims()[1]; ++g) {
const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 9;
int32_t *output_ptr = out_data + g * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
DepthwiseConv3x3ValidCol<1, 1>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
DepthwiseConv3x3ValidCol<1, 1>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
}
// valid
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 3; h += 4) {
const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr1 = output_ptr0 + output_w;
int32_t *output_ptr2 = output_ptr1 + output_w;
int32_t *output_ptr3 = output_ptr2 + output_w;
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
......@@ -377,27 +694,27 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"vst1.32 {d10[0]}, [%[output_ptr3]]! \n"
"end_%=: \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [output_ptr3] "+r"(output_ptr3),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop)
: [remain] "r"(remain)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
}
// remain height
int start_h = (input_h - 2) & 0xFFFC;
for (int h = start_h; h < input_h - 3 /*(input_h - 2) - 1*/; h += 2) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
int32_t* output_ptr0 = output_ptr + h * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int loop = loops;
int start_h = valid_h_start + (valid_h & 0xFFFC);
for (int h = start_h; h < valid_h_end - 1; h += 2) {
const int8_t *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr1 = output_ptr0 + output_w;
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
......@@ -415,9 +732,9 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
: [filter_ptr] "r"(filter_ptr)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
// loop 6 widths
"loop_2h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
......@@ -589,23 +906,23 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"end_%=: \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[loop] "+r"(loop)
: [remain] "r"(remain)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "r0");
}
start_h = (input_h - 2) & 0xFFFE;
if (start_h < input_h - 2) {
const int8_t* input_ptr0 = input_ptr + start_h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
int32_t* output_ptr0 = output_ptr + start_h * output_w;
int loop = loops;
start_h = valid_h_start + (valid_h & 0xFFFE);
if (start_h < valid_h_end) {
const int8_t *input_ptr0 = input_ptr + (start_h - padding_h) * input_w;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
int32_t *output_ptr0 = output_ptr + start_h * output_w + valid_w_start;
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
......@@ -623,9 +940,9 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #6 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #6 \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld1.32 {d9}, [%[input_ptr0]], r0 \n"
......@@ -736,56 +1053,91 @@ void DepthwiseConv3x3s1<int8_t, int32_t>(const framework::Tensor &input,
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop)
: [remain] "r"(remain)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "r0");
}
}
#endif // __aarch64__
}
template <>
void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
void DepthwiseConv3x3S2<int8_t, int32_t>(const framework::Tensor &input,
const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output) {
const int8_t *input_data = input.data<int8_t>();
const int8_t *filter_data = filter.data<int8_t>();
int32_t *out_data = output->mutable_data<int32_t>();
// make sure that batch size is 1
int input_c = input.dims()[1];
int input_h = input.dims()[2];
int input_w = input.dims()[3];
int output_c = output->dims()[1];
int output_h = output->dims()[2];
int output_w = output->dims()[3];
int padding_h = paddings[0];
int padding_w = paddings[1];
int image_size = input_h * input_w;
int out_image_size = output_h * output_w;
#if __aarch64__
// TODO(hjchen2)
#else
int valid_h_start = (padding_h + 1) / 2;
int valid_h_end = output_h - valid_h_start;
int valid_h = valid_h_end - valid_h_start;
int valid_w_start = (padding_w + 1) / 2;
int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start;
// DLOG << "valid_h_start: " << valid_h_start;
// DLOG << "valid_h_end: " << valid_h_end;
// DLOG << "valid_w_start: " << valid_w_start;
// DLOG << "valid_w_end: " << valid_w_end;
#pragma omp parallel for
for (int g = 0; g < input_c; ++g) {
const int8_t* input_ptr = input_data + g * image_size;
const int8_t* filter_ptr = filter_data + g * 9;
int32_t* output_ptr = out_data + g * out_image_size;
int loops = output_w / 6;
int remain = output_w - loops * 6;
for (int h = 0; h < input_h - 6 /*(input_h - 1) - 5*/; h += 6) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
const int8_t* input_ptr3 = input_ptr2 + input_w;
const int8_t* input_ptr4 = input_ptr3 + input_w;
const int8_t* input_ptr5 = input_ptr4 + input_w;
const int8_t* input_ptr6 = input_ptr5 + input_w;
int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w;
int32_t* output_ptr1 = output_ptr0 + output_w;
int32_t* output_ptr2 = output_ptr1 + output_w;
int loop = loops;
for (int g = 0; g < input.dims()[1]; ++g) {
const int8_t *input_ptr = input_data + g * image_size;
const int8_t *filter_ptr = filter_data + g * 9;
int32_t *output_ptr = out_data + g * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
DepthwiseConv3x3ValidCol<2, 2>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
DepthwiseConv3x3ValidCol<2, 2>(
input_ptr, filter_ptr, valid_h_start, valid_h_end, w, input_h,
input_w, padding_h, padding_w, output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h,
input_w, padding_h, padding_w, output_w,
output_ptr);
}
// valid
int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const int8_t *input_ptr0 = input_ptr + offset;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
const int8_t *input_ptr3 = input_ptr2 + input_w;
const int8_t *input_ptr4 = input_ptr3 + input_w;
const int8_t *input_ptr5 = input_ptr4 + input_w;
const int8_t *input_ptr6 = input_ptr5 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int32_t *output_ptr1 = output_ptr0 + output_w;
int32_t *output_ptr2 = output_ptr1 + output_w;
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
......@@ -803,9 +1155,9 @@ void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #12 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #12 \n"
// loop 6 widths
"loop_3h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
......@@ -1048,25 +1400,26 @@ void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
"vst1.32 {d20[0]}, [%[output_ptr0]]! \n"
"vst1.32 {d24[0]}, [%[output_ptr1]]! \n"
"vst1.32 {d28[0]}, [%[output_ptr2]]! \n"
"end_%=: \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [output_ptr1] "+r"(output_ptr1),
[output_ptr2] "+r"(output_ptr2), [input_ptr6] "+r"(input_ptr6),
[input_ptr0] "+r"(input_ptr0), [input_ptr1] "+r"(input_ptr1),
[input_ptr2] "+r"(input_ptr2), [input_ptr3] "+r"(input_ptr3),
[input_ptr4] "+r"(input_ptr4), [input_ptr5] "+r"(input_ptr5),
[loop] "+r"(loop)
: [remain] "r"(remain)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
}
int start_h = (output_h / 3) * 6;
for (int h = start_h; h < input_h - 2 /*(input_h - 1) - 1*/; h += 2) {
const int8_t* input_ptr0 = input_ptr + h * input_w;
const int8_t* input_ptr1 = input_ptr0 + input_w;
const int8_t* input_ptr2 = input_ptr1 + input_w;
int32_t* output_ptr0 = output_ptr + (h >> 1) * output_w;
int loop = loops;
int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const int8_t *input_ptr0 = input_ptr + offset;
const int8_t *input_ptr1 = input_ptr0 + input_w;
const int8_t *input_ptr2 = input_ptr1 + input_w;
int32_t *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int loop = output_w_tiles;
asm volatile(
"vld1.32 {q0}, [%[filter_ptr]] \n"
"vmovl.s8 q14, d0 \n"
......@@ -1084,9 +1437,9 @@ void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
: [filter_ptr] "r"(filter_ptr)
: "memory", "q0", "q1", "q2", "q3", "q4", "q14", "q15");
asm volatile(
"mov r0, #12 \n"
"cmp %[loop], #0 \n"
"ble start_remain_%= \n"
"mov r0, #12 \n"
// loop 6 widths
"loop_1h6w_%=: \n"
"vld2.8 {d10, d11}, [%[input_ptr0]], r0 \n"
......@@ -1190,18 +1543,19 @@ void DepthwiseConv3x3s2<int8_t, int32_t>(const framework::Tensor &input,
"cmp %[remain], #1 \n"
"blt end_%= \n"
"vst1.32 {d22[0]}, [%[output_ptr0]]! \n"
"end_%=: \n"
"end_%=: \n"
: [output_ptr0] "+r"(output_ptr0), [input_ptr0] "+r"(input_ptr0),
[input_ptr1] "+r"(input_ptr1), [input_ptr2] "+r"(input_ptr2),
[loop] "+r"(loop)
: [remain] "r"(remain)
: [remain] "r"(output_w_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "r0");
}
}
#endif // __aarch64__
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const RType *Input() const { return input_; }
RType *&Filter() const { return filter_; }
RType *Filter() const { return filter_; }
RType *&Output() const { return output_; }
RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; }
......@@ -419,6 +419,8 @@ class ConvParam : public OpParam {
EXEC_INVALID = 0,
EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT,
......@@ -439,8 +441,8 @@ class ConvParam : public OpParam {
private:
RType *input_;
mutable RType *output_;
mutable RType *filter_;
RType *output_;
RType *filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
......@@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
if (outputs.count("Out")) {
output_ = OutFrom<GType>(outputs, scope);
}
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) {
......@@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam {
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#if defined(FUSION_DEQUANT_ADD_BN_OP) || \
defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \
defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP)
template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
class FusionDequantBNParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
FusionDequantBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
......@@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output
output_ = OpParam::OutFrom<GType>(outputs, scope);
if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
}
public:
// elementwise add
int axis_;
RType *bias_;
// batch norm
RType *bn_mean_;
RType *bn_variance_;
RType *bn_scale_;
RType *bn_bias_;
float epsilon_;
// output
RType *output_;
};
#endif
#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP)
template <typename Dtype>
class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// output
if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
}
public:
// elementwise add
int axis_;
RType *bias_;
};
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <typename Dtype>
class FusionDequantBNReluParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public FusionDequantAddBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantAddBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
};
#endif
......
......@@ -44,25 +44,19 @@ struct Round<round::RoundTowardsZero> {
template <>
struct Round<round::RoundToEven> {
int8_t operator()(float x) {
int8_t ret = 0;
float v = std::round(x);
int32_t q = (int32_t)v;
if (abs(abs(q - x) - 0.5) > 0) {
ret = q;
} else {
if (abs(q) % 2 == 0) {
ret = q;
} else {
ret = q + ((q > 0) ? -1 : 1);
int32_t q = static_cast<int32_t>(v);
if (abs(abs(q - v) - 0.5) <= 0) {
if (abs(q) % 2 != 0) {
q = q + ((q > 0) ? -1 : 1);
}
}
return ret;
return static_cast<int8_t>(q);
}
};
template <round::RoundType T>
static void quantize(const Tensor *input, const float scale, const int pad,
const int8_t pad_val, Tensor *output) {
static void quantize(const Tensor *input, const float scale, Tensor *output) {
int batch_size = input->dims()[0];
int channels = input->dims()[1];
int input_h = input->dims()[2];
......@@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad,
for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>()(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
yh[w] = Round<T>()(xh[w] * scale);
}
}
}
......@@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) {
int TestQuqntizeOp(int argc, char *argv[]) {
if (argc < 5) {
std::cout
<< "Usage: ./test-quantize-op batch_size channel height width [pad]"
<< std::endl;
std::cout << "Usage: ./test-quantize-op batch_size channel height width"
<< std::endl;
return 1;
}
int pad = 0;
int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]);
int height = atoi(argv[3]);
int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim =
......@@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) {
auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope);
op->InferShape();
......@@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) {
framework::Tensor output_cmp;
output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp;
// quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp);
// quantize<round::RoundToEven>(input, scale, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
......
......@@ -249,7 +249,9 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON)
set(FUSION_DEQUANT_ADD_BN_OP ON)
set(FUSION_DEQUANT_BN_RELU_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -451,10 +453,17 @@ endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU)
if (FUSION_DEQUANT_ADD_BN_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_OP)
endif()
if (FUSION_DEQUANT_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_BN_RELU_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
endif()
......@@ -467,3 +476,4 @@ endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册