提交 0c4be5a4 编写于 作者: X xiebaiyuan 提交者: GitHub

Merge pull request #1324 from hjchen2/dev-latest

Fix ios cross compile, revert quantize kernel, and add other dequantize fusion pattern
...@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; ...@@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn";
const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu";
const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu";
const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_TANH = "tanh";
...@@ -136,6 +138,8 @@ std::unordered_map< ...@@ -136,6 +138,8 @@ std::unordered_map<
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}},
......
...@@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; ...@@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN;
extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU;
extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU;
extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_TANH;
......
...@@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU); ...@@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU);
#ifdef DEQUANT_OP #ifdef DEQUANT_OP
LOAD_OP1(dequantize, CPU); LOAD_OP1(dequantize, CPU);
#endif #endif
#ifdef FUSION_DEQUANT_ADD_BN_OP
LOAD_OP1(fusion_dequant_add_bn, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn);
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
LOAD_OP1(fusion_dequant_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_bn_relu);
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
LOAD_OP1(fusion_dequant_add_bn_relu, CPU); LOAD_OP1(fusion_dequant_add_bn_relu, CPU);
LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu);
......
...@@ -95,7 +95,8 @@ static std::mutex shared_mutex; ...@@ -95,7 +95,8 @@ static std::mutex shared_mutex;
andModelParamsLen:(size_t)combinedParamsLen andModelParamsLen:(size_t)combinedParamsLen
andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf { andCombinedParamsBuf:(const uint8_t *)combinedParamsBuf {
pam_->SetThreadNum(2); pam_->SetThreadNum(2);
return loaded_ = pam_->LoadCombinedMemory(modelLen, modelBuf, combinedParamsLen, combinedParamsBuf); return loaded_ = pam_->LoadCombinedMemory(modelLen, modelBuf, combinedParamsLen,
const_cast<uint8_t*>(combinedParamsBuf));
} }
- (BOOL)load:(NSString *)modelAndWeightPath{ - (BOOL)load:(NSString *)modelAndWeightPath{
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include <string> #include <string>
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/kernel/depthwise_conv_kernel.h" #include "operators/kernel/conv_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -26,19 +26,16 @@ namespace operators { ...@@ -26,19 +26,16 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DepthwiseConvOp : public framework::OperatorWithKernel< class DepthwiseConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>, DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>> { operators::ConvKernel<DeviceType, T>> {
public: public:
DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<DeviceType, ConvParam<DeviceType>,
DeviceType, ConvParam<DeviceType>, operators::ConvKernel<DeviceType, T>>(
operators::DepthwiseConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
void InferShape() const override; void InferShape() const override;
private:
}; };
} // namespace operators } // namespace operators
......
...@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/depthwise_conv_kernel.h" #include "operators/fusion_dequant_add_bn_op.h"
#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <typename Dtype, typename T>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { void FusionDequantAddBNOp<Dtype, T>::InferShape() const {
return true; const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
} }
template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConvCompute<float>(param);
}
template class DepthwiseConvKernel<CPU, float>;
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_add_bn, ops::FusionDequantAddBNMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_add_bn, ops::FusionDequantAddBNOp);
#endif
#endif #endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantAddBNMatcher : public framework::FusionOpMatcher {
public:
FusionDequantAddBNMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD) >
std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}},
{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN; }
};
template <typename DeviceType, typename T>
class FusionDequantAddBNOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>> {
public:
FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantAddBNParam<DeviceType>,
operators::FusionDequantAddBNKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "framework/operator.h" #include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h" #include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_add_bn_relu_kernel.h" #include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEPTHWISECONV_OP #ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once #include "operators/fusion_dequant_bn_relu_op.h"
#include "framework/operator.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
using framework::OpKernelBase; template <typename Dtype, typename T>
void FusionDequantBNReluOp<Dtype, T>::InferShape() const {
const auto& input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
template <typename DeviceType, typename T>
class DepthwiseConvKernel
: public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public:
void Compute(const ConvParam<DeviceType> &param);
bool Init(ConvParam<DeviceType> *param);
};
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
REGISTER_FUSION_MATCHER(fusion_dequant_bn_relu,
ops::FusionDequantBNReluMatcher);
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(fusion_dequant_bn_relu, ops::FusionDequantBNReluOp);
#endif
#endif #endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_DEQUANT_BN_RELU_OP
#pragma once
#include <string>
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
class FusionDequantBNReluMatcher : public framework::FusionOpMatcher {
public:
FusionDequantBNReluMatcher() {
node_ = framework::Node(G_OP_TYPE_DEQUANTIZE);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_BATCHNORM) >
std::make_shared<framework::Node>(G_OP_TYPE_RELU);
}
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_BATCHNORM,
{{"Scale", "BNScale"},
{"Mean", "BNMean"},
{"Bias", "BNBias"},
{"Variance", "BNVariance"}}}},
removed_nodes);
}
std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_BN_RELU; }
};
template <typename DeviceType, typename T>
class FusionDequantBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>> {
public:
FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDequantBNReluParam<DeviceType>,
operators::FusionDequantBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -22,41 +22,43 @@ namespace operators { ...@@ -22,41 +22,43 @@ namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3;
bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1];
if (param->Filter()->type() == typeid(int8_t)) { if (param->Filter()->type() == typeid(int8_t)) {
if (param->Groups() == param->Input()->dims()[1] && if (depth3x3 && param->Strides()[0] < 3 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 &&
param->Strides()[0] == param->Strides()[1]) { param->Strides()[0] == param->Strides()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_INT8;
} else { } else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8; param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_INT8;
} }
} else { } else {
if (param->Groups() == param->Input()->dims()[1] && if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] && param->Strides()[0] == 1 && param->Paddings()[0] == 1 &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Paddings()[0] == param->Paddings()[1]) {
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S1P1_FLOAT;
} else if (param->Groups() == param->Input()->dims()[1] && } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Input()->dims()[1] == param->Output()->dims()[1] && param->Strides()[0] == 2 && param->Paddings()[0] == 0 &&
param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Paddings()[0] == param->Paddings()[1]) {
param->Filter()->dims()[2] == 3) { param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT;
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT; } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
#ifndef __aarch64__ #ifndef __aarch64__
} else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 && param->Strides()[0] == 1 && param->Dilations()[0] == 1 &&
param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 && param->Output()->dims()[1] >= 16 &&
param->Input()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 &&
param->Input()->dims()[2] <= 140 /* refered from ncnn */) { param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight // transform weight
framework::Tensor *transformed_weight = new framework::Tensor; framework::Tensor transformed_weight;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(), operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
transformed_weight); &transformed_weight);
param->Filter() = transformed_weight; framework::TensorCopy(transformed_weight, param->Filter());
#endif #endif
} else { } else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
...@@ -78,9 +80,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -78,9 +80,13 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
nullptr, false); nullptr, false);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Filter(), nullptr, param.Output(), false); param.Output(), nullptr, false);
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P0_FLOAT:
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false);
break; break;
case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT: case ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT:
WinogradConv3x3<8, 3>(param); WinogradConv3x3<8, 3>(param);
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #ifdef FUSION_DEQUANT_ADD_BN_OP
#include "operators/kernel/dequant_add_bn_relu_kernel.h" #include "operators/kernel/dequant_add_bn_kernel.h"
#include <cmath> #include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h> #include <arm_neon.h>
...@@ -24,8 +24,8 @@ namespace paddle_mobile { ...@@ -24,8 +24,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init( bool FusionDequantAddBNKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) { FusionDequantAddBNParam<CPU> *param) {
// elementwise add params // elementwise add params
const Tensor *bias = param->bias_; const Tensor *bias = param->bias_;
// batch norm params // batch norm params
...@@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init( ...@@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel<CPU, float>::Init(
} }
template <> template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute( void FusionDequantAddBNKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) { const FusionDequantAddBNParam<CPU> &param) {
const int32_t *input = param.input_->data<int32_t>(); const int32_t *input = param.input_->data<int32_t>();
const float *bn_scale = param.bn_scale_->data<float>(); const float *bn_scale = param.bn_scale_->data<float>();
const float *bn_bias = param.bn_bias_->data<float>(); const float *bn_bias = param.bn_bias_->data<float>();
...@@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
remain = spatial_size & 0xF; remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias); float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) { for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x); int32x4_t r0 = vld1q_s32(x);
...@@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
f1 = vmlaq_f32(__bias, __scale, f1); f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2); f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3); f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0); vst1q_f32(y, f0);
vst1q_f32(y + 4, f1); vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2); vst1q_f32(y + 8, f2);
...@@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
} }
#endif // __ARM_NEON__ #endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) { for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f); y[k] = scale * x[k] + bias;
} }
} }
} }
...@@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute( ...@@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel<CPU, float>::Compute(
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP #endif // FUSION_DEQUANT_ADD_BN_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/dequant_bn_relu_kernel.h"
#include <cmath>
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif
namespace paddle_mobile {
namespace operators {
#if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP)
void DequantBNReluCompute(const FusionDequantBNParam<CPU> *param) {
const int32_t *input = param->input_->data<int32_t>();
const float *bn_scale = param->bn_scale_->data<float>();
const float *bn_bias = param->bn_bias_->data<float>();
// dequantize params
const float activation_scale = param->activation_scale_->data<float>()[0];
const float weight_scale = param->weight_scale_;
const float dequant_scale = activation_scale / weight_scale;
float *output = param->output_->mutable_data<float>();
int batch_size = param->input_->dims()[0];
int channels = param->input_->dims()[1];
size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
float scale = bn_scale[c] * dequant_scale;
float bias = bn_bias[c];
size_t offset = (batch * channels + c) * spatial_size;
const int32_t *x = input + offset;
float *y = output + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
float32x4_t __zero = vdupq_n_f32(0.f);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
int32x4_t r0 = vld1q_s32(x);
int32x4_t r1 = vld1q_s32(x + 4);
int32x4_t r2 = vld1q_s32(x + 8);
int32x4_t r3 = vld1q_s32(x + 12);
float32x4_t f0 = vcvtq_f32_s32(r0);
float32x4_t f1 = vcvtq_f32_s32(r1);
float32x4_t f2 = vcvtq_f32_s32(r2);
float32x4_t f3 = vcvtq_f32_s32(r3);
f0 = vmlaq_f32(__bias, __scale, f0);
f1 = vmlaq_f32(__bias, __scale, f1);
f2 = vmlaq_f32(__bias, __scale, f2);
f3 = vmlaq_f32(__bias, __scale, f3);
f0 = vmaxq_f32(__zero, f0);
f1 = vmaxq_f32(__zero, f1);
f2 = vmaxq_f32(__zero, f2);
f3 = vmaxq_f32(__zero, f3);
vst1q_f32(y, f0);
vst1q_f32(y + 4, f1);
vst1q_f32(y + 8, f2);
vst1q_f32(y + 12, f3);
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = std::max(scale * x[k] + bias, 0.f);
}
}
}
}
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <>
bool FusionDequantBNReluKernel<CPU, float>::Init(
FusionDequantBNReluParam<CPU> *param) {
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c];
}
return true;
}
template <>
void FusionDequantBNReluKernel<CPU, float>::Compute(
const FusionDequantBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_BN_RELU_OP
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <>
bool FusionDequantAddBNReluKernel<CPU, float>::Init(
FusionDequantAddBNReluParam<CPU> *param) {
// elementwise add params
const Tensor *bias = param->bias_;
// batch norm params
const Tensor *bn_mean = param->bn_mean_;
const Tensor *bn_variance = param->bn_variance_;
Tensor *bn_scale = param->bn_scale_;
Tensor *bn_bias = param->bn_bias_;
const float epsilon = param->epsilon_;
const float *bias_ptr = bias->data<float>();
const float *mean_ptr = bn_mean->data<float>();
const float *var_ptr = bn_variance->data<float>();
float *bn_scale_ptr = bn_scale->mutable_data<float>();
float *bn_bias_ptr = bn_bias->mutable_data<float>();
for (int c = 0; c < bn_scale->numel(); ++c) {
float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon));
bn_scale_ptr[c] = inv_scale;
bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c];
}
return true;
}
template <>
void FusionDequantAddBNReluKernel<CPU, float>::Compute(
const FusionDequantAddBNReluParam<CPU> &param) {
DequantBNReluCompute(&param);
}
#endif // FUSION_DEQUANT_ADD_BN_RELU_OP
} // namespace operators
} // namespace paddle_mobile
...@@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) { ...@@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
// param.Output(), false); // param.Output(), false);
if (param.Paddings()[0] == 0) { if (param.Paddings()[0] == 0) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
*param.Bias(), true); param.Bias(), true);
} else { } else {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), *param.Bias(), true); param.Output(), param.Bias(), true);
} }
} else { } else {
ConvAddBasic(param); ConvAddBasic(param);
......
...@@ -164,31 +164,21 @@ template <typename Itype, typename Otype> ...@@ -164,31 +164,21 @@ template <typename Itype, typename Otype>
inline void DepthwiseConv3x3(const ConvParam<CPU> &param) { inline void DepthwiseConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
const Tensor *filter = param.Filter(); const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = input->dims()[0];
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<Otype>(); output->mutable_data<Otype>();
const std::vector<int> &paddings = param.Paddings();
const std::vector<int> &strides = param.Strides();
const int batch_size = static_cast<int>(input->dims()[0]);
Tensor input_pad;
math::PadFunctor<CPU, Itype> pad;
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1); Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1); Tensor out_batch = output->Slice(i, i + 1);
if (paddings[0] || paddings[1]) {
framework::DDim pad_shape = in_batch.dims();
pad_shape[2] += 2 * paddings[0];
pad_shape[3] += 2 * paddings[1];
input_pad.mutable_data<float>(pad_shape);
pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1],
&input_pad);
} else {
input_pad = in_batch;
}
if (strides[0] == 1) { if (strides[0] == 1) {
math::DepthwiseConv3x3s1<Itype, Otype>(input_pad, *filter, &out_batch); math::DepthwiseConv3x3S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else if (strides[0] == 2) { } else if (strides[0] == 2) {
math::DepthwiseConv3x3s2<Itype, Otype>(input_pad, *filter, &out_batch); math::DepthwiseConv3x3S2<Itype, Otype>(in_batch, *filter, paddings,
&out_batch);
} else { } else {
// math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter, // math::DepthwiseConv3x3<Itype, Otype>(input_pad, *filter,
// &out_batch); // &out_batch);
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #ifdef FUSION_DEQUANT_ADD_BN_OP
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/op_param.h" #include "operators/op_param.h"
...@@ -23,12 +23,12 @@ namespace paddle_mobile { ...@@ -23,12 +23,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionDequantAddBNReluKernel class FusionDequantAddBNKernel
: public framework::OpKernelBase<DeviceType, : public framework::OpKernelBase<DeviceType,
FusionDequantAddBNReluParam<DeviceType>> { FusionDequantAddBNParam<DeviceType>> {
public: public:
void Compute(const FusionDequantAddBNReluParam<DeviceType> &param); void Compute(const FusionDequantAddBNParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param); bool Init(FusionDequantAddBNParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -12,42 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,42 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DEPTHWISECONV_OP
#pragma once #pragma once
#include <vector>
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "framework/operator.h"
#include "operators/math/depthwise_conv3x3.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> #ifdef FUSION_DEQUANT_BN_RELU_OP
void DepthwiseConvCompute(const ConvParam<CPU> &param) { template <typename DeviceType, typename T>
Tensor Bias; class FusionDequantBNReluKernel
Bias.mutable_data<float>({param.Groups()}); : public framework::OpKernelBase<DeviceType,
if (param.Groups() == param.Input()->dims()[1] && FusionDequantBNReluParam<DeviceType>> {
param.Filter()->dims()[2] == param.Filter()->dims()[3] && public:
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { void Compute(const FusionDequantBNReluParam<DeviceType> &param);
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), bool Init(FusionDequantBNReluParam<DeviceType> *param);
&Bias, false); };
} else if (param.Groups() == param.Input()->dims()[1] && #endif
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { template <typename DeviceType, typename T>
// math::DepthwiseConv3x3(param.Input(), param.Strides(), class FusionDequantAddBNReluKernel
// param.Paddings(), : public framework::OpKernelBase<DeviceType,
// param.Filter(), &Bias, param.Output(), false); FusionDequantAddBNReluParam<DeviceType>> {
math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), public:
Bias, false); void Compute(const FusionDequantAddBNReluParam<DeviceType> &param);
bool Init(FusionDequantAddBNReluParam<DeviceType> *param);
} else { };
GemmConv<float, float>(param); #endif
}
}
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -1272,13 +1272,16 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, ...@@ -1272,13 +1272,16 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias) { bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
const float *bias_data = bias.data<float>(); const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
const int in_h = static_cast<int>(input->dims()[2]); const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]); const int in_w = static_cast<int>(input->dims()[3]);
...@@ -1905,7 +1908,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, ...@@ -1905,7 +1908,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias) { bool if_bias) {
#if __ARM_NEON #if __ARM_NEON
...@@ -1925,7 +1928,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, ...@@ -1925,7 +1928,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
for (int c = 0; c < input_channel; c++) { for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9; const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw; const float *input_data = input->data<float>() + c * inhxw;
const float *bias_data = bias.data<float>() + c; const float *bias_data = bias->data<float>() + c;
float *output_data = output->data<float>() + c * outhxw; float *output_data = output->data<float>() + c * outhxw;
float w00 = filter_data[0]; float w00 = filter_data[0];
float w01 = filter_data[1]; float w01 = filter_data[1];
......
...@@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, ...@@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input,
void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias); bool if_bias);
void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
...@@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, ...@@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input,
void DepthwiseConv3x3s2p0(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input,
const framework::Tensor *filter, const framework::Tensor *filter,
framework::Tensor *output, framework::Tensor bias, framework::Tensor *output, framework::Tensor *bias,
bool if_bias); bool if_bias);
// TODO(hjchen2) need to be implemented // TODO(hjchen2) need to be implemented
...@@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, ...@@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input,
// void DepthwiseConv3x3(const framework::Tensor *input, // void DepthwiseConv3x3(const framework::Tensor *input,
// const framework::Tensor *filter, // const framework::Tensor *filter,
// const std::vector<int> &strides, // const std::vector<int> &strides,
// const std::vector<int> &paddings,
// framework::Tensor *output); // framework::Tensor *output);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void DepthwiseConv3x3s1(const framework::Tensor &input, void DepthwiseConv3x3S1(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output); framework::Tensor *output);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void DepthwiseConv3x3s2(const framework::Tensor &input, void DepthwiseConv3x3S2(const framework::Tensor &input,
const framework::Tensor &filter, const framework::Tensor &filter,
const std::vector<int> &paddings,
framework::Tensor *output); framework::Tensor *output);
} // namespace math } // namespace math
......
...@@ -405,9 +405,9 @@ class ConvParam : public OpParam { ...@@ -405,9 +405,9 @@ class ConvParam : public OpParam {
const RType *Input() const { return input_; } const RType *Input() const { return input_; }
RType *&Filter() const { return filter_; } RType *Filter() const { return filter_; }
RType *&Output() const { return output_; } RType *Output() const { return output_; }
const vector<int> &Strides() const { return strides_; } const vector<int> &Strides() const { return strides_; }
...@@ -419,6 +419,8 @@ class ConvParam : public OpParam { ...@@ -419,6 +419,8 @@ class ConvParam : public OpParam {
EXEC_INVALID = 0, EXEC_INVALID = 0,
EXEC_GEMM_FLOAT, EXEC_GEMM_FLOAT,
EXEC_DEPTHWISE3x3S1P1_FLOAT, EXEC_DEPTHWISE3x3S1P1_FLOAT,
EXEC_DEPTHWISE3x3S2P0_FLOAT,
EXEC_DEPTHWISE3x3S2P1_FLOAT,
EXEC_DEPTHWISE3x3_FLOAT, EXEC_DEPTHWISE3x3_FLOAT,
EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD3X3_FLOAT,
EXEC_WINOGRAD5X5_FLOAT, EXEC_WINOGRAD5X5_FLOAT,
...@@ -439,8 +441,8 @@ class ConvParam : public OpParam { ...@@ -439,8 +441,8 @@ class ConvParam : public OpParam {
private: private:
RType *input_; RType *input_;
mutable RType *output_; RType *output_;
mutable RType *filter_; RType *filter_;
vector<int> strides_; vector<int> strides_;
vector<int> paddings_; vector<int> paddings_;
vector<int> dilations_; vector<int> dilations_;
...@@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam { ...@@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam {
DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) { const AttributeMap &attrs, const Scope &scope) {
input_ = InputXFrom<GType>(inputs, scope); input_ = InputXFrom<GType>(inputs, scope);
if (outputs.count("Out")) {
output_ = OutFrom<GType>(outputs, scope); output_ = OutFrom<GType>(outputs, scope);
}
activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope); activation_scale_ = OpParam::GetVarValue<GType>("Scale", inputs, scope);
// dequantization is performed as x = x / static_scale / online_scale // dequantization is performed as x = x / static_scale / online_scale
if (HasAttr("weight_scale", attrs)) { if (HasAttr("weight_scale", attrs)) {
...@@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam { ...@@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam {
}; };
#endif #endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP #if defined(FUSION_DEQUANT_ADD_BN_OP) || \
defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \
defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP)
template <typename Dtype> template <typename Dtype>
class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> { class FusionDequantBNParam : public DequantizeParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType; typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType; typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public: public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs, FusionDequantBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) const AttributeMap &attrs, const Scope &scope)
: DequantizeParam<Dtype>(inputs, outputs, attrs, scope) { : DequantizeParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// batch norm params // batch norm params
bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope); bn_mean_ = OpParam::GetVarValue<GType>("BNMean", inputs, scope);
bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope); bn_variance_ = OpParam::GetVarValue<GType>("BNVariance", inputs, scope);
...@@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> { ...@@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam<Dtype> {
bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope); bn_bias_ = OpParam::GetVarValue<GType>("BNBias", inputs, scope);
epsilon_ = OpParam::GetAttr<float>("epsilon", attrs); epsilon_ = OpParam::GetAttr<float>("epsilon", attrs);
// output // output
output_ = OpParam::OutFrom<GType>(outputs, scope); if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
} }
public: public:
// elementwise add
int axis_;
RType *bias_;
// batch norm // batch norm
RType *bn_mean_; RType *bn_mean_;
RType *bn_variance_; RType *bn_variance_;
RType *bn_scale_; RType *bn_scale_;
RType *bn_bias_; RType *bn_bias_;
float epsilon_; float epsilon_;
};
#endif
#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP)
template <typename Dtype>
class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// element wise add params
axis_ = OpParam::GetAttr<int>("axis", attrs);
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
// output // output
RType *output_; if (outputs.count("Y")) {
this->output_ = OpParam::OutputYFrom<GType>(outputs, scope);
}
}
public:
// elementwise add
int axis_;
RType *bias_;
};
#endif
#ifdef FUSION_DEQUANT_BN_RELU_OP
template <typename Dtype>
class FusionDequantBNReluParam : public FusionDequantBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
};
#endif
#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP
template <typename Dtype>
class FusionDequantAddBNReluParam : public FusionDequantAddBNParam<Dtype> {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
FusionDequantAddBNReluParam(const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope)
: FusionDequantAddBNParam<Dtype>(inputs, outputs, attrs, scope) {
// output
if (outputs.count("Out")) {
this->output_ = OpParam::OutFrom<GType>(outputs, scope);
}
}
}; };
#endif #endif
......
...@@ -44,25 +44,19 @@ struct Round<round::RoundTowardsZero> { ...@@ -44,25 +44,19 @@ struct Round<round::RoundTowardsZero> {
template <> template <>
struct Round<round::RoundToEven> { struct Round<round::RoundToEven> {
int8_t operator()(float x) { int8_t operator()(float x) {
int8_t ret = 0;
float v = std::round(x); float v = std::round(x);
int32_t q = (int32_t)v; int32_t q = static_cast<int32_t>(v);
if (abs(abs(q - x) - 0.5) > 0) { if (abs(abs(q - v) - 0.5) <= 0) {
ret = q; if (abs(q) % 2 != 0) {
} else { q = q + ((q > 0) ? -1 : 1);
if (abs(q) % 2 == 0) {
ret = q;
} else {
ret = q + ((q > 0) ? -1 : 1);
} }
} }
return ret; return static_cast<int8_t>(q);
} }
}; };
template <round::RoundType T> template <round::RoundType T>
static void quantize(const Tensor *input, const float scale, const int pad, static void quantize(const Tensor *input, const float scale, Tensor *output) {
const int8_t pad_val, Tensor *output) {
int batch_size = input->dims()[0]; int batch_size = input->dims()[0];
int channels = input->dims()[1]; int channels = input->dims()[1];
int input_h = input->dims()[2]; int input_h = input->dims()[2];
...@@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad, ...@@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad,
for (int nc = 0; nc < batch_size * channels; ++nc) { for (int nc = 0; nc < batch_size * channels; ++nc) {
const float *xh = x + nc * input_spatial; const float *xh = x + nc * input_spatial;
int8_t *yh = y + nc * output_spatial; int8_t *yh = y + nc * output_spatial;
// pad top
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
}
}
for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) { for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) {
// pad left
for (int w = 0; w < pad; ++w) {
yh[w] = pad_val;
}
for (int w = 0; w < input_w; ++w) { for (int w = 0; w < input_w; ++w) {
yh[w + pad] = Round<T>()(xh[w] * scale); yh[w] = Round<T>()(xh[w] * scale);
}
// pad right
for (int w = 0; w < pad; ++w) {
yh[pad + input_w + w] = pad_val;
}
}
// pad bottom
for (int h = 0; h < pad; ++h, yh += output_w) {
for (int w = 0; w < output_w; ++w) {
yh[w] = pad_val;
} }
} }
} }
...@@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) { ...@@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) {
int TestQuqntizeOp(int argc, char *argv[]) { int TestQuqntizeOp(int argc, char *argv[]) {
if (argc < 5) { if (argc < 5) {
std::cout std::cout << "Usage: ./test-quantize-op batch_size channel height width"
<< "Usage: ./test-quantize-op batch_size channel height width [pad]"
<< std::endl; << std::endl;
return 1; return 1;
} }
int pad = 0;
int batch_size = atoi(argv[1]); int batch_size = atoi(argv[1]);
int channel = atoi(argv[2]); int channel = atoi(argv[2]);
int height = atoi(argv[3]); int height = atoi(argv[3]);
int width = atoi(argv[4]); int width = atoi(argv[4]);
if (argc == 6) {
pad = atoi(argv[5]);
}
std::cout << "batch_size: " << batch_size << ", channel: " << channel std::cout << "batch_size: " << batch_size << ", channel: " << channel
<< ", height: " << height << ", width: " << width << std::endl; << ", height: " << height << ", width: " << width << std::endl;
framework::DDim dim = framework::DDim dim =
...@@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) { ...@@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) {
auto output_scale_var = scope.get()->Var("output_scale"); auto output_scale_var = scope.get()->Var("output_scale");
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad, pad}));
auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs, auto *op = new operators::QuantizeOp<CPU, float>("quantize", inputs, outputs,
attrs, scope); attrs, scope);
op->InferShape(); op->InferShape();
...@@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) { ...@@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) {
framework::Tensor output_cmp; framework::Tensor output_cmp;
output_cmp.Resize(output->dims()); output_cmp.Resize(output->dims());
float scale = 127 / output_scale_cmp; float scale = 127 / output_scale_cmp;
// quantize<round::RoundToEven>(input, scale, pad, 0, &output_cmp); // quantize<round::RoundToEven>(input, scale, &output_cmp);
// quantize<round::RoundAwayZero>(input, scale, pad, 0, &output_cmp); // quantize<round::RoundAwayZero>(input, scale, &output_cmp);
quantize<round::RoundTowardsZero>(input, scale, pad, 0, &output_cmp); quantize<round::RoundTowardsZero>(input, scale, &output_cmp);
int8_t *output_cmp_data = output_cmp.data<int8_t>(); int8_t *output_cmp_data = output_cmp.data<int8_t>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
......
...@@ -249,7 +249,9 @@ if(NOT FOUND_MATCH) ...@@ -249,7 +249,9 @@ if(NOT FOUND_MATCH)
set(SUM_OP ON) set(SUM_OP ON)
set(QUANT_OP ON) set(QUANT_OP ON)
set(DEQUANT_OP ON) set(DEQUANT_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU ON) set(FUSION_DEQUANT_ADD_BN_OP ON)
set(FUSION_DEQUANT_BN_RELU_OP ON)
set(FUSION_DEQUANT_ADD_BN_RELU_OP ON)
endif() endif()
# option(BATCHNORM_OP "" ON) # option(BATCHNORM_OP "" ON)
...@@ -451,10 +453,17 @@ endif() ...@@ -451,10 +453,17 @@ endif()
if (DEQUANT_OP) if (DEQUANT_OP)
add_definitions(-DDEQUANT_OP) add_definitions(-DDEQUANT_OP)
endif() endif()
if (FUSION_DEQUANT_ADD_BN_RELU) if (FUSION_DEQUANT_ADD_BN_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_OP)
endif()
if (FUSION_DEQUANT_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_BN_RELU_OP)
endif()
if (FUSION_DEQUANT_ADD_BN_RELU_OP)
add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP)
endif() endif()
if (TANH_OP) if (TANH_OP)
add_definitions(-DTANH_OP) add_definitions(-DTANH_OP)
endif() endif()
...@@ -467,3 +476,4 @@ endif() ...@@ -467,3 +476,4 @@ endif()
if (FUSION_DECONVADDRELU_OP) if (FUSION_DECONVADDRELU_OP)
add_definitions(-DFUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP)
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册