diff --git a/src/common/types.cpp b/src/common/types.cpp index 312e491a35681e2fc75584106160a4c79e22e372..14c4d86868eb6aac8ed19120e2bdced87fd85277 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -39,6 +39,7 @@ const char *G_OP_TYPE_POLYGON_BOX_TRANSFORM = "polygon_box_transform"; const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; const char *G_OP_TYPE_RELU = "relu"; +const char *G_OP_TYPE_RELU6 = "relu6"; const char *G_OP_TYPE_RESHAPE = "reshape"; const char *G_OP_TYPE_RESHAPE2 = "reshape2"; const char *G_OP_TYPE_SIGMOID = "sigmoid"; @@ -74,6 +75,10 @@ const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn"; const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu"; const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_QUANT = + "fusion_dequant_add_bn_quant"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU_QUANT = + "fusion_dequant_add_bn_relu_quant"; const char *G_OP_TYPE_TANH = "tanh"; const char *G_OP_TYPE_FUSION_DECONV_RELU = "fusion_deconv_relu"; @@ -89,6 +94,7 @@ std::unordered_map< {G_OP_TYPE_PRELU, {{"X", "Alpha"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}}, {G_OP_TYPE_RELU, {{"X"}, {"Out"}}}, + {G_OP_TYPE_RELU6, {{"X"}, {"Out"}}}, {G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}}, {G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}}, {G_OP_TYPE_MUL, {{"X"}, {"Out"}}}, @@ -141,6 +147,10 @@ std::unordered_map< {G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}}, {G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU_QUANT, + {{"X", "Scale"}, {"Out", "OutScale"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_QUANT, + {{"X", "Scale"}, {"Out", "OutScale"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_ADD, {{"Input"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index 16ed1aef57432249b14c415b3a23042ca295b600..1c0dd27d5c91b201301befda5017d95a8bc36b61 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -114,6 +114,7 @@ extern const char *G_OP_TYPE_MULTICLASS_NMS; extern const char *G_OP_TYPE_POOL2D; extern const char *G_OP_TYPE_PRIOR_BOX; extern const char *G_OP_TYPE_RELU; +extern const char *G_OP_TYPE_RELU6; extern const char *G_OP_TYPE_RESHAPE; extern const char *G_OP_TYPE_SIGMOID; extern const char *G_OP_TYPE_SOFTMAX; @@ -141,6 +142,8 @@ extern const char *G_OP_TYPE_DEQUANTIZE; extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN; extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU; extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_QUANT; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU_QUANT; extern const char *G_OP_TYPE_TANH; extern const char *G_OP_TYPE_FUSION_DECONV_RELU; diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 0e00585f9042124e1a62e6ad8ce01ebfbfd541a0..6dc8f8c30d8259e0ae71bebb75f618342ec5f841 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -302,7 +302,15 @@ std::shared_ptr Executor::Predict( for (int i = 0; i < profile.size(); i++) { const auto &pInfo = profile[i]; uint64_t timeCost = pInfo.runEnd - pInfo.runBegin; - _tp[ops[i]->Type()] += timeCost; + if (ops[i]->Type() == "conv2d") { + auto inputs = ops[i]->Inputs(); + auto *filter = framework::GetVarValue( + "Filter", inputs, *(program_.scope)); + int kernel_size = filter->dims()[2]; + _tp[ops[i]->Type() + "_" + std::to_string(kernel_size)] += timeCost; + } else { + _tp[ops[i]->Type()] += timeCost; + } } printf("====================[ profile ]======================\n"); using prof_t = std::pair; @@ -372,6 +380,14 @@ std::shared_ptr Executor::PredictLod( for (int i = 0; i < profile.size(); i++) { const auto &pInfo = profile[i]; uint64_t timeCost = pInfo.runEnd - pInfo.runBegin; + if (ops[i]->Type() == "conv2d") { + auto inputs = ops[i]->Inputs(); + auto input_keys = ops[i]->GetInputKeys(); + auto *filter = framework::GetVarValue( + input_keys[1], inputs, *(program_.scope)); + int kernel_size = filter->dims()[2]; + printf("kernel size: %d\n", kernel_size); + } _tp[ops[i]->Type()] += timeCost; } printf("====================[ profile ]======================\n"); diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 2534217d58674f912f0e5da741dfcae41827edf1..ad0abc692ccc64caa0d3cf88846e0ecc651cfaae 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -191,6 +191,7 @@ LOAD_OP2(mul, CPU, MALI_GPU); #endif #ifdef RELU_OP LOAD_OP2(relu, CPU, MALI_GPU); +LOAD_OP1(relu6, CPU); #endif #ifdef IM2SEQUENCE_OP LOAD_OP1(im2sequence, CPU); @@ -245,3 +246,11 @@ LOAD_FUSION_MATCHER(fusion_dequant_bn_relu); LOAD_OP1(fusion_dequant_add_bn_relu, CPU); LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); #endif +#ifdef FUSION_DEQUANT_ADD_BN_QUANT_OP +LOAD_OP1(fusion_dequant_add_bn_quant, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn_quant); +#endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP +LOAD_OP1(fusion_dequant_add_bn_relu_quant, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu_quant); +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_quant_op.cpp b/src/operators/fusion_dequant_add_bn_relu_quant_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82eacd7f47dbfad56b99467d91ec849720ec787c --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_quant_op.cpp @@ -0,0 +1,62 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/fusion_dequant_add_bn_relu_quant_op.h" + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP +namespace paddle_mobile { +namespace operators { + +template +void FusionDequantAddBNReluQuantOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_relu_quant, + ops::FusionDequantAddBNReluQuantMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_relu_quant, + ops::FusionDequantAddBNReluQuantOp); +#endif +#endif // FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP + +#ifdef FUSION_DEQUANT_ADD_BN_QUANT_OP +namespace paddle_mobile { +namespace operators { + +template +void FusionDequantAddBNQuantOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn_quant, + ops::FusionDequantAddBNQuantMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn_quant, + ops::FusionDequantAddBNQuantOp); +#endif + +#endif // FUSION_DEQUANT_ADD_BN_QUANT_OP diff --git a/src/operators/fusion_dequant_add_bn_relu_quant_op.h b/src/operators/fusion_dequant_add_bn_relu_quant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..23e548ea8157850b2f889e3503e56435d01fbb2b --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_relu_quant_op.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP +class FusionDequantAddBNReluQuantMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNReluQuantMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU) > + std::make_shared(G_OP_TYPE_QUANTIZE); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU_QUANT; } +}; + +template +class FusionDequantAddBNReluQuantOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluQuantParam, + operators::FusionDequantAddBNReluQuantKernel> { + public: + FusionDequantAddBNReluQuantOp(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNReluQuantParam, + operators::FusionDequantAddBNReluQuantKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; +#endif // FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP + +#ifdef FUSION_DEQUANT_ADD_BN_QUANT_OP +class FusionDequantAddBNQuantMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNQuantMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_QUANTIZE); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN_QUANT; } +}; + +template +class FusionDequantAddBNQuantOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNQuantParam, + operators::FusionDequantAddBNQuantKernel> { + public: + FusionDequantAddBNQuantOp(const std::string &type, + const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNQuantParam, + operators::FusionDequantAddBNQuantKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; +#endif // FUSION_DEQUANT_ADD_BN_QUANT_OP + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/dequant_add_bn_kernel.cpp b/src/operators/kernel/arm/dequant_add_bn_kernel.cpp index 65fb0190f76a34a584d065bd43841567e9658bb8..d8ccb75126d9be2569037be483c24a7d9cb7a407 100644 --- a/src/operators/kernel/arm/dequant_add_bn_kernel.cpp +++ b/src/operators/kernel/arm/dequant_add_bn_kernel.cpp @@ -67,7 +67,9 @@ void FusionDequantAddBNKernel::Compute( #pragma omp parallel for collapse(2) for (int batch = 0; batch < batch_size; ++batch) { for (int c = 0; c < channels; ++c) { - float scale = bn_scale[c] * dequant_scale; + // not fuse bn and dequant scale to minimize precision difference + // float scale = bn_scale[c] * dequant_scale; + float scale = bn_scale[c]; float bias = bn_bias[c]; size_t offset = (batch * channels + c) * spatial_size; const int32_t *x = input + offset; @@ -76,9 +78,9 @@ void FusionDequantAddBNKernel::Compute( #if defined(__ARM_NEON__) || defined(__ARM_NEON) int loop = spatial_size >> 4; remain = spatial_size & 0xF; + float32x4_t __dequant_scale = vdupq_n_f32(dequant_scale); float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __bias = vdupq_n_f32(bias); - for (int k = 0; k < loop; ++k, x += 16, y += 16) { int32x4_t r0 = vld1q_s32(x); int32x4_t r1 = vld1q_s32(x + 4); @@ -88,6 +90,10 @@ void FusionDequantAddBNKernel::Compute( float32x4_t f1 = vcvtq_f32_s32(r1); float32x4_t f2 = vcvtq_f32_s32(r2); float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmulq_f32(__dequant_scale, f0); + f1 = vmulq_f32(__dequant_scale, f1); + f2 = vmulq_f32(__dequant_scale, f2); + f3 = vmulq_f32(__dequant_scale, f3); f0 = vmlaq_f32(__bias, __scale, f0); f1 = vmlaq_f32(__bias, __scale, f1); f2 = vmlaq_f32(__bias, __scale, f2); @@ -99,7 +105,7 @@ void FusionDequantAddBNKernel::Compute( } #endif // __ARM_NEON__ for (int k = 0; k < remain; ++k) { - y[k] = scale * x[k] + bias; + y[k] = scale * (dequant_scale * x[k]) + bias; } } } diff --git a/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp index 4d656712c193aa81a8be11c53856c868e2b82483..5f13333b54f305d0237c01e921d12c2e3bed0733 100644 --- a/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 201f8 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ limitations under the License. */ #include "operators/kernel/dequant_bn_relu_kernel.h" #include +#include "operators/math/quantize.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include #endif @@ -21,6 +22,31 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { +#if defined(FUSION_DEQUANT_BN_RELU_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) +void PublicFusionDequantBNInitParam(FusionDequantBNParam *param, + const framework::Tensor *bias) { + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + float val = bias ? bias->data()[c] : 0; + bn_bias_ptr[c] = inv_scale * (val - mean_ptr[c]) + bn_bias_ptr[c]; + } +} +#endif + #if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP) void DequantBNReluCompute(const FusionDequantBNParam *param) { const int32_t *input = param->input_->data(); @@ -39,7 +65,9 @@ void DequantBNReluCompute(const FusionDequantBNParam *param) { #pragma omp parallel for collapse(2) for (int batch = 0; batch < batch_size; ++batch) { for (int c = 0; c < channels; ++c) { - float scale = bn_scale[c] * dequant_scale; + // not fuse bn and dequant scale to minimize precision difference + // float scale = bn_scale[c] * dequant_scale; + float scale = bn_scale[c]; float bias = bn_bias[c]; size_t offset = (batch * channels + c) * spatial_size; const int32_t *x = input + offset; @@ -48,10 +76,10 @@ void DequantBNReluCompute(const FusionDequantBNParam *param) { #if defined(__ARM_NEON__) || defined(__ARM_NEON) int loop = spatial_size >> 4; remain = spatial_size & 0xF; + float32x4_t __dequant_scale = vdupq_n_f32(dequant_scale); float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __bias = vdupq_n_f32(bias); float32x4_t __zero = vdupq_n_f32(0.f); - for (int k = 0; k < loop; ++k, x += 16, y += 16) { int32x4_t r0 = vld1q_s32(x); int32x4_t r1 = vld1q_s32(x + 4); @@ -61,6 +89,10 @@ void DequantBNReluCompute(const FusionDequantBNParam *param) { float32x4_t f1 = vcvtq_f32_s32(r1); float32x4_t f2 = vcvtq_f32_s32(r2); float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmulq_f32(__dequant_scale, f0); + f1 = vmulq_f32(__dequant_scale, f1); + f2 = vmulq_f32(__dequant_scale, f2); + f3 = vmulq_f32(__dequant_scale, f3); f0 = vmlaq_f32(__bias, __scale, f0); f1 = vmlaq_f32(__bias, __scale, f1); f2 = vmlaq_f32(__bias, __scale, f2); @@ -76,7 +108,7 @@ void DequantBNReluCompute(const FusionDequantBNParam *param) { } #endif // __ARM_NEON__ for (int k = 0; k < remain; ++k) { - y[k] = std::max(scale * x[k] + bias, 0.f); + y[k] = std::max(scale * (dequant_scale * x[k]) + bias, 0.f); } } } @@ -87,22 +119,7 @@ void DequantBNReluCompute(const FusionDequantBNParam *param) { template <> bool FusionDequantBNReluKernel::Init( FusionDequantBNReluParam *param) { - // batch norm params - const Tensor *bn_mean = param->bn_mean_; - const Tensor *bn_variance = param->bn_variance_; - Tensor *bn_scale = param->bn_scale_; - Tensor *bn_bias = param->bn_bias_; - const float epsilon = param->epsilon_; - - const float *mean_ptr = bn_mean->data(); - const float *var_ptr = bn_variance->data(); - float *bn_scale_ptr = bn_scale->mutable_data(); - float *bn_bias_ptr = bn_bias->mutable_data(); - for (int c = 0; c < bn_scale->numel(); ++c) { - float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); - bn_scale_ptr[c] = inv_scale; - bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c]; - } + PublicFusionDequantBNInitParam(param, nullptr); return true; } @@ -117,25 +134,8 @@ void FusionDequantBNReluKernel::Compute( template <> bool FusionDequantAddBNReluKernel::Init( FusionDequantAddBNReluParam *param) { - // elementwise add params - const Tensor *bias = param->bias_; - // batch norm params - const Tensor *bn_mean = param->bn_mean_; - const Tensor *bn_variance = param->bn_variance_; - Tensor *bn_scale = param->bn_scale_; - Tensor *bn_bias = param->bn_bias_; - const float epsilon = param->epsilon_; - - const float *bias_ptr = bias->data(); - const float *mean_ptr = bn_mean->data(); - const float *var_ptr = bn_variance->data(); - float *bn_scale_ptr = bn_scale->mutable_data(); - float *bn_bias_ptr = bn_bias->mutable_data(); - for (int c = 0; c < bn_scale->numel(); ++c) { - float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); - bn_scale_ptr[c] = inv_scale; - bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c]; - } + const framework::Tensor *bias = param->bias_; + PublicFusionDequantBNInitParam(param, bias); return true; } @@ -146,5 +146,248 @@ void FusionDequantAddBNReluKernel::Compute( } #endif // FUSION_DEQUANT_ADD_BN_RELU_OP +#ifdef FUSION_DEQUANT_ADD_BN_QUANT_OP +template +void DequantBNQuantCompute(const FusionDequantAddBNQuantParam *param) { + const int32_t *input = param->input_->data(); + const float *bn_scale = param->bn_scale_->data(); + const float *bn_bias = param->bn_bias_->data(); + // dequantize params + const float activation_scale = param->activation_scale_->data()[0]; + const float weight_scale = param->weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + // quantize params + Tensor *output_scale = param->online_scale_; + float max_abs = 0.f; + + int8_t *output = param->output_->mutable_data(); + int batch_size = param->input_->dims()[0]; + int channels = param->input_->dims()[1]; + size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3]; + + // if (param->is_static_) { + if (true) { + max_abs = param->static_scale_; + float quant_scale = 127.f / max_abs; + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + // not fuse bn and dequant scale to minimize precision difference + // float scale = bn_scale[c] * dequant_scale; + float scale = bn_scale[c]; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + int8_t *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __dequant_scale = vdupq_n_f32(dequant_scale); + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __quant_scale = vdupq_n_f32(quant_scale); + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmulq_f32(__dequant_scale, f0); + f1 = vmulq_f32(__dequant_scale, f1); + f2 = vmulq_f32(__dequant_scale, f2); + f3 = vmulq_f32(__dequant_scale, f3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmulq_f32(__quant_scale, f0); + f1 = vmulq_f32(__quant_scale, f1); + f2 = vmulq_f32(__quant_scale, f2); + f3 = vmulq_f32(__quant_scale, f3); + int32x4_t q0 = math::vround_f32(f0); + int32x4_t q1 = math::vround_f32(f1); + int32x4_t q2 = math::vround_f32(f2); + int32x4_t q3 = math::vround_f32(f3); + int16x4_t d0 = vmovn_s32(q0); + int16x4_t d1 = vmovn_s32(q1); + int16x4_t d2 = vmovn_s32(q2); + int16x4_t d3 = vmovn_s32(q3); + int16x8_t q5 = vcombine_s16(d0, d1); + int16x8_t q6 = vcombine_s16(d2, d3); + int8x8_t d5 = vmovn_s16(q5); + int8x8_t d6 = vmovn_s16(q6); + vst1_s8(y, d5); + vst1_s8(y + 8, d6); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + float x_temp = scale * (dequant_scale * x[k]) + bias; + y[k] = math::Round(x_temp * quant_scale); + } + } + } + } else { + // TODO(hjchen2) + max_abs = std::max(max_abs, 1e-6f); + } + param->online_scale_->mutable_data()[0] = max_abs; +} + +template <> +bool FusionDequantAddBNQuantKernel::Init( + FusionDequantAddBNQuantParam *param) { + const framework::Tensor *bias = param->bias_; + PublicFusionDequantBNInitParam(param, bias); + return true; +} + +template <> +void FusionDequantAddBNQuantKernel::Compute( + const FusionDequantAddBNQuantParam ¶m) { + switch (param.round_type_) { + case ROUND_NEAREST_TO_EVEN: + DequantBNQuantCompute(¶m); + break; + case ROUND_NEAREST_TOWARDS_ZERO: + DequantBNQuantCompute(¶m); + break; + case ROUND_NEAREST_AWAY_ZERO: + DequantBNQuantCompute(¶m); + break; + default: + LOG(kLOG_ERROR) << "round type is not supported."; + break; + } +} +#endif // FUSION_DEQUANT_ADD_BN_QUANT_OP + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP +template +void DequantBNReluQuantCompute( + const FusionDequantAddBNReluQuantParam *param) { + const int32_t *input = param->input_->data(); + const float *bn_scale = param->bn_scale_->data(); + const float *bn_bias = param->bn_bias_->data(); + // dequantize params + const float activation_scale = param->activation_scale_->data()[0]; + const float weight_scale = param->weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + // quantize params + Tensor *output_scale = param->online_scale_; + float max_abs = 0.f; + + int8_t *output = param->output_->mutable_data(); + int batch_size = param->input_->dims()[0]; + int channels = param->input_->dims()[1]; + size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3]; + + // if (param->is_static_) { + if (true) { + max_abs = param->static_scale_; + float quant_scale = 127.f / max_abs; + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + // float scale = bn_scale[c] * dequant_scale; + float scale = bn_scale[c]; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + int8_t *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __dequant_scale = vdupq_n_f32(dequant_scale); + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __zero = vdupq_n_f32(0.f); + float32x4_t __quant_scale = vdupq_n_f32(quant_scale); + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmulq_f32(__dequant_scale, f0); + f1 = vmulq_f32(__dequant_scale, f1); + f2 = vmulq_f32(__dequant_scale, f2); + f3 = vmulq_f32(__dequant_scale, f3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmaxq_f32(__zero, f0); + f1 = vmaxq_f32(__zero, f1); + f2 = vmaxq_f32(__zero, f2); + f3 = vmaxq_f32(__zero, f3); + f0 = vmulq_f32(__quant_scale, f0); + f1 = vmulq_f32(__quant_scale, f1); + f2 = vmulq_f32(__quant_scale, f2); + f3 = vmulq_f32(__quant_scale, f3); + int32x4_t q0 = math::vround_f32(f0); + int32x4_t q1 = math::vround_f32(f1); + int32x4_t q2 = math::vround_f32(f2); + int32x4_t q3 = math::vround_f32(f3); + int16x4_t d0 = vmovn_s32(q0); + int16x4_t d1 = vmovn_s32(q1); + int16x4_t d2 = vmovn_s32(q2); + int16x4_t d3 = vmovn_s32(q3); + int16x8_t q5 = vcombine_s16(d0, d1); + int16x8_t q6 = vcombine_s16(d2, d3); + int8x8_t d5 = vmovn_s16(q5); + int8x8_t d6 = vmovn_s16(q6); + vst1_s8(y, d5); + vst1_s8(y + 8, d6); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + float x_temp = std::max(scale * (dequant_scale * x[k]) + bias, 0.f); + y[k] = math::Round(x_temp * quant_scale); + } + } + } + } else { + // TODO(hjchen2) + max_abs = std::max(max_abs, 1e-6f); + } + param->online_scale_->mutable_data()[0] = max_abs; +} + +template <> +bool FusionDequantAddBNReluQuantKernel::Init( + FusionDequantAddBNReluQuantParam *param) { + const framework::Tensor *bias = param->bias_; + PublicFusionDequantBNInitParam(param, bias); + return true; +} + +template <> +void FusionDequantAddBNReluQuantKernel::Compute( + const FusionDequantAddBNReluQuantParam ¶m) { + switch (param.round_type_) { + case ROUND_NEAREST_TO_EVEN: + DequantBNReluQuantCompute(¶m); + break; + case ROUND_NEAREST_TOWARDS_ZERO: + DequantBNReluQuantCompute(¶m); + break; + case ROUND_NEAREST_AWAY_ZERO: + DequantBNReluQuantCompute(¶m); + break; + default: + LOG(kLOG_ERROR) << "round type is not supported."; + break; + } +} +#endif // FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index ca3fa71f98f778752ac9dd7728385f5525696a02..ce5d952517e1c665faa42793bef2b273b779d5cb 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -16,6 +16,7 @@ limitations under the License. */ #include "operators/kernel/quantize_kernel.h" #include +#include "operators/math/quantize.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include @@ -30,72 +31,6 @@ inline float32_t vmaxvq_f32(float32x4_t r) { } #endif -template -inline int32x4_t vround_f32(float32x4_t r) { - return vcvtq_s32_f32(r); -} - -template <> -inline int32x4_t vround_f32(float32x4_t r) { - float32x4_t plus = vdupq_n_f32(0.5); - float32x4_t minus = vdupq_n_f32(-0.5); - float32x4_t zero = vdupq_n_f32(0); - uint32x4_t more_than_zero = vcgtq_f32(r, zero); - float32x4_t temp = vbslq_f32(more_than_zero, plus, minus); - temp = vaddq_f32(r, temp); - int32x4_t ret = vcvtq_s32_f32(temp); - return ret; -} - -template <> -inline int32x4_t vround_f32(float32x4_t r) { - float32x4_t point5 = vdupq_n_f32(0.5); - int32x4_t one = vdupq_n_s32(1); - int32x4_t zero = vdupq_n_s32(0); - - int32x4_t rnd = vround_f32(r); - float32x4_t frnd = vcvtq_f32_s32(rnd); - frnd = vsubq_f32(frnd, r); - frnd = vabsq_f32(frnd); - uint32x4_t equal_point5 = vceqq_f32(frnd, point5); - int32x4_t abs_rnd = vabsq_s32(rnd); - abs_rnd = vandq_s32(abs_rnd, one); - uint32x4_t not_mod2 = vreinterpretq_u32_s32(abs_rnd); - uint32x4_t mask = vandq_u32(equal_point5, not_mod2); - uint32x4_t more_than_zero = vcgtq_s32(rnd, zero); - more_than_zero = vandq_u32(more_than_zero, vreinterpretq_u32_s32(one)); - mask = veorq_u32(more_than_zero, mask); - more_than_zero = veorq_u32(more_than_zero, vreinterpretq_u32_s32(one)); - mask = vaddq_u32(more_than_zero, mask); - int32x4_t smask = vreinterpretq_s32_u32(mask); - smask = vsubq_s32(smask, one); - rnd = vaddq_s32(rnd, smask); - return rnd; -} -#endif - -template -inline int8_t Round(const float &x) { - return static_cast(x); -} - -template <> -inline int8_t Round(const float &x) { - return std::round(x); -} - -template <> -inline int8_t Round(const float &x) { - float v = std::round(x); - int32_t q = static_cast(v); - if (std::abs(std::abs(q - v) - 0.5) <= 0) { - if (std::abs(q) % 2 != 0) { - q = q + ((q > 0) ? -1 : 1); - } - } - return static_cast(q); -} - template static void Quantize(const Tensor *input, const float scale, Tensor *output) { const float *x = input->data(); @@ -105,6 +40,7 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) { size_t loop = remain >> 4; remain = remain & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { const float *local_x = x + (i << 4); @@ -113,14 +49,14 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) { float32x4_t r1 = vld1q_f32(local_x + 4); float32x4_t r2 = vld1q_f32(local_x + 8); float32x4_t r3 = vld1q_f32(local_x + 12); - r0 = vmulq_n_f32(r0, scale); - r1 = vmulq_n_f32(r1, scale); - r2 = vmulq_n_f32(r2, scale); - r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vround_f32(r0); - int32x4_t q1 = vround_f32(r1); - int32x4_t q2 = vround_f32(r2); - int32x4_t q3 = vround_f32(r3); + r0 = vmulq_f32(r0, __scale); + r1 = vmulq_f32(r1, __scale); + r2 = vmulq_f32(r2, __scale); + r3 = vmulq_f32(r3, __scale); + int32x4_t q0 = math::vround_f32(r0); + int32x4_t q1 = math::vround_f32(r1); + int32x4_t q2 = math::vround_f32(r2); + int32x4_t q3 = math::vround_f32(r3); int16x4_t d0 = vmovn_s32(q0); int16x4_t d1 = vmovn_s32(q1); int16x4_t d2 = vmovn_s32(q2); @@ -136,7 +72,7 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) { y += (loop << 4); #endif for (size_t i = 0; i < remain; ++i) { - y[i] = Round(x[i] * scale); + y[i] = math::Round(x[i] * scale); } } @@ -171,6 +107,13 @@ float find_abs_max(const Tensor *input) { return max_abs; } +} // namespace operators +} // namespace paddle_mobile +#endif // __ARM_NEON__ + +namespace paddle_mobile { +namespace operators { + template <> bool QuantizeKernel::Init(QuantizeParam *param) { return true; @@ -182,8 +125,8 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { Tensor *output = param.output_; Tensor *output_scale = param.online_scale_; float max_abs = 0.f; - if (param.is_static_) { - max_abs = param.static_scale_; + if (param.offline_) { + max_abs = param.offline_scale_->data()[0]; } else { max_abs = find_abs_max(input); } @@ -210,4 +153,4 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { } // namespace operators } // namespace paddle_mobile -#endif +#endif // QUANT_OP diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/relu_kernel.cpp index 8ee103484eb753913e5554b64d6dac523066322a..f1de895008678342bc7a6f6db7809fac4ba29982 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/relu_kernel.cpp @@ -15,11 +15,86 @@ limitations under the License. */ #ifdef RELU_OP #include "operators/kernel/relu_kernel.h" -#include "operators/kernel/central-arm-func/relu_arm_func.h" +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif namespace paddle_mobile { namespace operators { +enum ReluMode { + Relu = 0, + Relu6 = 1, + PRelu = 2, + LeakyRelu = 3, +}; + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +template +inline float32x4_t vRelu_f32(const float32x4_t &x) { + float32x4_t __zero = vdupq_n_f32(0.f); + return vmaxq_f32(__zero, x); +} + +template <> +inline float32x4_t vRelu_f32(const float32x4_t &x) { + float32x4_t __zero = vdupq_n_f32(0.f); + float32x4_t __six = vdupq_n_f32(6.f); + return vminq_f32(__six, vmaxq_f32(__zero, x)); +} +#endif + +template +inline float ReluFunc(const float &x) { + return std::max(x, 0.f); +} + +template <> +inline float ReluFunc(const float &x) { + return std::min(std::max(x, 0.f), 6.f); +} + +template +struct ReluCompute { + void operator()(const Tensor *input, Tensor *output) {} +}; + +template +struct ReluCompute { + void operator()(const Tensor *input, Tensor *output) { + const float *x = input->data(); + float *y = output->mutable_data(); + size_t remain = input->numel(); +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + size_t loop = remain >> 4; + remain = remain & 0xF; + + #pragma omp parallel for + for (size_t i = 0; i < loop; ++i) { + const float *local_x = x + (i << 4); + float *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); + r0 = vRelu_f32(r0); + r1 = vRelu_f32(r1); + r2 = vRelu_f32(r2); + r3 = vRelu_f32(r3); + vst1q_f32(local_y, r0); + vst1q_f32(local_y + 4, r1); + vst1q_f32(local_y + 8, r2); + vst1q_f32(local_y + 12, r3); + } + x += (loop << 4); + y += (loop << 4); +#endif + for (size_t i = 0; i < remain; ++i) { + y[i] = ReluFunc(x[i]); + } + } +}; + template <> bool ReluKernel::Init(ReluParam *param) { return true; @@ -27,7 +102,21 @@ bool ReluKernel::Init(ReluParam *param) { template <> void ReluKernel::Compute(const ReluParam ¶m) { - ReluCompute(param); + const Tensor *input = param.InputX(); + Tensor *output = param.Out(); + ReluCompute()(input, output); +} + +template <> +bool Relu6Kernel::Init(ReluParam *param) { + return true; +} + +template <> +void Relu6Kernel::Compute(const ReluParam ¶m) { + const Tensor *input = param.InputX(); + Tensor *output = param.Out(); + ReluCompute()(input, output); } } // namespace operators diff --git a/src/operators/kernel/arm/transpose2_kernel.cpp b/src/operators/kernel/arm/transpose2_kernel.cpp index 228f210ea1c52f1bfe601bd46f741347dabd6cce..6928df71e65dfe0aa28ceddda8f379f60fb65908 100644 --- a/src/operators/kernel/arm/transpose2_kernel.cpp +++ b/src/operators/kernel/arm/transpose2_kernel.cpp @@ -11,14 +11,111 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef TRANSPOSE2_OP #include "operators/kernel/transpose2_kernel.h" -#include "operators/kernel/central-arm-func/transpose2_arm_func.h" namespace paddle_mobile { namespace operators { +bool IsShuffleChannel(const std::vector &axis) { + bool is_shuffle_channel = true; + if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) { + for (int i = 3; i < axis.size(); ++i) { + if (axis[i] != i) { + is_shuffle_channel = false; + break; + } + } + } else { + return false; + } + return is_shuffle_channel; +} + +template +void ShuffleChannelCompute(const Transpose2Param ¶m) { + const std::vector &axis = param.Axis(); + const Tensor *input = param.InputX(); + const Dtype *input_ptr = input->data(); + Tensor *output = param.Out(); + Dtype *output_ptr = output->mutable_data(); + // input and output's shape dimension must >= 2 && <= 6. + const framework::DDim &in_dim = input->dims(); + const framework::DDim &out_dim = output->dims(); + size_t offset = 1; + for (int i = 3; i < axis.size(); ++i) { + offset *= in_dim[i]; + } + + #pragma omp parallel for collapse(3) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int c1 = 0; c1 < out_dim[1]; ++c1) { + for (int c2 = 0; c2 < out_dim[2]; ++c2) { + size_t out_offset = + ((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset; + size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset; + memcpy(output_ptr + out_offset, input_ptr + in_offset, + offset * sizeof(Dtype)); + } + } + } +} + +template +void Transpose2Compute(const Transpose2Param ¶m) { + const std::vector &axis = param.Axis(); + const Tensor *input = param.InputX(); + const Dtype *input_ptr = input->data(); + Tensor *output = param.Out(); + Dtype *output_ptr = output->mutable_data(); + // input and output's shape dimension must >= 2 && <= 6. + const framework::DDim &in_dim = input->dims(); + const framework::DDim &out_dim = output->dims(); + + // precompute inverted output dim and strides + size_t rout_dim[6], strides[6]; + int permute = axis.size(); // permute must >=2 && <= 6. + for (int i = 0; i < permute; ++i) { + int k = permute - 1 - i; + strides[k] = 1; + for (int j = axis[i] + 1; j < permute; ++j) { + strides[k] *= in_dim[j]; + } + rout_dim[k] = out_dim[i]; + } + // unroll the first 2 dimensions + int reamin_dim = 1; + for (int i = 2; i < out_dim.size(); ++i) { + reamin_dim *= out_dim[i]; + } + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int j = 0; j < out_dim[1]; ++j) { + size_t offset = batch * strides[permute - 1] + j * strides[permute - 2]; + Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; + int indics[4] = {0, 0, 0, 0}; + for (int k = 0; k < reamin_dim; ++k) { + out_ptr[k] = input_ptr[offset]; + indics[0] += 1; + offset += strides[0]; + for (int p = 0; p < permute - 3; ++p) { + if (indics[p] == rout_dim[p]) { + indics[p + 1] += 1; + indics[p] = 0; + offset += strides[p + 1]; + offset -= rout_dim[p] * strides[p]; + } else { + break; + } + } + } + } + } +} + template <> bool Transpose2Kernel::Init(Transpose2Param *param) { return true; @@ -26,10 +123,24 @@ bool Transpose2Kernel::Init(Transpose2Param *param) { template <> void Transpose2Kernel::Compute(const Transpose2Param ¶m) { - Transpose2Compute(param); + const std::vector &axis = param.Axis(); + bool shuffle_channel = IsShuffleChannel(axis); + if (shuffle_channel) { + if (param.InputX()->type() == typeid(int8_t)) { + ShuffleChannelCompute(param); + } else { + ShuffleChannelCompute(param); + } + } else { + if (param.InputX()->type() == typeid(int8_t)) { + Transpose2Compute(param); + } else { + Transpose2Compute(param); + } + } } } // namespace operators } // namespace paddle_mobile -#endif +#endif // TRANSPOSE2_OP diff --git a/src/operators/kernel/central-arm-func/relu_arm_func.h b/src/operators/kernel/central-arm-func/relu_arm_func.h deleted file mode 100644 index 38b2e6f334b4b24460f72450b01e4bdc2a6ff616..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/relu_arm_func.h +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef RELU_OP -#pragma once - -#include -#include "operators/op_param.h" -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#include -#endif - -namespace paddle_mobile { -namespace operators { - -template -struct ReluFunctor { - inline T operator()(T in) const { return in > 0 ? in : 0; } -}; - -/* - * @b 特化到具体平台的实现, param 从 op 层传入 - * */ -template -void ReluCompute(const ReluParam ¶m) { - const auto *input_x = param.InputX(); - auto *input_x_ptr = input_x->data(); - auto *out = param.Out(); - auto *out_ptr = out->mutable_data(); - - int numel = input_x->numel(); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#if __aarch64__ - if (numel > 0) { - int loop = numel >> 0x4; - int remain = numel & 0xF; - float32x4_t zero = vdupq_n_f32(0.f); - for (int i = 0; i < loop; ++i) { - float32x4_t r0 = vld1q_f32(input_x_ptr); - float32x4_t r1 = vld1q_f32(input_x_ptr + 4); - float32x4_t r2 = vld1q_f32(input_x_ptr + 8); - float32x4_t r3 = vld1q_f32(input_x_ptr + 12); - r0 = vmaxq_f32(r0, zero); - r1 = vmaxq_f32(r1, zero); - r2 = vmaxq_f32(r2, zero); - r3 = vmaxq_f32(r3, zero); - vst1q_f32(out_ptr, r0); - vst1q_f32(out_ptr + 4, r1); - vst1q_f32(out_ptr + 8, r2); - vst1q_f32(out_ptr + 12, r3); - input_x_ptr += 16; - out_ptr += 16; - } - for (int i = 0; i < remain; ++i) { - out_ptr[i] = (input_x_ptr[i] > 0) * input_x_ptr[i]; - } -#else - if (numel > 64) { - asm volatile( - "pld [%[input_x_ptr], #0] \n\t" - "vmov.f32 q8, #0.0 \n\t" - "subs %[num], %[num], #32 \n\t" - "blt end_num_%= \n\t" - "loop_num_%=: \n\t" - "pld [%[input_x_ptr], #1024] \n\t" - - "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" - - "vmax.f32 q0, q0, q8 \n\t" - "vmax.f32 q1, q1, q8 \n\t" - "vmax.f32 q2, q2, q8 \n\t" - "vmax.f32 q3, q3, q8 \n\t" - "vmax.f32 q4, q4, q8 \n\t" - "vmax.f32 q5, q5, q8 \n\t" - "vmax.f32 q6, q6, q8 \n\t" - "vmax.f32 q7, q7, q8 \n\t" - - "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" - "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" - "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" - "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" - - "subs %[num], %[num], #32 \n\t" - "bge loop_num_%= \n\t" - "end_num_%=: \n\t" - "cmp %[num], #0 \n\t" - "bge end_%= \n\t" - "mov r6, #4 \n\t" - "mul r5, %[num], r6 \n\t" - "add %[input_x_ptr], %[input_x_ptr], r5 \n\t" - "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" - "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" - "vmax.f32 q0, q0, q8 \n\t" - "vmax.f32 q1, q1, q8 \n\t" - "vmax.f32 q2, q2, q8 \n\t" - "vmax.f32 q3, q3, q8 \n\t" - "vmax.f32 q4, q4, q8 \n\t" - "vmax.f32 q5, q5, q8 \n\t" - "vmax.f32 q6, q6, q8 \n\t" - "vmax.f32 q7, q7, q8 \n\t" - "add %[out_ptr], %[out_ptr], r5 \n\t" - "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" - "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" - "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" - "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" - "end_%=: \n\t" - : - : - [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] "r"(numel) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "r5", - "r6"); -#endif - } else { -#endif - ReluFunctor func_; - math::Transform trans; - trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - } -#endif -} -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/transpose2_arm_func.h b/src/operators/kernel/central-arm-func/transpose2_arm_func.h deleted file mode 100644 index dea90e863b20f19820d60d9cce67b6849d3c467b..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/transpose2_arm_func.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef TRANSPOSE2_OP -#pragma once - -#include -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void Transpose2Compute(const Transpose2Param& param) { - const auto* input_x = param.InputX(); - const auto input_x_dims = input_x->dims(); - auto* out = param.Out(); - const auto axis = param.Axis(); - const auto* input_x_data = input_x->data(); - auto* out_data = out->mutable_data(); - - size_t ndim = axis.size(); - std::vector xdim(ndim); - std::vector xstride(ndim); - std::vector xout(ndim); - for (int i = 0; i < ndim; i++) { - int j = ndim - 1 - i; - xdim[j] = input_x_dims[axis[i]]; - xstride[j] = 1; - for (int k = axis[i] + 1; k < ndim; k++) { - xstride[j] *= input_x_dims[k]; - } - xout[j] = xstride[j] * xdim[j]; - } - - auto numel = input_x->numel(); - size_t pind = 0; - std::vector ind(ndim); - for (int i = 0; i < numel; i++) { - out_data[i] = input_x_data[pind]; - ind[0]++; - pind += xstride[0]; - for (int j = 0; j < ndim - 1; j++) { - if (ind[j] == xdim[j]) { - ind[j + 1]++; - ind[j] = 0; - pind += xstride[j + 1]; - pind -= xout[j]; - } else { - break; - } - } - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/transpose_arm_func.h b/src/operators/kernel/central-arm-func/transpose_arm_func.h index 1bd2e11a3405abc99c5a33be4ec9b61855f77b08..ef3d38eff23a44accc7ab71eb2095ff4a78c1571 100644 --- a/src/operators/kernel/central-arm-func/transpose_arm_func.h +++ b/src/operators/kernel/central-arm-func/transpose_arm_func.h @@ -21,23 +21,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -// vector pos; -// template -// void TransposeFunc(const int numel, const T* input, const vector axis, -// const vector old_strides, const vector -// new_strides, T* output) { -// for (int i = 0; i < numel; ++i) { -// int old_idx = 0; -// int idx = i; -// for (int j = 0; j < axis.size(); ++j) { -// int order = axis[j]; -// old_idx += (idx / new_strides[j]) * old_strides[order]; -// idx %= new_strides[j]; -// } -// output[i] = input[old_idx]; -// } -// } - template void TransposeCompute(const TransposeParam& param) { const auto* input_x = param.InputX(); diff --git a/src/operators/kernel/dequant_bn_relu_kernel.h b/src/operators/kernel/dequant_bn_relu_kernel.h index edea449dd68db474b14b02304bbdf63768e1bfb0..8cc8419dc61a6367959014a196d316b779bc9392 100644 --- a/src/operators/kernel/dequant_bn_relu_kernel.h +++ b/src/operators/kernel/dequant_bn_relu_kernel.h @@ -42,5 +42,27 @@ class FusionDequantAddBNReluKernel }; #endif +#ifdef FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP +template +class FusionDequantAddBNReluQuantKernel + : public framework::OpKernelBase< + DeviceType, FusionDequantAddBNReluQuantParam> { + public: + void Compute(const FusionDequantAddBNReluQuantParam ¶m); + bool Init(FusionDequantAddBNReluQuantParam *param); +}; +#endif + +#ifdef FUSION_DEQUANT_ADD_BN_QUANT_OP +template +class FusionDequantAddBNQuantKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantAddBNQuantParam ¶m); + bool Init(FusionDequantAddBNQuantParam *param); +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/feed_kernel.h b/src/operators/kernel/feed_kernel.h index 2b1220fee534040e5ccae5aee84adf3b4b6290b9..2f6fb6b31d9f9d29aa50104fe217869380cfb7ad 100644 --- a/src/operators/kernel/feed_kernel.h +++ b/src/operators/kernel/feed_kernel.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using namespace framework; + template class FeedKernel : public framework::OpKernelBase> { diff --git a/src/operators/kernel/relu_kernel.h b/src/operators/kernel/relu_kernel.h index 48f47c2de6df8d3aa9461fba915fd1a6406d4b9f..e9473ee63bdc297d0789c15f2fcad79fb29c143f 100644 --- a/src/operators/kernel/relu_kernel.h +++ b/src/operators/kernel/relu_kernel.h @@ -17,7 +17,6 @@ limitations under the License. */ #pragma once #include "framework/operator.h" - #include "operators/op_param.h" namespace paddle_mobile { @@ -30,6 +29,15 @@ class ReluKernel void Compute(const ReluParam& param); bool Init(ReluParam* param); }; + +template +class Relu6Kernel + : public framework::OpKernelBase> { + public: + void Compute(const ReluParam& param); + bool Init(ReluParam* param); +}; + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/quantize.h b/src/operators/math/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..be88d1d9ca889294d19b5712c2403e768f926137 --- /dev/null +++ b/src/operators/math/quantize.h @@ -0,0 +1,100 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef QUANT_OP + +#pragma once + +#include +#include "common/types.h" +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +inline int8_t Round(const float &x) { + return static_cast(x); +} + +template <> +inline int8_t Round(const float &x) { + return std::round(x); +} + +template <> +inline int8_t Round(const float &x) { + float v = std::round(x); + int32_t q = static_cast(v); + if (std::abs(std::abs(q - v) - 0.5) <= 0) { + if (std::abs(q) % 2 != 0) { + q = q + ((q > 0) ? -1 : 1); + } + } + return static_cast(q); +} + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +template +inline int32x4_t vround_f32(float32x4_t r) { + return vcvtq_s32_f32(r); +} + +template <> +inline int32x4_t vround_f32(float32x4_t r) { + float32x4_t plus = vdupq_n_f32(0.5); + float32x4_t minus = vdupq_n_f32(-0.5); + float32x4_t zero = vdupq_n_f32(0); + uint32x4_t more_than_zero = vcgtq_f32(r, zero); + float32x4_t temp = vbslq_f32(more_than_zero, plus, minus); + temp = vaddq_f32(r, temp); + int32x4_t ret = vcvtq_s32_f32(temp); + return ret; +} + +template <> +inline int32x4_t vround_f32(float32x4_t r) { + float32x4_t point5 = vdupq_n_f32(0.5); + int32x4_t one = vdupq_n_s32(1); + int32x4_t zero = vdupq_n_s32(0); + + int32x4_t rnd = vround_f32(r); + float32x4_t frnd = vcvtq_f32_s32(rnd); + frnd = vsubq_f32(frnd, r); + frnd = vabsq_f32(frnd); + uint32x4_t equal_point5 = vceqq_f32(frnd, point5); + int32x4_t abs_rnd = vabsq_s32(rnd); + abs_rnd = vandq_s32(abs_rnd, one); + uint32x4_t not_mod2 = vreinterpretq_u32_s32(abs_rnd); + uint32x4_t mask = vandq_u32(equal_point5, not_mod2); + uint32x4_t more_than_zero = vcgtq_s32(rnd, zero); + more_than_zero = vandq_u32(more_than_zero, vreinterpretq_u32_s32(one)); + mask = veorq_u32(more_than_zero, mask); + more_than_zero = veorq_u32(more_than_zero, vreinterpretq_u32_s32(one)); + mask = vaddq_u32(more_than_zero, mask); + int32x4_t smask = vreinterpretq_s32_u32(mask); + smask = vsubq_s32(smask, one); + rnd = vaddq_s32(rnd, smask); + return rnd; +} +#endif // __ARM_NEON__ + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // QUANT_OP diff --git a/src/operators/op_param.h b/src/operators/op_param.h index b1c3028fb089894e641bde4d015b13b5dc351db2..c00369cec7a17ef742420d20bcad786665992136 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -2530,18 +2530,13 @@ class QuantizeParam : public OpParam { // scale = max(abs(x)) online_scale_ = OpParam::GetVarValue("OutScale", outputs, scope); // offline - if (HasAttr("static_scale", attrs)) { - is_static_ = true; - static_scale_ = GetAttr("static_scale", attrs); + if (OpParam::HasAttr("InScale", attrs)) { + offline_ = true; + offline_scale_ = OpParam::GetVarValue("InScale", inputs, scope); } // x = round(scale * x) - if (HasAttr("round_type", attrs)) { - round_type_ = GetAttr("round_type", attrs); - } - // get paddings - paddings_ = std::vector({0, 0}); - if (HasAttr("paddings", attrs)) { - paddings_ = GetAttr>("paddings", attrs); + if (OpParam::HasAttr("round_type", attrs)) { + round_type_ = OpParam::GetAttr("round_type", attrs); } } @@ -2551,17 +2546,13 @@ class QuantizeParam : public OpParam { // op output RType *output_; RType *online_scale_; - // if static scale or not - bool is_static_ = false; - // quantize scale - float static_scale_ = 1.0f; + // quantize offline scale + RType *offline_scale_; + // if offine scale or not + bool offline_ = false; // round method type - // nearest_zero and nearest_even is valid currently // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; - // optional paddings - std::vector paddings_; - int8_t padding_val_; }; #endif @@ -2580,10 +2571,10 @@ class DequantizeParam : public OpParam { } activation_scale_ = OpParam::GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale - if (HasAttr("weight_scale", attrs)) { - weight_scale_ = GetAttr("weight_scale", attrs); + if (OpParam::HasAttr("weight_scale", attrs)) { + weight_scale_ = OpParam::GetAttr("weight_scale", attrs); } else { - weight_scale_ = GetAttr("max_range", attrs); + weight_scale_ = OpParam::GetAttr("max_range", attrs); } } @@ -2597,9 +2588,11 @@ class DequantizeParam : public OpParam { }; #endif -#if defined(FUSION_DEQUANT_ADD_BN_OP) || \ - defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \ - defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP) +#if defined(FUSION_DEQUANT_ADD_BN_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \ + defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_QUANT_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) template class FusionDequantBNParam : public DequantizeParam { typedef typename DtypeTensorTrait::gtype GType; @@ -2632,7 +2625,10 @@ class FusionDequantBNParam : public DequantizeParam { }; #endif -#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP) +#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_QUANT_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) template class FusionDequantAddBNParam : public FusionDequantBNParam { typedef typename DtypeTensorTrait::gtype GType; @@ -2697,5 +2693,79 @@ class FusionDequantAddBNReluParam : public FusionDequantAddBNParam { }; #endif +#ifdef FUSION_DEQUANT_ADD_BN_QUANT_OP +template +class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNQuantParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantAddBNParam(inputs, outputs, attrs, scope) { + // scale output + online_scale_ = OpParam::GetVarValue("OutScale", outputs, scope); + // offline + if (OpParam::HasAttr("static_scale", attrs)) { + is_static_ = true; + static_scale_ = OpParam::GetAttr("static_scale", attrs); + } + // x = round(scale * x) + if (OpParam::HasAttr("round_type", attrs)) { + round_type_ = OpParam::GetAttr("round_type", attrs); + } + } + + public: + RType *online_scale_; + // if static scale or not + bool is_static_ = false; + // quantize scale + float static_scale_ = 1.0f; + // round method type + // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; +}; +#endif + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP +template +class FusionDequantAddBNReluQuantParam + : public FusionDequantAddBNReluParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNReluQuantParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, + const Scope &scope) + : FusionDequantAddBNReluParam(inputs, outputs, attrs, scope) { + // scale output + online_scale_ = OpParam::GetVarValue("OutScale", outputs, scope); + // offline + if (OpParam::HasAttr("static_scale", attrs)) { + is_static_ = true; + static_scale_ = OpParam::GetAttr("static_scale", attrs); + } + // x = round(scale * x) + if (OpParam::HasAttr("round_type", attrs)) { + round_type_ = OpParam::GetAttr("round_type", attrs); + } + } + + public: + RType *online_scale_; + // if static scale or not + bool is_static_ = false; + // quantize scale + float static_scale_ = 1.0f; + // round method type + // RoundType round_type_ = ROUND_NEAREST_AWAY_ZERO; + RoundType round_type_ = ROUND_NEAREST_TOWARDS_ZERO; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/quantize_op.cpp b/src/operators/quantize_op.cpp index 6dd9d75af463753008b273b93253cb986eb90e80..bde99cfd5ab4b1c2f86c55cdb39dbf559c66d576 100644 --- a/src/operators/quantize_op.cpp +++ b/src/operators/quantize_op.cpp @@ -22,10 +22,7 @@ namespace operators { template void QuantizeOp::InferShape() const { - auto input_dims = this->param_.input_->dims(); - const std::vector &paddings = this->param_.paddings_; - input_dims[2] += 2 * paddings[0]; - input_dims[3] += 2 * paddings[1]; + const auto &input_dims = this->param_.input_->dims(); this->param_.output_->Resize(input_dims); auto scale_dims = framework::make_ddim(std::vector{1}); this->param_.online_scale_->Resize(scale_dims); diff --git a/src/operators/relu_op.cpp b/src/operators/relu_op.cpp index d6d83475ee7879f8bc967439dac2094df12c8617..7ceaa815cfb554be9fd2feccb2cc05c6bfa1aa33 100644 --- a/src/operators/relu_op.cpp +++ b/src/operators/relu_op.cpp @@ -24,17 +24,19 @@ void ReluOp::InferShape() const { this->param_.Out()->Resize(input_dims); } +template +void Relu6Op::InferShape() const { + auto input_dims = this->param_.InputX()->dims(); + this->param_.Out()->Resize(input_dims); +} + } // namespace operators } // namespace paddle_mobile -/* - * @b 每一个 op 都需要注册一下的, - * USE_OP的参数 和 REGISTER_OPERATOR的第一个参数 - * 都是需要和model中类型对应起来的 - * */ namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(relu, ops::ReluOp); +REGISTER_OPERATOR_CPU(relu6, ops::Relu6Op); #endif #ifdef PADDLE_MOBILE_MALI_GPU REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp); diff --git a/src/operators/relu_op.h b/src/operators/relu_op.h index 1c94a7f6d71484d0a4bd14e89d8518f6e73a660b..4bb67933db6ac1c174e267259df52b8eb79dbb35 100644 --- a/src/operators/relu_op.h +++ b/src/operators/relu_op.h @@ -25,25 +25,34 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using paddle_mobile::framework::Tensor; - template class ReluOp : public framework::OperatorWithKernel< DeviceType, ReluParam, operators::ReluKernel> { public: - /* - * @b op 的实例化方法, 需要调用父类的实例化方法, 以及实例化自己的参数结构体 - * */ ReluOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, std::shared_ptr scope) : framework::OperatorWithKernel, operators::ReluKernel>( type, inputs, outputs, attrs, scope) {} + void InferShape() const override; +}; - protected: +template +class Relu6Op : public framework::OperatorWithKernel< + DeviceType, ReluParam, + operators::Relu6Kernel> { + public: + Relu6Op(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + operators::Relu6Kernel>( + type, inputs, outputs, attrs, scope) {} + + void InferShape() const override; }; } // namespace operators diff --git a/test/net/test_benchmark.cpp b/test/net/test_benchmark.cpp index 3378229d0fb95745fb7b779f3ce043198d77681b..5ee7f363e5d254c6e9b363593e20a0286cbe9000 100644 --- a/test/net/test_benchmark.cpp +++ b/test/net/test_benchmark.cpp @@ -59,6 +59,13 @@ int main(int argc, char* argv[]) { } auto time4 = time(); std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n"; + std::ostringstream os("output tensor size: "); + os << output->numel() << "\n" << output->data()[0]; + for (int i = 1; i < output->numel(); ++i) { + os << ", " << output->data()[i]; + } + std::string output_str = os.str(); + std::cout << output_str << std::endl; } return 0; } diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index c3379df609fc1e18b8c3545e25849f8a7ff0461b..6c334edf1cd02525f56032a6947009a42c4ad74e 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -16,16 +16,30 @@ limitations under the License. */ #include "../test_helper.h" #include "../test_include.h" -int main() { +int main(int argc, char* argv[]) { + if (argc < 2) { + std::cout << "Usage: ./test_benchmark feed_shape [thread_num] [use_fuse]\n" + << "feed_shape: input tensor shape, such as 1,3,224,224.\n" + << "thread_num: optional int, threads count, default is 1.\n" + << "use_fuse: optional bool, default is 0.\n"; + return 1; + } + int thread_num = 1; + bool optimize = false; + char* feed_shape = argv[1]; + if (argc >= 3) { + thread_num = atoi(argv[2]); + } + if (argc >= 4) { + optimize = atoi(argv[3]); + } #ifdef PADDLE_MOBILE_FPGA paddle_mobile::PaddleMobile paddle_mobile; #endif #ifdef PADDLE_MOBILE_CPU paddle_mobile::PaddleMobile paddle_mobile; #endif - - paddle_mobile.SetThreadNum(1); - bool optimize = true; + paddle_mobile.SetThreadNum(thread_num); auto time1 = time(); if (paddle_mobile.Load(g_googlenet, optimize)) { auto time2 = paddle_mobile::time(); @@ -34,6 +48,11 @@ int main() { std::vector input; std::vector output; std::vector dims{1, 3, 224, 224}; + if (feed_shape) { + sscanf(feed_shape, "%d,%d,%d", &dims[1], &dims[2], &dims[3]); + } + std::cout << "feed shape: [" << dims[0] << ", " << dims[1] << ", " + << dims[2] << ", " << dims[3] << "]\n"; GetInput(g_test_image_1x3x224x224, &input, dims); // warmup for (int i = 0; i < 10; ++i) { @@ -44,7 +63,6 @@ int main() { output = paddle_mobile.Predict(input, dims); } auto time4 = time(); - std::cout << "predict cost: " << time_diff(time3, time4) / 10 << "ms\n"; } return 0; diff --git a/tools/op.cmake b/tools/op.cmake index e2254c3261d53d142e77f09c001d9cbebb5f85ff..ad9ffe95a3f0c9409f44b2621e1ea16595ec916e 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -252,6 +252,8 @@ if(NOT FOUND_MATCH) set(FUSION_DEQUANT_ADD_BN_OP ON) set(FUSION_DEQUANT_BN_RELU_OP ON) set(FUSION_DEQUANT_ADD_BN_RELU_OP ON) + set(FUSION_DEQUANT_ADD_BN_QUANT_OP ON) + set(FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -462,6 +464,12 @@ endif() if (FUSION_DEQUANT_ADD_BN_RELU_OP) add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) endif() +if (FUSION_DEQUANT_ADD_BN_QUANT_OP) +# add_definitions(-DFUSION_DEQUANT_ADD_BN_QUANT_OP) +endif() +if (FUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) +# add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_QUANT_OP) +endif() if (TANH_OP) @@ -476,4 +484,3 @@ endif() if (FUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP) endif() -