diff --git a/src/common/types.cpp b/src/common/types.cpp index 36c93046c1c09d0ec5043ef9a7514dedf212e738..312e491a35681e2fc75584106160a4c79e22e372 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -71,6 +71,8 @@ const char *G_OP_TYPE_SUM = "sum"; const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; +const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN = "fusion_dequant_add_bn"; +const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU = "fusion_dequant_bn_relu"; const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU = "fusion_dequant_add_bn_relu"; const char *G_OP_TYPE_TANH = "tanh"; @@ -136,6 +138,8 @@ std::unordered_map< {G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}}, {G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}}, {G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}}, + {G_OP_TYPE_FUSION_DEQUANT_ADD_BN, {{"X", "Scale"}, {"Y"}}}, + {G_OP_TYPE_FUSION_DEQUANT_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU, {{"X", "Scale"}, {"Out"}}}, {G_OP_TYPE_TANH, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_DECONV_RELU, {{"Input"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index 5704618c9e475781b22df4ae3a0ac3a994eb8c90..16ed1aef57432249b14c415b3a23042ca295b600 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -138,6 +138,8 @@ extern const char *G_OP_TYPE_ELEMENTWISE_MUL; extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE; +extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN; +extern const char *G_OP_TYPE_FUSION_DEQUANT_BN_RELU; extern const char *G_OP_TYPE_FUSION_DEQUANT_ADD_BN_RELU; extern const char *G_OP_TYPE_TANH; diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 135ef9083e42271fe63cdc29ee53e876f532c287..2534217d58674f912f0e5da741dfcae41827edf1 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -233,6 +233,14 @@ LOAD_OP1(quantize, CPU); #ifdef DEQUANT_OP LOAD_OP1(dequantize, CPU); #endif +#ifdef FUSION_DEQUANT_ADD_BN_OP +LOAD_OP1(fusion_dequant_add_bn, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_add_bn); +#endif +#ifdef FUSION_DEQUANT_BN_RELU_OP +LOAD_OP1(fusion_dequant_bn_relu, CPU); +LOAD_FUSION_MATCHER(fusion_dequant_bn_relu); +#endif #ifdef FUSION_DEQUANT_ADD_BN_RELU_OP LOAD_OP1(fusion_dequant_add_bn_relu, CPU); LOAD_FUSION_MATCHER(fusion_dequant_add_bn_relu); diff --git a/src/operators/depthwise_conv_op.h b/src/operators/depthwise_conv_op.h index 102d65670d3e50acd15745e95b85d7b843994ed7..26253e0e0a7d3c52808a691d4257e7074e1da6e2 100644 --- a/src/operators/depthwise_conv_op.h +++ b/src/operators/depthwise_conv_op.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "framework/operator.h" -#include "operators/kernel/depthwise_conv_kernel.h" +#include "operators/kernel/conv_kernel.h" namespace paddle_mobile { namespace operators { @@ -26,19 +26,16 @@ namespace operators { template class DepthwiseConvOp : public framework::OperatorWithKernel< DeviceType, ConvParam, - operators::DepthwiseConvKernel> { + operators::ConvKernel> { public: DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, std::shared_ptr scope) - : framework::OperatorWithKernel< - DeviceType, ConvParam, - operators::DepthwiseConvKernel>( + : framework::OperatorWithKernel, + operators::ConvKernel>( type, inputs, outputs, attrs, scope) {} void InferShape() const override; - - private: }; } // namespace operators diff --git a/src/operators/kernel/arm/depthwise_conv_kernel.cpp b/src/operators/fusion_dequant_add_bn_op.cpp similarity index 58% rename from src/operators/kernel/arm/depthwise_conv_kernel.cpp rename to src/operators/fusion_dequant_add_bn_op.cpp index 000d59baa8c804201cbd2e2a731c2077196b698f..4df50af22b0dc9e214b0cabe303bf70edf50c307 100644 --- a/src/operators/kernel/arm/depthwise_conv_kernel.cpp +++ b/src/operators/fusion_dequant_add_bn_op.cpp @@ -12,27 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef DEPTHWISECONV_OP +#ifdef FUSION_DEQUANT_ADD_BN_OP -#include "operators/kernel/depthwise_conv_kernel.h" -#include "operators/kernel/central-arm-func/depthwise_conv_arm_func.h" +#include "operators/fusion_dequant_add_bn_op.h" namespace paddle_mobile { namespace operators { -template <> -bool DepthwiseConvKernel::Init(ConvParam *param) { - return true; +template +void FusionDequantAddBNOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); } -template <> -void DepthwiseConvKernel::Compute(const ConvParam ¶m) { - DepthwiseConvCompute(param); -} - -template class DepthwiseConvKernel; - } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_add_bn, ops::FusionDequantAddBNMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_add_bn, ops::FusionDequantAddBNOp); +#endif + #endif diff --git a/src/operators/fusion_dequant_add_bn_op.h b/src/operators/fusion_dequant_add_bn_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8c4f353a81705c41c75a5aff92f2637b92755a2c --- /dev/null +++ b/src/operators/fusion_dequant_add_bn_op.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_ADD_BN_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_add_bn_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantAddBNMatcher : public framework::FusionOpMatcher { + public: + FusionDequantAddBNMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > + std::make_shared(G_OP_TYPE_BATCHNORM); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, + {G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_ADD_BN; } +}; + +template +class FusionDequantAddBNOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNParam, + operators::FusionDequantAddBNKernel> { + public: + FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantAddBNParam, + operators::FusionDequantAddBNKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/fusion_dequant_add_bn_relu_op.h b/src/operators/fusion_dequant_add_bn_relu_op.h index dbd9ad0de2ece751ffd4da05cb09f0091a5755aa..b33d3c210ca56f27b769789fee08023ebb8c80de 100644 --- a/src/operators/fusion_dequant_add_bn_relu_op.h +++ b/src/operators/fusion_dequant_add_bn_relu_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include #include "framework/operator.h" #include "framework/program/program-optimize/fusion_op_register.h" -#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/kernel/dequant_bn_relu_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/kernel/depthwise_conv_kernel.h b/src/operators/fusion_dequant_bn_relu_op.cpp similarity index 56% rename from src/operators/kernel/depthwise_conv_kernel.h rename to src/operators/fusion_dequant_bn_relu_op.cpp index 3ee5bf86e97baa3970239e32b7fd5fc341e09f92..c843889a61a128c86915b14b0229ed172df2325b 100644 --- a/src/operators/kernel/depthwise_conv_kernel.h +++ b/src/operators/fusion_dequant_bn_relu_op.cpp @@ -12,29 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef DEPTHWISECONV_OP +#ifdef FUSION_DEQUANT_BN_RELU_OP -#pragma once - -#include "framework/operator.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" +#include "operators/fusion_dequant_bn_relu_op.h" namespace paddle_mobile { namespace operators { -using framework::OpKernelBase; +template +void FusionDequantBNReluOp::InferShape() const { + const auto& input_dims = this->param_.input_->dims(); + this->param_.output_->Resize(input_dims); +} -template -class DepthwiseConvKernel - : public OpKernelBase> { - public: - void Compute(const ConvParam ¶m); - bool Init(ConvParam *param); -}; } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; +REGISTER_FUSION_MATCHER(fusion_dequant_bn_relu, + ops::FusionDequantBNReluMatcher); + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(fusion_dequant_bn_relu, ops::FusionDequantBNReluOp); +#endif + #endif diff --git a/src/operators/fusion_dequant_bn_relu_op.h b/src/operators/fusion_dequant_bn_relu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b556df1e3707736be0eaf58eb8323cdbb64cbd74 --- /dev/null +++ b/src/operators/fusion_dequant_bn_relu_op.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_DEQUANT_BN_RELU_OP + +#pragma once + +#include +#include +#include "framework/operator.h" +#include "framework/program/program-optimize/fusion_op_register.h" +#include "operators/kernel/dequant_bn_relu_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +class FusionDequantBNReluMatcher : public framework::FusionOpMatcher { + public: + FusionDequantBNReluMatcher() { + node_ = framework::Node(G_OP_TYPE_DEQUANTIZE); + node_ > std::make_shared(G_OP_TYPE_BATCHNORM) > + std::make_shared(G_OP_TYPE_RELU); + } + + void FolderNodes( + framework::Node *node, + std::vector> *removed_nodes) { + node->Folder(node_.Depth(), Type(), + {{G_OP_TYPE_BATCHNORM, + {{"Scale", "BNScale"}, + {"Mean", "BNMean"}, + {"Bias", "BNBias"}, + {"Variance", "BNVariance"}}}}, + removed_nodes); + } + + std::string Type() { return G_OP_TYPE_FUSION_DEQUANT_BN_RELU; } +}; + +template +class FusionDequantBNReluOp + : public framework::OperatorWithKernel< + DeviceType, FusionDequantBNReluParam, + operators::FusionDequantBNReluKernel> { + public: + FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, FusionDequantBNReluParam, + operators::FusionDequantBNReluKernel>( + type, inputs, outputs, attrs, scope) {} + // inference output shape + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 840be6c67d2e350c914a7d8aa8e9a32acdd00fb1..6b9bd5c970590d2405c5f58a5f7016be5949a511 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -22,33 +22,35 @@ namespace operators { template <> bool ConvKernel::Init(ConvParam *param) { + bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3; + bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1]; if (param->Filter()->type() == typeid(int8_t)) { - if (param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1] && - param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3 && param->Strides()[0] < 3 && + if (depth3x3 && param->Strides()[0] < 3 && param->Strides()[0] == param->Strides()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; } else { param->ExecMode() = ConvParam::EXEC_GEMM_INT8; } } else { - if (param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1] && - param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1) { + if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 1 && param->Paddings()[0] == 1 && + param->Paddings()[0] == param->Paddings()[1]) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT; - } else if (param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1] && - param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; + } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 2 && param->Paddings()[0] == 0 && + param->Paddings()[0] == param->Paddings()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT; + } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 2 && param->Paddings()[0] == 1 && + param->Paddings()[0] == param->Paddings()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT; #ifndef __aarch64__ - } else if (param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Strides()[0] == param->Strides()[1] && + } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && - param->Filter()->dims()[2] == 3 && param->Strides()[0] == 1 && - param->Dilations()[0] == 1 && param->Output()->dims()[1] >= 16 && + param->Strides()[0] == 1 && param->Dilations()[0] == 1 && + param->Output()->dims()[1] >= 16 && param->Input()->dims()[1] >= 16 && param->Input()->dims()[2] <= 140 /* refered from ncnn */) { param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; @@ -78,9 +80,13 @@ void ConvKernel::Compute(const ConvParam ¶m) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), nullptr, false); break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); + case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: + math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), nullptr, false); + break; + case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: + math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), + nullptr, false); break; case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); diff --git a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_add_bn_kernel.cpp similarity index 86% rename from src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp rename to src/operators/kernel/arm/dequant_add_bn_kernel.cpp index bfe1935c216f94d660997b1bfa42f18e63295992..65fb0190f76a34a584d065bd43841567e9658bb8 100644 --- a/src/operators/kernel/arm/dequant_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/dequant_add_bn_kernel.cpp @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +#ifdef FUSION_DEQUANT_ADD_BN_OP -#include "operators/kernel/dequant_add_bn_relu_kernel.h" +#include "operators/kernel/dequant_add_bn_kernel.h" #include #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include @@ -24,8 +24,8 @@ namespace paddle_mobile { namespace operators { template <> -bool FusionDequantAddBNReluKernel::Init( - FusionDequantAddBNReluParam *param) { +bool FusionDequantAddBNKernel::Init( + FusionDequantAddBNParam *param) { // elementwise add params const Tensor *bias = param->bias_; // batch norm params @@ -49,8 +49,8 @@ bool FusionDequantAddBNReluKernel::Init( } template <> -void FusionDequantAddBNReluKernel::Compute( - const FusionDequantAddBNReluParam ¶m) { +void FusionDequantAddBNKernel::Compute( + const FusionDequantAddBNParam ¶m) { const int32_t *input = param.input_->data(); const float *bn_scale = param.bn_scale_->data(); const float *bn_bias = param.bn_bias_->data(); @@ -78,7 +78,6 @@ void FusionDequantAddBNReluKernel::Compute( remain = spatial_size & 0xF; float32x4_t __scale = vdupq_n_f32(scale); float32x4_t __bias = vdupq_n_f32(bias); - float32x4_t __zero = vdupq_n_f32(0.f); for (int k = 0; k < loop; ++k, x += 16, y += 16) { int32x4_t r0 = vld1q_s32(x); @@ -93,10 +92,6 @@ void FusionDequantAddBNReluKernel::Compute( f1 = vmlaq_f32(__bias, __scale, f1); f2 = vmlaq_f32(__bias, __scale, f2); f3 = vmlaq_f32(__bias, __scale, f3); - f0 = vmaxq_f32(__zero, f0); - f1 = vmaxq_f32(__zero, f1); - f2 = vmaxq_f32(__zero, f2); - f3 = vmaxq_f32(__zero, f3); vst1q_f32(y, f0); vst1q_f32(y + 4, f1); vst1q_f32(y + 8, f2); @@ -104,7 +99,7 @@ void FusionDequantAddBNReluKernel::Compute( } #endif // __ARM_NEON__ for (int k = 0; k < remain; ++k) { - y[k] = std::max(scale * x[k] + bias, 0.f); + y[k] = scale * x[k] + bias; } } } @@ -113,4 +108,4 @@ void FusionDequantAddBNReluKernel::Compute( } // namespace operators } // namespace paddle_mobile -#endif // FUSION_DEQUANT_ADD_BN_RELU_OP +#endif // FUSION_DEQUANT_ADD_BN_OP diff --git a/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp b/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4d656712c193aa81a8be11c53856c868e2b82483 --- /dev/null +++ b/src/operators/kernel/arm/dequant_bn_relu_kernel.cpp @@ -0,0 +1,150 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/dequant_bn_relu_kernel.h" +#include +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +namespace paddle_mobile { +namespace operators { + +#if defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_RELU_OP) +void DequantBNReluCompute(const FusionDequantBNParam *param) { + const int32_t *input = param->input_->data(); + const float *bn_scale = param->bn_scale_->data(); + const float *bn_bias = param->bn_bias_->data(); + // dequantize params + const float activation_scale = param->activation_scale_->data()[0]; + const float weight_scale = param->weight_scale_; + const float dequant_scale = activation_scale / weight_scale; + + float *output = param->output_->mutable_data(); + int batch_size = param->input_->dims()[0]; + int channels = param->input_->dims()[1]; + size_t spatial_size = param->input_->dims()[2] * param->input_->dims()[3]; + + #pragma omp parallel for collapse(2) + for (int batch = 0; batch < batch_size; ++batch) { + for (int c = 0; c < channels; ++c) { + float scale = bn_scale[c] * dequant_scale; + float bias = bn_bias[c]; + size_t offset = (batch * channels + c) * spatial_size; + const int32_t *x = input + offset; + float *y = output + offset; + size_t remain = spatial_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = spatial_size >> 4; + remain = spatial_size & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __bias = vdupq_n_f32(bias); + float32x4_t __zero = vdupq_n_f32(0.f); + + for (int k = 0; k < loop; ++k, x += 16, y += 16) { + int32x4_t r0 = vld1q_s32(x); + int32x4_t r1 = vld1q_s32(x + 4); + int32x4_t r2 = vld1q_s32(x + 8); + int32x4_t r3 = vld1q_s32(x + 12); + float32x4_t f0 = vcvtq_f32_s32(r0); + float32x4_t f1 = vcvtq_f32_s32(r1); + float32x4_t f2 = vcvtq_f32_s32(r2); + float32x4_t f3 = vcvtq_f32_s32(r3); + f0 = vmlaq_f32(__bias, __scale, f0); + f1 = vmlaq_f32(__bias, __scale, f1); + f2 = vmlaq_f32(__bias, __scale, f2); + f3 = vmlaq_f32(__bias, __scale, f3); + f0 = vmaxq_f32(__zero, f0); + f1 = vmaxq_f32(__zero, f1); + f2 = vmaxq_f32(__zero, f2); + f3 = vmaxq_f32(__zero, f3); + vst1q_f32(y, f0); + vst1q_f32(y + 4, f1); + vst1q_f32(y + 8, f2); + vst1q_f32(y + 12, f3); + } +#endif // __ARM_NEON__ + for (int k = 0; k < remain; ++k) { + y[k] = std::max(scale * x[k] + bias, 0.f); + } + } + } +} +#endif + +#ifdef FUSION_DEQUANT_BN_RELU_OP +template <> +bool FusionDequantBNReluKernel::Init( + FusionDequantBNReluParam *param) { + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = bn_bias_ptr[c] - inv_scale * mean_ptr[c]; + } + return true; +} + +template <> +void FusionDequantBNReluKernel::Compute( + const FusionDequantBNReluParam ¶m) { + DequantBNReluCompute(¶m); +} +#endif // FUSION_DEQUANT_BN_RELU_OP + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template <> +bool FusionDequantAddBNReluKernel::Init( + FusionDequantAddBNReluParam *param) { + // elementwise add params + const Tensor *bias = param->bias_; + // batch norm params + const Tensor *bn_mean = param->bn_mean_; + const Tensor *bn_variance = param->bn_variance_; + Tensor *bn_scale = param->bn_scale_; + Tensor *bn_bias = param->bn_bias_; + const float epsilon = param->epsilon_; + + const float *bias_ptr = bias->data(); + const float *mean_ptr = bn_mean->data(); + const float *var_ptr = bn_variance->data(); + float *bn_scale_ptr = bn_scale->mutable_data(); + float *bn_bias_ptr = bn_bias->mutable_data(); + for (int c = 0; c < bn_scale->numel(); ++c) { + float inv_scale = bn_scale_ptr[c] / (std::sqrt(var_ptr[c] + epsilon)); + bn_scale_ptr[c] = inv_scale; + bn_bias_ptr[c] = inv_scale * (bias_ptr[c] - mean_ptr[c]) + bn_bias_ptr[c]; + } + return true; +} + +template <> +void FusionDequantAddBNReluKernel::Compute( + const FusionDequantAddBNReluParam ¶m) { + DequantBNReluCompute(¶m); +} +#endif // FUSION_DEQUANT_ADD_BN_RELU_OP + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index 3b5924ecbf886159d129212cc36c8630cb8cce2f..988f0b0f03b84c25a2e17e9d14054f99dcce4916 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -132,10 +132,10 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { // param.Output(), false); if (param.Paddings()[0] == 0) { math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - *param.Bias(), true); + param.Bias(), true); } else { math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), *param.Bias(), true); + param.Output(), param.Bias(), true); } } else { ConvAddBasic(param); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index b01a654c713f2328d62714f23af68d606380d203..ce111ed78f7b81affffc646b49a00e6d15cbb697 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -107,9 +107,15 @@ inline void GemmConv(const ConvParam ¶m) { Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, + if (param.Input()->type() == typeid(int8_t)) { + math::matmul_int8(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0)); + } else { + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0)); + } } } } diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h deleted file mode 100644 index b48b03491bab9594f36cad0b21485ae72c8c3c31..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef DEPTHWISECONV_OP - -#pragma once -#include -#include "operators/kernel/central-arm-func/conv_arm_func.h" -#include "operators/math/depthwise_conv3x3.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void DepthwiseConvCompute(const ConvParam ¶m) { - Tensor Bias; - Bias.mutable_data({param.Groups()}); - if (param.Groups() == param.Input()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - &Bias, false); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - // math::DepthwiseConv3x3(param.Input(), param.Strides(), - // param.Paddings(), - // param.Filter(), &Bias, param.Output(), false); - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), - Bias, false); - - } else { - GemmConv(param); - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 07e634e3be9648520357871d91d6677aec6b5c0e..62e8ae03d9119cafc3c5716042569a90f077325c 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -73,8 +73,8 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, - static_cast(1), out, static_cast(0)); + math::matmul_int8(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); } else { out->mutable_data(); diff --git a/src/operators/kernel/dequant_add_bn_relu_kernel.h b/src/operators/kernel/dequant_add_bn_kernel.h similarity index 75% rename from src/operators/kernel/dequant_add_bn_relu_kernel.h rename to src/operators/kernel/dequant_add_bn_kernel.h index 7138e5c415caca6766913f9959bd41def0943d34..2fcdad6903e378121c265080f68c35c451714e30 100644 --- a/src/operators/kernel/dequant_add_bn_relu_kernel.h +++ b/src/operators/kernel/dequant_add_bn_kernel.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +#ifdef FUSION_DEQUANT_ADD_BN_OP #include "framework/operator.h" #include "operators/op_param.h" @@ -23,12 +23,12 @@ namespace paddle_mobile { namespace operators { template -class FusionDequantAddBNReluKernel +class FusionDequantAddBNKernel : public framework::OpKernelBase> { + FusionDequantAddBNParam> { public: - void Compute(const FusionDequantAddBNReluParam ¶m); - bool Init(FusionDequantAddBNReluParam *param); + void Compute(const FusionDequantAddBNParam ¶m); + bool Init(FusionDequantAddBNParam *param); }; } // namespace operators diff --git a/src/operators/kernel/dequant_bn_relu_kernel.h b/src/operators/kernel/dequant_bn_relu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..edea449dd68db474b14b02304bbdf63768e1bfb0 --- /dev/null +++ b/src/operators/kernel/dequant_bn_relu_kernel.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#ifdef FUSION_DEQUANT_BN_RELU_OP +template +class FusionDequantBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantBNReluParam ¶m); + bool Init(FusionDequantBNReluParam *param); +}; +#endif + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluKernel + : public framework::OpKernelBase> { + public: + void Compute(const FusionDequantAddBNReluParam ¶m); + bool Init(FusionDequantAddBNReluParam *param); +}; +#endif + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 39b9b8d3f1c5c2bf09a3db5de5216dd1a08b491a..e74659ab4f0cd86c5a6f742a8313bbfb06dc51d3 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -1272,13 +1272,13 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias) { #if __ARM_NEON const float *input_data = input->data(); const float *filter_data = filter->data(); float *output_data = output->data(); - const float *bias_data = bias.data(); + const float *bias_data = bias->data(); const int in_h = static_cast(input->dims()[2]); const int in_w = static_cast(input->dims()[3]); @@ -1905,7 +1905,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias) { #if __ARM_NEON @@ -1925,7 +1925,7 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, for (int c = 0; c < input_channel; c++) { const float *filter_data = filter->data() + c * 9; const float *input_data = input->data() + c * inhxw; - const float *bias_data = bias.data() + c; + const float *bias_data = bias->data() + c; float *output_data = output->data() + c * outhxw; float w00 = filter_data[0]; float w01 = filter_data[1]; diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h index 72cadaf21553a428e1479d5548d2aa5f4fcdf90c..34e68e42664a65f9203a30562c2780210c05a42e 100644 --- a/src/operators/math/depthwise_conv3x3.h +++ b/src/operators/math/depthwise_conv3x3.h @@ -50,7 +50,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias); void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, @@ -62,7 +62,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, void DepthwiseConv3x3s2p0(const framework::Tensor *input, const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor bias, + framework::Tensor *output, framework::Tensor *bias, bool if_bias); // TODO(hjchen2) need to be implemented diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index 8498992fcecbcb2c9a773fba874e108c013a04fc..e409fe07dc55bcf68748f0f25b3b63480d25cd56 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -23,10 +23,12 @@ limitations under the License. */ #if __aarch64__ #define MR_INT8 4 +#define NR_INT8 2 #define MR 6 #define NR 16 #else #define MR_INT8 4 +#define NR_INT8 2 #define MR 6 #define NR 8 #endif @@ -193,52 +195,58 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int small block inner product void AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); + void AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc); void AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc); // 8 bits int inner product - void InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, - const int8_t *a, const int8_t *b, int8_t beta, - int32_t *c, int32_t *C, int32_t ldc, bool relu, - int8_t *bias); + void InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int32_t *C, + int32_t ldc, bool relu); + void InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int8_t *C, + int32_t ldc, bool relu, int32_t *bias); // 8 bits int pack function void PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, + int32_t lda, int8_t *buffer); void PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, + int32_t ldb, int8_t *buffer); void PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); void PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer); void PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer); + void PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer); + void PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer); // 8 bits int matrix product - void Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, int32_t *C, - int32_t ldc, bool relu, int8_t *bias); - void Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, - int32_t *C, int32_t ldc, bool relu, int8_t *bias); + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, int32_t *C, + int32_t ldc, bool relu, int32_t *bias); + void Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, int8_t *C, + int32_t ldc, bool relu, int32_t *bias); + void Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + int32_t *C, int32_t ldc, bool relu, int32_t *bias); // 8 bits int write back - // C = alpha * A * B + beta * C - void WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); // C = A * B void WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc); - // C = A * B + C - void WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); - // C = A * B + bias - void WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias); - // C = A * B + C, relu(C) - void WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc); - // C = A * B + bias, relu(C) - void WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias); + // C = A * B + bias, scale * relu(C) + void WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale); + // C = A * B + bias, scale * C + void WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale); private: int MC = 0; @@ -254,7 +262,7 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, // 8 bits int int8_t *packedA_int8; int8_t *packedB_int8; - int32_t *packedC_int8; + int32_t *packedC_int32; int8_t *zero_int8; }; diff --git a/src/operators/math/gemm_int8.cpp b/src/operators/math/gemm_int8.cpp index b16db7fe6acf0c3c7fb2902c9fb3f6e3dc81a65f..555672720f2be51631ea10808ce6891b08df0721 100644 --- a/src/operators/math/gemm_int8.cpp +++ b/src/operators/math/gemm_int8.cpp @@ -18,6 +18,8 @@ limitations under the License. */ #include "operators/math/gemm.h" #if __ARM_NEON #include +#include + #endif #ifdef _OPENMP #include @@ -62,7 +64,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, "pld [%[b_ptr], #128] \n\t" "vld1.s8 {d0-d3}, [%[a_ptr]]! \n\t" // load A 8 cols "vld1.s8 {d8-d11}, [%[b_ptr]]! \n\t" // load B first 4 rows - "vmovl.s8 q2, d0 \n\t" // process B first 4 + "vmovl.s8 q2, d0 \n\t" // process B first // rows "vmovl.s8 q3, d8 \n\t" "vmlal.s16 q8, d6, d4[0]\n\t" @@ -241,6 +243,132 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, #endif // __ARM_NEON } +void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, + int32_t ldc) { +#if __ARM_NEON +#if __aarch64__ +// TODO(wzzju) +#else +#define PADDLE_LABEL_LOOP "1" +#define PADDLE_LABEL_AFTER_LOOP "2" + asm volatile( + "lsl %[ldc], %[ldc], #2 \n\t" // sizeof(int32) == 4 + "vldr d0, [%[b], #0] \n\t" + "vmov.s32 q8, #0 \n\t" + "vldr d4, [%[a], #0] \n\t" + "vmov.s32 q9, q8 \n\t" + "vldr d2, [%[b], #16] \n\t" + "vmov.s32 q10, q8 \n\t" + "vldr d6, [%[a], #16] \n\t" + "vmov.s32 q11, q8 \n\t" + "vldr d1, [%[b], #8]\n\t" + "vmov.s32 q12, q8 \n\t" + "vldr d5, [%[a], #8]\n" + "vmov.s32 q13, q8 \n\t" + "vldr d3, [%[b], #24]\n\t" + "vmov.s32 q14, q8 \n\t" + "vldr d7, [%[a], #24]\n" + "vmov.s32 q15, q8 \n\t" + + PADDLE_LABEL_LOOP + ": \n\t" + "vmull.s8 q4, d0, d4 \n\t" // first half + "add %[b], %[b], #32 \n\t" + "vmull.s8 q5, d2, d4 \n\t" + "vldr d4, [%[a], #32] \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vmull.s8 q7, d2, d6 \n\t" + "vldr d6, [%[a], #48] \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vmlal.s8 q5, d3, d5 \n\t" + "vldr d5, [%[a], #40] \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + "vldr d7, [%[a], #56] \n\t" + + "vpadal.s16 q8, q4 \n\t" // pairwise-add + "add %[a], %[a], #64 \n\t" + "vpadal.s16 q9, q5 \n\t" + "subs %[k], %[k], #16 \n\t" + "vpadal.s16 q10, q6 \n\t" + "vpadal.s16 q11, q7 \n\t" + + "beq " PADDLE_LABEL_AFTER_LOOP + "f \n\t" + + "vmull.s8 q4, d0, d4 \n\t" // first half + "vmull.s8 q5, d2, d4 \n\t" + "vldr d4, [%[a], #0] \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vldr d0, [%[b], #0] \n\t" + "vmull.s8 q7, d2, d6 \n\t" + "vldr d2, [%[b], #16] \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vldr d6, [%[a], #16] \n\t" + "vmlal.s8 q5, d3, d5 \n\t" + "vldr d5, [%[a], #8] \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vldr d1, [%[b], #8] \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + "vldr d3, [%[b], #24] \n\t" + + "vpadal.s16 q12, q4 \n\t" // pairwise-add + "vldr d7, [%[a], #24] \n\t" + "vpadal.s16 q13, q5 \n\t" + "vpadal.s16 q14, q6 \n\t" + "vpadal.s16 q15, q7 \n\t" + + "b " PADDLE_LABEL_LOOP "b \n\t" + + PADDLE_LABEL_AFTER_LOOP + ": \n\t" + "vmull.s8 q4, d0, d4 \n\t" // first half + "vmull.s8 q5, d2, d4 \n\t" + "vmull.s8 q6, d0, d6 \n\t" + "vmull.s8 q7, d2, d6 \n\t" + + "vmlal.s8 q4, d1, d5 \n\t" // second half + "vmlal.s8 q5, d3, d5 \n\t" + "vmlal.s8 q6, d1, d7 \n\t" + "vmlal.s8 q7, d3, d7 \n\t" + + "vpadal.s16 q12, q4 \n\t" // pairwise-add + "vpadal.s16 q13, q5 \n\t" + "vpadal.s16 q14, q6 \n\t" + "vpadal.s16 q15, q7 \n\t" + + "vpadd.s32 d0, d16, d17 \n\t" // reduce to int32 + "vpadd.s32 d1, d18, d19 \n\t" + "vpadd.s32 d2, d20, d21 \n\t" + "vpadd.s32 d3, d22, d23 \n\t" + "vpadd.s32 d4, d24, d25 \n\t" + "vpadd.s32 d5, d26, d27 \n\t" + "vpadd.s32 d6, d28, d29 \n\t" + "vpadd.s32 d7, d30, d31 \n\t" + + "vpadd.s32 d8, d0, d1 \n\t" // reduce to int32 again + "vpadd.s32 d9, d2, d3 \n\t" + "vpadd.s32 d10, d4, d5 \n\t" + "vpadd.s32 d11, d6, d7 \n\t" + + "vst1.32 {d8}, [%[c]], %[ldc] \n\t" + "vst1.32 {d9}, [%[c]], %[ldc] \n\t" + "vst1.32 {d10}, [%[c]], %[ldc] \n\t" + "vst1.32 {d11}, [%[c]] \n\t" + + : [k] "+r"(k), [a] "+r"(a), [b] "+r"(b), [c] "+r"(c) + : [ldc] "r"(ldc) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#undef PADDLE_LABEL_AFTER_LOOP +#undef PADDLE_LABEL_LOOP + +#endif // __aarch64__ +#endif // __ARM_NEON +} + // 8 bits int small block inner product void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, int32_t ldc) { @@ -539,51 +667,213 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, } // 8 bits int inner product -void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, int8_t alpha, - const int8_t *a, const int8_t *b, int8_t beta, - int32_t *c, int32_t *C, int32_t ldc, bool relu, - int8_t *bias) { +void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, + const int8_t *b, float beta, int32_t *c, int32_t *C, + int32_t ldc, bool relu) { #pragma omp parallel for - for (int32_t j = 0; j < nc; j += NR) { + for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) { #if __aarch64__ // TODO(wzzju) #else // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); - AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); #endif // __aarch64__ } } - if (alpha != 1) { - WriteWithAlphaBeta(mc, nc, c, C, ldc); - return; - } - if (beta == 0) { + if (!relu) { WriteBasic(mc, nc, c, C, ldc); return; } - if (beta == 1 && !relu) { - if (bias == nullptr) { - WriteWithAdd(mc, nc, c, C, ldc); - } else { - WriteWithAddV1(mc, nc, c, C, ldc, bias); +} + +void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, + const int8_t *a, const int8_t *b, float beta, + int32_t *c, int8_t *C, int32_t ldc, bool relu, + int32_t *bias) { +#pragma omp parallel for + for (int32_t j = 0; j < nc; j += NR_INT8) { + for (int32_t i = 0; i < mc; i += MR_INT8) { +#if __aarch64__ + // TODO(wzzju) +#else + // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); + AddDot4x2(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); +#endif // __aarch64__ } + } + if (relu) { + WriteWithAddReluScale(mc, nc, c, C, ldc, bias, alpha); return; + } else { + WriteWithAddScale(mc, nc, c, C, ldc, bias, alpha); } - if (beta == 1 && relu) { - if (bias == nullptr) { - WriteWithAddRelu(mc, nc, c, C, ldc); - } else { - WriteWithAddReluV1(mc, nc, c, C, ldc, bias); +} + +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + + for (int32_t i = 0; i < i_length; i += 4) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * KC; + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * KC; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } } - return; } } + // 8 bits int PackMatrixA_4r void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int8_t *a0, *a1, *a2, *a3; - for (int32_t i = 0; i < m - m_tail; i += MR_INT8) { + for (int32_t i = 0; i < m - m_tail; i += 4) { a0 = A + i * lda; a1 = A + (i + 1) * lda; a2 = A + (i + 2) * lda; @@ -625,7 +915,7 @@ void Gemm::PackMatrixA_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { const int32_t i_length = m - m_tail; - for (int32_t i = 0; i < i_length; i += MR_INT8) { + for (int32_t i = 0; i < i_length; i += 6) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -676,11 +966,79 @@ void Gemm::PackMatrixA_6r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, } } +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; + for (int32_t j = 0; j < j_length; j += 2) { + int8_t *local_buffer = buffer + j * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j); + const int8_t *b1 = &B((i << 4), j + 1); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j); + const int8_t *b1 = &B((k_count << 4), j + 1); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j_length); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = 0; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j_length); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + // 8 bits int PackMatrixB void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; - for (int32_t j = 0; j < j_length; j += NR) { + for (int32_t j = 0; j < j_length; j += 8) { int8_t *local_buffer = buffer + j * k; for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); @@ -715,7 +1073,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, for (int32_t j = j_length; j < n; ++j) { *local_buffer++ = *b0++; } - for (int32_t j = n; j < j_length + NR; ++j) { + for (int32_t j = n; j < j_length + 8; ++j) { *local_buffer++ = 0; } } @@ -723,19 +1081,20 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, } // 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, - int32_t lda, const int8_t *B, int32_t ldb, int8_t beta, - int32_t *C, int32_t ldc, bool relu, int8_t *bias) { +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + int32_t *C, int32_t ldc, bool relu, int32_t *bias) { // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L2 cache is 0.5~4 Mib (Contex-A72 cluster) int32_t L1 = 32 * 1024; int32_t L2 = 512 * 1024; - KC = k; + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; MC = L1 / (KC * sizeof(int8_t)); NC = L2 / (KC * sizeof(int8_t)); - // make sure MC is multiple of MR_INT8, and NC is multiple of NR + // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 if (MC == 0) { MC = MR_INT8; } else { @@ -745,52 +1104,106 @@ void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, int8_t alpha, const int8_t *A, } // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; if (NC == 0) { - NC = NR; + NC = NR_INT8; } else { int32_t nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; } // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); - packedC_int8 = static_cast( + packedC_int32 = static_cast( paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); int32_t mc, nc; for (int32_t j = 0; j < n; j += NC) { nc = s_min(n - j, NC); - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB_int8); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); for (int32_t i = 0; i < m; i += MC) { mc = s_min(m - i, MC); - // PackMatrixA_6r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); - PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); if (bias == nullptr) { + InnerKernel(mc, nc, alpha, packedA_int8, packedB_int8, beta, + packedC_int32, &C(i, j), ldc, relu); + } + } + } + + paddle_mobile::memory::Free(packedA_int8); + paddle_mobile::memory::Free(packedB_int8); + paddle_mobile::memory::Free(packedC_int32); + paddle_mobile::memory::Free(zero_int8); +} + +// 8 bits int matrix product (m*k x k*n) +void Gemm::Sgemm(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, + int32_t lda, const int8_t *B, int32_t ldb, float beta, + int8_t *C, int32_t ldc, bool relu, int32_t *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int32_t L1 = 32 * 1024; + int32_t L2 = 512 * 1024; + + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; + MC = L1 / (KC * sizeof(int8_t)); + NC = L2 / (KC * sizeof(int8_t)); + + // make sure MC is multiple of MR_INT8, and NC is multiple of NR_INT8 + if (MC == 0) { + MC = MR_INT8; + } else { + int32_t mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR_INT8; + } else { + int32_t nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + packedA_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); + packedB_int8 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); + packedC_int32 = static_cast( + paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC)); + zero_int8 = + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); + int32_t mc, nc; + for (int32_t j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, packedB_int8); + for (int32_t i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, packedA_int8); + if (bias != nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, nullptr); - } else { - InnerKernelWithBias(mc, nc, alpha, packedA_int8, packedB_int8, beta, - packedC_int8, &C(i, j), ldc, relu, bias + i); + packedC_int32, &C(i, j), ldc, relu, bias + i); } } } paddle_mobile::memory::Free(packedA_int8); paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(packedC_int32); paddle_mobile::memory::Free(zero_int8); } // 8 bits int write back -// C = alpha * A * B + beta * C -void Gemm::WriteWithAlphaBeta(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} -// C = A * B, 8位 int32_t +// C = A * B void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t ldc) { #if __ARM_NEON @@ -802,7 +1215,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, int32_t step = sizeof(int32_t) * ldc; int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 4)); int32_t volatile m = mc; - + int32_t volatile n = nc1; int32_t *volatile c_ptr, *volatile C_ptr; int32_t *C0, *c0; c_ptr = c; @@ -836,7 +1249,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, "end_mc_%=: \n\t" : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), [step] "r"(step), [step1] "r"(step1) : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); } @@ -854,20 +1267,254 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, #endif // __ARM_NEON } -// C = A * B + C -void Gemm::WriteWithAdd(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} +// C = A * B + bias, scale * C +void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale) { +#if __ARM_NEON +#if __aarch64__ +// TODO(wzzju) +#else + int32_t zero = 0; + int8_t narrow = -128; + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + int32_t step = sizeof(int8_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3)); + int32_t volatile m = mc; + int32_t volatile n = nc1; + int32_t *volatile c_ptr, *volatile bias_ptr; + int8_t *volatile C_ptr; + c_ptr = c; + C_ptr = C; + bias_ptr = bias; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.8 d24, %[narrow] \n\t" + "loop_mc_%=: \n\t" + "vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t" + "vdup.32 q13, d26[0] \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vqadd.s32 q1, q1, q13 \n\t" + "vcvt.f32.s32 q2, q0 \n\t" + "vcvt.f32.s32 q3, q1 \n\t" + "vmul.f32 q2, q2, q15 \n\t" + "vmul.f32 q3, q3, q15 \n\t" + "vcvt.s32.f32 q4, q2 \n\t" + "vcvt.s32.f32 q5, q3 \n\t" + "vqmovn.s32 d12, q4 \n\t" + "vqmovn.s32 d13, q5 \n\t" + "vqmovn.s16 d14, q6 \n\t" + "vceq.s8 d15, d14, d24 \n\t" + "vsub.s8 d14, d14, d15 \n\t" + "vst1.8 {d14}, [r6]! \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" -// C = A * B + bias -void Gemm::WriteWithAddV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias) {} -// C = A * B + C, relu(C) -void Gemm::WriteWithAddRelu(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc) {} + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), + [step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr), + [scale] "r"(scale), [zero] "r"(zero), [narrow] "r"(narrow) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q12", "q13", "q14", "q15"); + } -// C = A * B + bias, relu(C) -void Gemm::WriteWithAddReluV1(int32_t mc, int32_t nc, int32_t *c, int32_t *C, - int32_t ldc, int8_t *bias) {} + int32_t nc_left; + int32_t *c0; + int8_t *C0; + int32_t bias_v; + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 8 + i * ldc; + c0 = c_ptr + nc1 * 8 + i * NC; + bias_v = *(bias_ptr + i); + nc_left = _nc1; + asm volatile( + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.8 d24, %[narrow] \n\t" + "vdup.32 q13, %[bias_v] \n\t" + "cmp %[_nc1], #4 \n\t" + "blt less_four_%= \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vceq.s8 d9, d8, d24 \n\t" + "vsub.s8 d8, d8, d9 \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vst1.8 {d8[1]}, [%[C0]]! \n\t" + "vst1.8 {d8[2]}, [%[C0]]! \n\t" + "vst1.8 {d8[3]}, [%[C0]]! \n\t" + "subs %[_nc1], %[_nc1], #4 \n\t" + "beq process_over_%= \n\t" + "less_four_%=: \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vceq.s8 d9, d8, d24 \n\t" + "vsub.s8 d8, d8, d9 \n\t" + "loop_save_%=: \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vext.8 d8, d8, d8, #1 \n\t" + "subs %[_nc1], %[_nc1], #1 \n\t" + "bgt loop_save_%= \n\t" + "process_over_%=: \n\t" + : + : [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0), + [bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero), + [narrow] "r"(narrow) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q12", "q13", "q14", + "q15"); + } + } +#endif // __aarch64__ +#endif // __ARM_NEON +} + +// C = A * B + bias, scale * relu(C) +void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, + int32_t ldc, int32_t *bias, float scale) { +#if __ARM_NEON +#if __aarch64__ +// TODO(wzzju) +#else + int32_t zero = 0; + int32_t nc1 = nc >> 3; + int32_t _nc1 = nc & 7; + int32_t step = sizeof(int8_t) * ldc; + int32_t step1 = sizeof(int32_t) * (NC - (nc1 << 3)); + int32_t volatile m = mc; + int32_t volatile n = nc1; + int32_t *volatile c_ptr, *volatile bias_ptr; + int8_t *volatile C_ptr; + c_ptr = c; + C_ptr = C; + bias_ptr = bias; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "loop_mc_%=: \n\t" + "vld1.32 {d26[0]}, [%[bias_ptr]]!\n\t" + "vdup.32 q13, d26[0] \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vqadd.s32 q1, q1, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vmax.s32 q1, q1, q14 \n\t" + "vcvt.f32.s32 q2, q0 \n\t" + "vcvt.f32.s32 q3, q1 \n\t" + "vmul.f32 q2, q2, q15 \n\t" + "vmul.f32 q3, q3, q15 \n\t" + "vcvt.s32.f32 q4, q2 \n\t" + "vcvt.s32.f32 q5, q3 \n\t" + "vqmovn.s32 d12, q4 \n\t" + "vqmovn.s32 d13, q5 \n\t" + "vqmovn.s16 d14, q6 \n\t" + "vst1.8 {d14}, [r6]! \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(n), + [step] "r"(step), [step1] "r"(step1), [bias_ptr] "r"(bias_ptr), + [scale] "r"(scale), [zero] "r"(zero) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q13", "q14", "q15"); + } + + int32_t nc_left; + int32_t *c0; + int8_t *C0; + int32_t bias_v; + if (_nc1 != 0) { + for (int32_t i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 8 + i * ldc; + c0 = c_ptr + nc1 * 8 + i * NC; + bias_v = *(bias_ptr + i); + nc_left = _nc1; + asm volatile( + "vdup.32 q15, %[scale] \n\t" + "vdup.32 q14, %[zero] \n\t" + "vdup.32 q13, %[bias_v] \n\t" + "cmp %[_nc1], #4 \n\t" + "blt less_four_%= \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vst1.8 {d8[1]}, [%[C0]]! \n\t" + "vst1.8 {d8[2]}, [%[C0]]! \n\t" + "vst1.8 {d8[3]}, [%[C0]]! \n\t" + "subs %[_nc1], %[_nc1], #4 \n\t" + "beq process_over_%= \n\t" + "less_four_%=: \n\t" + "vld1.32 {q0}, [%[c0]]! \n\t" + "vqadd.s32 q0, q0, q13 \n\t" + "vmax.s32 q0, q0, q14 \n\t" + "vcvt.f32.s32 q1, q0 \n\t" + "vmul.f32 q1, q1, q15 \n\t" + "vcvt.s32.f32 q2, q1 \n\t" + "vqmovn.s32 d6, q2 \n\t" + "vqmovn.s16 d8, q3 \n\t" + "loop_save_%=: \n\t" + "vst1.8 {d8[0]}, [%[C0]]! \n\t" + "vext.8 d8, d8, d8, #1 \n\t" + "subs %[_nc1], %[_nc1], #1 \n\t" + "bgt loop_save_%= \n\t" + "process_over_%=: \n\t" + : + : [_nc1] "r"(nc_left), [C0] "r"(C0), [c0] "r"(c0), + [bias_v] "r"(bias_v), [scale] "r"(scale), [zero] "r"(zero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q13", "q14", "q15"); + } + } +#endif // __aarch64__ +#endif // __ARM_NEON +} } // namespace math } // namespace operators diff --git a/src/operators/math/gemm_omp_int8.cpp b/src/operators/math/gemm_omp_int8.cpp index 21256cccfcc6dcc647f34a2129616b70804d398f..d4d4c294934191ba6717716486bf857477d73b55 100644 --- a/src/operators/math/gemm_omp_int8.cpp +++ b/src/operators/math/gemm_omp_int8.cpp @@ -28,10 +28,10 @@ namespace operators { namespace math { // 8 bits int matrix product (m*k x k*n) -void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, +void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, const int8_t *A, int32_t lda, const int8_t *B, int32_t ldb, - int8_t beta, int32_t *C, int32_t ldc, bool relu, - int8_t *bias) { + float beta, int32_t *C, int32_t ldc, bool relu, + int32_t *bias) { #ifdef _OPENMP int32_t max_threads = omp_get_max_threads(); #else @@ -39,10 +39,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, #endif int32_t L1 = 64 / max_threads * 1024; - KC = k; + const int32_t k_complete = (k + 15) - ((k + 15) & 15); + KC = k_complete; zero_int8 = - static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * KC)); - memset(static_cast(zero_int8), 0, sizeof(int8_t) * KC); + static_cast(paddle_mobile::memory::Alloc(sizeof(int8_t) * k)); + memset(static_cast(zero_int8), 0, sizeof(int8_t) * k); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(int8_t)); @@ -54,14 +55,14 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, MC = (MC + MR_INT8 - 1) / MR_INT8 * MR_INT8; } // 补齐 B - NC = (n + NR - 1) / NR * NR; + NC = (n + NR_INT8 - 1) / NR_INT8 * NR_INT8; packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); #if __aarch64__ // TODO(wzzju) #else - PackMatrixB_omp_8c(KC, n, n % NR, B, ldb, packedB_int8); + PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); #endif packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC * max_threads)); @@ -69,11 +70,11 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, // 对 B 分块 NC = L1 / (KC * sizeof(int8_t)); if (NC == 0) { - NC = NR; + NC = NR_INT8; } else { int32_t nblock_num = (n + NC - 1) / NC; NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; + NC = (NC + NR_INT8 - 1) / NR_INT8 * NR_INT8; } // 补齐 A MC = (m + MR_INT8 - 1) / MR_INT8 * MR_INT8; @@ -83,12 +84,12 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, #if __aarch64__ // TODO(wzzju) #else - PackMatrixA_omp_4r(m, KC, m % MR_INT8, A, lda, packedA_int8); + PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); #endif packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC * max_threads)); } - packedC_int8 = static_cast( + packedC_int32 = static_cast( paddle_mobile::memory::Alloc(sizeof(int32_t) * MC * NC * max_threads)); if (m > n) { @@ -103,14 +104,19 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, int32_t mc; mc = s_min(m - i, MC); int8_t *local_A = packedA_int8 + MC * KC * local_threads; - int32_t *local_C = packedC_int8 + MC * NC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ // TODO(wzzju) #else - PackMatrixA_4r(mc, KC, mc % MR_INT8, &A(i, 0), lda, local_A); + PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); #endif - InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, local_C, - &C(i, 0), ldc, relu, bias + i); + // InnerKernelWithBias(mc, n, alpha, local_A, packedB_int8, beta, + // local_C, + // &C(i, 0), ldc, relu, bias + i); + if (bias == nullptr) { + InnerKernel(mc, n, alpha, local_A, packedB_int8, beta, local_C, + &C(i, 0), ldc, relu); + } } } else { #pragma omp parallel for @@ -123,20 +129,25 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, int8_t alpha, int32_t nc; nc = s_min(n - j, NC); int8_t *local_B = packedB_int8 + KC * NC * local_threads; - int32_t *local_C = packedC_int8 + MC * NC * local_threads; + int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ // TODO(wzzju) #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, local_B); + PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); #endif - InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, local_C, - &C(0, j), ldc, relu, bias); + // InnerKernelWithBias(m, nc, alpha, packedA_int8, local_B, beta, + // local_C, + // &C(0, j), ldc, relu, bias); + if (bias == nullptr) { + InnerKernel(m, nc, alpha, packedA_int8, local_B, beta, local_C, + &C(0, j), ldc, relu); + } } } paddle_mobile::memory::Free(packedA_int8); paddle_mobile::memory::Free(packedB_int8); - paddle_mobile::memory::Free(packedC_int8); + paddle_mobile::memory::Free(packedC_int32); paddle_mobile::memory::Free(zero_int8); } @@ -144,7 +155,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, int32_t ldb, int8_t *buffer) { const int32_t j_length = n - n_tail; #pragma omp parallel for - for (int32_t j = 0; j < j_length; j += NR) { + for (int32_t j = 0; j < j_length; j += 8) { int8_t *local_buffer = buffer + j * k; for (int32_t i = 0; i < k; ++i) { const int8_t *b0 = &B(i, j); @@ -179,7 +190,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, for (int32_t j = j_length; j < n; ++j) { *local_buffer++ = *b0++; } - for (int32_t j = n; j < j_length + NR; ++j) { + for (int32_t j = n; j < j_length + 8; ++j) { *local_buffer++ = 0; } } @@ -188,9 +199,9 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, const int8_t *A, int32_t lda, int8_t *buffer) { - const int i_length = m - m_tail; + const int32_t i_length = m - m_tail; #pragma omp parallel for - for (int32_t i = 0; i < i_length; i += MR_INT8) { + for (int32_t i = 0; i < i_length; i += 4) { const int8_t *a0 = A + i * lda; const int8_t *a1 = A + (i + 1) * lda; const int8_t *a2 = A + (i + 2) * lda; @@ -221,7 +232,7 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, default: break; } - for (int j = 0; j < k; ++j) { + for (int32_t j = 0; j < k; ++j) { *local_buffer++ = *a0++; *local_buffer++ = *a1++; *local_buffer++ = *a2++; @@ -230,6 +241,232 @@ void Gemm::PackMatrixA_omp_4r(int32_t m, int32_t k, int32_t m_tail, } } +// 8 bits int PackMatrixA_4r +void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, + const int8_t *A, int32_t lda, int8_t *buffer) { + const int32_t i_length = m - m_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t i = 0; i < i_length; i += 4) { + const int8_t *a0 = A + i * lda; + const int8_t *a1 = A + (i + 1) * lda; + const int8_t *a2 = A + (i + 2) * lda; + const int8_t *a3 = A + (i + 3) * lda; + int8_t *local_buffer = buffer + i * KC; + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + + if (m_tail != 0) { + const int8_t *a0 = &A(i_length, 0); + const int8_t *a1 = a0 + lda; + const int8_t *a2 = a0 + 2 * lda; + const int8_t *a3 = a0 + 3 * lda; + int8_t *local_buffer = buffer + i_length * KC; + switch (m_tail) { + case 1: + a1 = zero_int8; + case 2: + a2 = zero_int8; + case 3: + a3 = zero_int8; + break; + default: + break; + } + for (int32_t j = 0; j < k_count; ++j) { +#if __ARM_NEON +#if __aarch64__ + // TODO(wzzju) +#else + asm volatile( + "vld1.s8 {d0, d1}, [%[a0]]! \n\t" + "vld1.s8 {d2, d3}, [%[a1]]! \n\t" + "vld1.s8 {d4, d5}, [%[a2]]! \n\t" + "vld1.s8 {d6, d7}, [%[a3]]! \n\t" + "vst1.s8 {d0, d1}, [%[local_buffer]]! \n\t" + "vst1.s8 {d2, d3}, [%[local_buffer]]! \n\t" + "vst1.s8 {d4, d5}, [%[local_buffer]]! \n\t" + "vst1.s8 {d6, d7}, [%[local_buffer]]! \n\t" + : [local_buffer] "+r"(local_buffer), [a0] "+r"(a0), [a1] "+r"(a1), + [a2] "+r"(a2), [a3] "+r"(a3) + : + : "memory", "q0", "q1", "q2", "q3"); +#endif // __aarch64__ +#else + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a0++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a1++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a2++; + } + for (int32_t l = 0; l < 16; ++l) { + *local_buffer++ = *a3++; + } +#endif // __ARM_NEON + } + if (k_tail != 0) { + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a0++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a1++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a2++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *a3++; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + +// 8 bits int PackMatrixB +void Gemm::PackMatrixB_omp_2c_16(int32_t k, int32_t n, int32_t n_tail, + const int8_t *B, int32_t ldb, int8_t *buffer) { + const int32_t j_length = n - n_tail; + const int32_t k_count = k >> 4; + const int32_t k_tail = k & 15; +#pragma omp parallel for + for (int32_t j = 0; j < j_length; j += 2) { + int8_t *local_buffer = buffer + j * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j); + const int8_t *b1 = &B((i << 4), j + 1); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b1; + b1 += ldb; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j); + const int8_t *b1 = &B((k_count << 4), j + 1); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b1; + b1 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } + if (n_tail != 0) { + int8_t *local_buffer = buffer + j_length * KC; + for (int32_t i = 0; i < k_count; ++i) { + const int8_t *b0 = &B((i << 4), j_length); + for (int m = 0; m < 16; ++m) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int m = 0; m < 16; ++m) { + *local_buffer++ = 0; + } + } + if (k_tail != 0) { + const int8_t *b0 = &B((k_count << 4), j_length); + for (int32_t j = k_count << 4; j < k; ++j) { + *local_buffer++ = *b0; + b0 += ldb; + } + for (int32_t j = k; j < KC; ++j) { + *local_buffer++ = 0; + } + for (int32_t j = k_count << 4; j < KC; ++j) { + *local_buffer++ = 0; + } + } + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index b91242c1868398e4541c3727567a905e5b0c8714..9661b2d4c22ed49ef0c078fac0872c7643057430 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -28,7 +28,12 @@ template void matmul(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, T alpha, framework::Tensor *matrix_out, T beta, bool relu = false, - T *bias = nullptr); + float *bias = nullptr); + +void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu = false, + int32_t *bias = nullptr); template void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, diff --git a/src/operators/math/math_function_int8.cpp b/src/operators/math/math_function_int8.cpp index e02824b290ebc0080613e2ae2365626d79576c9e..ba0e5578cd32ff45620ddaa6feda9b31b2bcd68e 100644 --- a/src/operators/math/math_function_int8.cpp +++ b/src/operators/math/math_function_int8.cpp @@ -20,11 +20,10 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { namespace math { -template <> -void matmul(const framework::Tensor &matrix_a, bool trans_a, - const framework::Tensor &matrix_b, bool trans_b, - int8_t alpha, framework::Tensor *matrix_out, int8_t beta, - bool relu, int8_t *bias) { +void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, + const framework::Tensor &matrix_b, bool trans_b, float alpha, + framework::Tensor *matrix_out, float beta, bool relu, + int32_t *bias) { auto dim_a = matrix_a.dims(); auto dim_b = matrix_b.dims(); auto dim_out = matrix_out->dims(); @@ -52,21 +51,45 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, } #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + // TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead. + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #else - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #endif } else { #ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); + if (bias != nullptr) { + // TODO(wzzju): gemm.Sgemm_omp_with_bias, now use single thread instead. + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } else { + gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, + matrix_out->data(), N, relu, bias); + } #else - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), N, - relu, bias); + if (bias != nullptr) { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } else { + gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), + N, relu, bias); + } #endif } } diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 3593ecc9831f6bf627273b0abb5e75cf8a168dbf..7be35f81f2e7052e32a93531c325d716ed81c2ec 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -419,6 +419,8 @@ class ConvParam : public OpParam { EXEC_INVALID = 0, EXEC_GEMM_FLOAT, EXEC_DEPTHWISE3x3S1P1_FLOAT, + EXEC_DEPTHWISE3x3S2P0_FLOAT, + EXEC_DEPTHWISE3x3S2P1_FLOAT, EXEC_DEPTHWISE3x3_FLOAT, EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD5X5_FLOAT, @@ -2573,7 +2575,9 @@ class DequantizeParam : public OpParam { DequantizeParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { input_ = InputXFrom(inputs, scope); - output_ = OutFrom(outputs, scope); + if (outputs.count("Out")) { + output_ = OutFrom(outputs, scope); + } activation_scale_ = OpParam::GetVarValue("Scale", inputs, scope); // dequantization is performed as x = x / static_scale / online_scale if (HasAttr("weight_scale", attrs)) { @@ -2593,20 +2597,19 @@ class DequantizeParam : public OpParam { }; #endif -#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +#if defined(FUSION_DEQUANT_ADD_BN_OP) || \ + defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || \ + defined(FUSION_DEQUANT_BN_RELU_OP) || defined(FUSION_DEQUANT_BN_OP) template -class FusionDequantAddBNReluParam : public DequantizeParam { +class FusionDequantBNParam : public DequantizeParam { typedef typename DtypeTensorTrait::gtype GType; typedef typename DtypeTensorTrait::rtype RType; public: - FusionDequantAddBNReluParam(const VariableNameMap &inputs, - const VariableNameMap &outputs, - const AttributeMap &attrs, const Scope &scope) + FusionDequantBNParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) : DequantizeParam(inputs, outputs, attrs, scope) { - // element wise add params - axis_ = OpParam::GetAttr("axis", attrs); - bias_ = OpParam::InputYFrom(inputs, scope); // batch norm params bn_mean_ = OpParam::GetVarValue("BNMean", inputs, scope); bn_variance_ = OpParam::GetVarValue("BNVariance", inputs, scope); @@ -2614,21 +2617,83 @@ class FusionDequantAddBNReluParam : public DequantizeParam { bn_bias_ = OpParam::GetVarValue("BNBias", inputs, scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); // output - output_ = OpParam::OutFrom(outputs, scope); + if (outputs.count("Y")) { + this->output_ = OpParam::OutputYFrom(outputs, scope); + } } public: - // elementwise add - int axis_; - RType *bias_; // batch norm RType *bn_mean_; RType *bn_variance_; RType *bn_scale_; RType *bn_bias_; float epsilon_; - // output - RType *output_; +}; +#endif + +#if defined(FUSION_DEQUANT_ADD_BN_RELU_OP) || defined(FUSION_DEQUANT_ADD_BN_OP) +template +class FusionDequantAddBNParam : public FusionDequantBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantBNParam(inputs, outputs, attrs, scope) { + // element wise add params + axis_ = OpParam::GetAttr("axis", attrs); + bias_ = OpParam::InputYFrom(inputs, scope); + // output + if (outputs.count("Y")) { + this->output_ = OpParam::OutputYFrom(outputs, scope); + } + } + + public: + // elementwise add + int axis_; + RType *bias_; +}; +#endif + +#ifdef FUSION_DEQUANT_BN_RELU_OP +template +class FusionDequantBNReluParam : public FusionDequantBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantBNParam(inputs, outputs, attrs, scope) { + // output + if (outputs.count("Out")) { + this->output_ = OpParam::OutFrom(outputs, scope); + } + } +}; +#endif + +#ifdef FUSION_DEQUANT_ADD_BN_RELU_OP +template +class FusionDequantAddBNReluParam : public FusionDequantAddBNParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + FusionDequantAddBNReluParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) + : FusionDequantAddBNParam(inputs, outputs, attrs, scope) { + // output + if (outputs.count("Out")) { + this->output_ = OpParam::OutFrom(outputs, scope); + } + } }; #endif diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index 14da4ba284b5ac7b0660bd15de871fdf5ed04cdd..5ca0b40cfcb20786ad69d1bbfbaca103b3e426e3 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -28,7 +28,7 @@ limitations under the License. */ int main() { paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(8); + paddle_mobile.SetThreadNum(4); Tensor aa, bb, cc; auto aaptr = aa.mutable_data({m, k}); auto bbptr = bb.mutable_data({k, n}); @@ -44,10 +44,12 @@ int main() { ccptr[i] = 2; } - Tensor aa_int8, bb_int8, cc_int8; + Tensor aa_int8, bb_int8, cc_int32, cc_int8; auto aaptr_int8 = aa_int8.mutable_data({m, k}); auto bbptr_int8 = bb_int8.mutable_data({k, n}); - auto ccptr_int8 = cc_int8.mutable_data({m, n}); + auto ccptr_int32 = cc_int32.mutable_data({m, n}); + auto ccptr_int8 = cc_int8.mutable_data({m, n}); + int32_t* bias_data = new int32_t[m]; for (int i = 0; i < m * k; ++i) { aaptr_int8[i] = static_cast(2); @@ -56,7 +58,11 @@ int main() { bbptr_int8[i] = static_cast(2); } for (int i = 0; i < m * n; ++i) { - ccptr_int8[i] = static_cast(2); + ccptr_int32[i] = static_cast(2); + } + + for (int i = 0; i < m; ++i) { + bias_data[i] = 2; } // float @@ -76,22 +82,41 @@ int main() { auto time2 = time(); std::cout << "float gemm cost :" << time_diff(time1, time2) / 10 << "ms\n"; - // int8_t + // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, - static_cast(0), false, nullptr); + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, + static_cast(0), false, nullptr); } auto time3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( - aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, - static_cast(0), false, nullptr); + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, + static_cast(0), false, nullptr); } auto time4 = time(); std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; + // int8_t with bias&relu + // warm-up 10 times + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), true, &bias_data[0]); + } + auto time5 = time(); + for (int j = 0; j < 10; ++j) { + paddle_mobile::operators::math::matmul_int8( + aa_int8, false, bb_int8, false, static_cast(1), &cc_int8, + static_cast(0), true, &bias_data[0]); + } + auto time6 = time(); + std::cout << "int8_t gemm_with_bias_relu cost :" + << time_diff(time5, time6) / 10 << "ms\n"; + + delete[] bias_data; + return 0; } diff --git a/tools/op.cmake b/tools/op.cmake index 98a5ce437ae6520a4cc27f9fceeadaeb30ba6e99..e2254c3261d53d142e77f09c001d9cbebb5f85ff 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -249,7 +249,9 @@ if(NOT FOUND_MATCH) set(SUM_OP ON) set(QUANT_OP ON) set(DEQUANT_OP ON) - set(FUSION_DEQUANT_ADD_BN_RELU ON) + set(FUSION_DEQUANT_ADD_BN_OP ON) + set(FUSION_DEQUANT_BN_RELU_OP ON) + set(FUSION_DEQUANT_ADD_BN_RELU_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -451,10 +453,17 @@ endif() if (DEQUANT_OP) add_definitions(-DDEQUANT_OP) endif() -if (FUSION_DEQUANT_ADD_BN_RELU) +if (FUSION_DEQUANT_ADD_BN_OP) + add_definitions(-DFUSION_DEQUANT_ADD_BN_OP) +endif() +if (FUSION_DEQUANT_BN_RELU_OP) + add_definitions(-DFUSION_DEQUANT_BN_RELU_OP) +endif() +if (FUSION_DEQUANT_ADD_BN_RELU_OP) add_definitions(-DFUSION_DEQUANT_ADD_BN_RELU_OP) endif() + if (TANH_OP) add_definitions(-DTANH_OP) endif() @@ -467,3 +476,4 @@ endif() if (FUSION_DECONVADDRELU_OP) add_definitions(-DFUSION_DECONVADDRELU_OP) endif() +