diff --git a/src/common/types.cpp b/src/common/types.cpp index ba00f639d76ae7c928f5b7484c08cec0b0926904..36c93046c1c09d0ec5043ef9a7514dedf212e738 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; + const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add"; @@ -134,6 +136,7 @@ std::unordered_map< {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index e9c0f81232dab7583c57fb036b58601aa26ec3c9..5704618c9e475781b22df4ae3a0ac3a994eb8c90 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_FUSION_DECONV_RELU; diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 982f1c0f3525afde8475866c0121343fafc9d5a0..135ef9083e42271fe63cdc29ee53e876f532c287 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU); #ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +LOAD_OP1(fusion_dequant_add_bn_relu, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_op.cpp b/src/operators/fusion_dequant_add_bn_relu_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..80d9040afb29b7a42c742b821e9d7522c1a12827 --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_op.cpp @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "operators/fusion_dequant_add_bn_relu_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void FusionDequantAddBNReluOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu, + ops::FusionDequantAddBNReluMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu, + ops::FusionDequantAddBNReluOp); +#endif + +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_op.h b/src/operators/fusion_dequant_add_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dbd9ad0de2ece751ffd4da05cb09f0091a5755aa --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_op.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; } +}; + +template +class FusionDequantAddBNReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluParam, + operators::FusionDequantAddBNReluKernel> { + public: + FusionDequantAddBNReluOp(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluParam, + operators::FusionDequantAddBNReluKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfe1935c216f94d660997b1bfa42f18e63295992 --- /dev/null +++ b/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +namespace paddle_mobile { +namespace operators { + +template <> +bool FusionDequantAddBNReluKernel::Init( + FusionDequantAddBNReluParam *param) { + // elementwise add params + const Tensor *bias = param->bias_; + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *bias_ptr = bias->data(); + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c]; + } + return true; +} + +template <> +void FusionDequantAddBNReluKernel::Compute( + const FusionDequantAddBNReluParam ¶m) { + const int32_t *input = param.input_->data(); + const float *bn_scale = param.bn_scale_->data(); + const float *bn_bias = param.bn_bias_->data(); + // dequantize params + const float activation_scale = param.activation_scale_->data()[0]; + const float weight_scale = param.weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + + float *output = param.output_->mutable_data(); + int batch_size = param.input_->dims()[0]; + int channels = param.input_->dims()[1]; + size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3]; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + float scale = bn_scale[c] * dequant_scale; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + float *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __zero = vdupq_n_f32(0.f); + + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmaxq_f32(__zero, f0); + f1 = vmaxq_f32(__zero, f1); + f2 = vmaxq_f32(__zero, f2); + f3 = vmaxq_f32(__zero, f3); + vst1q_f32(y, f0); + vst1q_f32(y + 4, f1); + vst1q_f32(y + 8, f2); + vst1q_f32(y + 12, f3); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + y[k] = std::max(scale * x[k] + bias, 0.f); + } + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // FUSION_DEQUANT_ADD_BN_RELU_OP diff --git a/src/operators/kernel/dequant_add_bn_relu_kernel.h b/src/operators/kernel/dequant_add_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7138e5c415caca6766913f9959bd41def0943d34 --- /dev/null +++ b/src/operators/kernel/dequant_add_bn_relu_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class FusionDequantAddBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantAddBNReluParam ¶m); + bool Init(FusionDequantAddBNReluParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index e864c9d631e384a52171223814166ea6709df8ad..3593ecc9831f6bf627273b0abb5e75cf8a168dbf 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -2526,7 +2526,7 @@ class QuantizeParam : public OpParam { output_ = OutFrom(outputs, scope); // online // scale = max(abs(x)) - online_scale_ = GetVarValue("OutScale", outputs, scope); + online_scale_ = OpParam::GetVarValue("OutScale", outputs, scope); // offline if (HasAttr("static_scale", attrs)) { is_static_ = true; @@ -2574,7 +2574,7 @@ class DequantizeParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); output_ = OutFrom(outputs, scope); - activation_scale_ = GetVarValue("Scale", inputs, scope); + activation_scale_ = OpParam::GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale if (HasAttr("weight_scale", attrs)) { weight_scale_ = GetAttr("weight_scale", attrs); @@ -2593,5 +2593,44 @@ class DequantizeParam : public OpParam { }; #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluParam : public DequantizeParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : DequantizeParam(inputs, outputs, attrs, scope) { + // element wise add params + axis_ = OpParam::GetAttr("axis", attrs); + bias_ = OpParam::InputYFrom(inputs, scope); + // batch norm params + bn_mean_ = OpParam::GetVarValue("BNMean", inputs, scope); + bn_variance_ = OpParam::GetVarValue("BNVariance", inputs, scope); + bn_scale_ = OpParam::GetVarValue("BNScale", inputs, scope); + bn_bias_ = OpParam::GetVarValue("BNBias", inputs, scope); + epsilon_ = OpParam::GetAttr("epsilon", attrs); + // output + output_ = OpParam::OutFrom(outputs, scope); + } + + public: + // elementwise add + int axis_; + RType *bias_; + // batch norm + RType *bn_mean_; + RType *bn_variance_; + RType *bn_scale_; + RType *bn_bias_; + float epsilon_; + // output + RType *output_; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/tools/op.cmake b/tools/op.cmake index 3a4a0597a44694c4edea8173af47627cb5680df2..98a5ce437ae6520a4cc27f9fceeadaeb30ba6e99 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -249,6 +249,7 @@ if(NOT FOUND_MATCH) set(SUM_OP ON) set(QUANT_OP ON) set(DEQUANT_OP ON) + set(FUSION_DEQUANT_ADD_BN_RELU ON) endif() # option(BATCHNORM_OP "" ON) @@ -450,6 +451,9 @@ endif() if (DEQUANT_OP) add_definitions(-DDEQUANT_OP) endif() +if (FUSION_DEQUANT_ADD_BN_RELU) + add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) +endif() if (TANH_OP) add_definitions(-DTANH_OP) @@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP) endif() if (FUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP) -endif() \ No newline at end of file +endif()