提交 e54bf8c5 编写于 作者: H hjchen2

Add dequant_add_bn_relu fusion kernel

上级 ce9cc318
......@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh";
const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu";
const char *G_OP_TYPE_FUSION_DECONV_ADD = "fusion_deconv_add";
......@@ -134,6 +136,7 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}},
......
......@@ -138,6 +138,7 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH;
extern const char *G_OP_TYPE_FUSION_DECONV_RELU;
......
......@@ -233,3 +233,7 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/fusion_dequant_add_bn_relu_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void FusionDequantAddBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu,
ops::FusionDequantAddBNReluOp);
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>> {
public:
FusionDequantAddBNReluOp(const std::string &type,
const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNReluParam<DeviceType>,
operators::FusionDequantAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>();
// dequantize params
const float activation_scale = param.activation_scale_->data<float>()[0];
const float weight_scale = param.weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param.output_->mutable_data<float>();
int batch_size = param.input_->dims()[0];
int channels = param.input_->dims()[1];
size_t spatial_size = param.input_->dims()[2] * param.input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel
: public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -2526,7 +2526,7 @@ class QuantizeParam : public OpParam {
output_ = OutFrom<GType>(outputs, scope);
// online
// scale = max(abs(x))
online_scale_ = GetVarValue<GType>("OutScale", outputs, scope);
online_scale_ = OpParam::GetVarValue<GType>("OutScale", outputs, scope);
// offline
if (HasAttr("static_scale", attrs)) {
is_static_ = true;
......@@ -2574,7 +2574,7 @@ class DequantizeParam : public OpParam {
const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
activation_scale_ = GetVarValue<GType>("Scale", inputs, scope);
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) {
weight_scale_ = GetAttr<float>("weight_scale", attrs);
......@@ -2593,5 +2593,44 @@ class DequantizeParam : public OpParam {
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
bn_scale_ = OpParam::GetVarValue<GType>("BNScale", inputs, scope);
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output
output_ = OpParam::OutFrom<GType>(outputs, scope);
}
public:
// elementwise add
int axis_;
RType *bias_;
// batch norm
RType *bn_mean_;
RType *bn_variance_;
RType *bn_scale_;
RType *bn_bias_;
float epsilon_;
// output
RType *output_;
};
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -249,6 +249,7 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -450,6 +451,9 @@ endif()
if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......@@ -462,4 +466,4 @@ if (FUSION_DECONVADD_OP)
endif()
if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP)
endif()
\ No newline at end of file
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册