diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index e55ebc4ace6c6a465bba7d2ee9e2d06d87dea347..fc6e6b03172021f0cca9d5c3d97ef79f54a150b6 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -61,16 +61,61 @@ bool ConvAddBNReluKernel::Init( param->SetNewBias(new_bias); InitBaseConvKernel(param); + + // try to use faster depthwise conv + switch (param->ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + const std::vector &paddings = param->Paddings(); + const std::vector &strides = param->Strides(); + if (paddings.size() == 2 && paddings[0] == paddings[1] && + strides.size() == 2 && strides[0] == strides[1]) { + int pad = paddings[0]; + int stride = strides[0]; + const int hin = param->Input()->dims()[2]; + if (pad == 0 && hin > 2) { + could_use_faster_depthwise_conv_ = true; + } else if (pad == 1) { + could_use_faster_depthwise_conv_ = true; + } + } + break; + } + + if (could_use_faster_depthwise_conv_) { + auto filter_data = param->Filter()->data(); + auto filter_dim = param->Filter()->dims(); + int len = 1; + for (int i = 0; i < filter_dim.size(); i++) { + len *= filter_dim[i]; + } + int batch = filter_dim[0]; + int step = len / batch; + for (int i = 0; i < batch; i++) { + for (int k = 0; k < step; k++) { + filter_data[i * step + k] = + filter_data[i * step + k] * new_scale_ptr[i]; + } + } + } + return true; } template <> void ConvAddBNReluKernel::Compute( const FusionConvAddBNReluParam ¶m) { + bool fusion_has_been_computed = false; switch (param.ExecMode()) { case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: - DepthwiseConv3x3(param); + if (could_use_faster_depthwise_conv_) { + FasterDepthwiseConv3x3_bias_relu(param, param.NewBias()->data(), + true); + fusion_has_been_computed = true; + } else { + DepthwiseConv3x3(param); + } break; case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); @@ -89,8 +134,10 @@ void ConvAddBNReluKernel::Compute( PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", param.ExecMode()); } - math::ScaleAddChannelWise(param.Output(), param.NewScale(), - param.NewBias(), param.Output()); + if (!fusion_has_been_computed) { + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); + } } template class ConvAddBNReluKernel; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.cpp b/src/operators/kernel/central-arm-func/conv_arm_func.cpp index b69ef51a9fb09c5fc38131549e6eb830e22d1987..9cd7cff4a4a9ccca41df932513c68db2baf8c6b6 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.cpp +++ b/src/operators/kernel/central-arm-func/conv_arm_func.cpp @@ -212,8 +212,8 @@ void DepthwiseConv3x3(const ConvParam ¶m) { } } -template <> -void DepthwiseConv3x3(const ConvParam ¶m) { +void FasterDepthwiseConv3x3_bias_relu(const ConvParam ¶m, + const float *bias, bool flag_relu) { const Tensor *input = param.Input(); const Tensor *filter = param.Filter(); const std::vector &paddings = param.Paddings(); @@ -222,52 +222,27 @@ void DepthwiseConv3x3(const ConvParam ¶m) { Tensor *output = param.Output(); output->mutable_data(); - if (paddings.size() == 2 && paddings[0] == paddings[1] && - strides.size() == 2 && strides[0] == strides[1]) { - int pad = paddings[0]; - int stride = strides[0]; - const float *din = input->data(); - float *dout = output->mutable_data(); - const float *weights = filter->data(); - const float *bias = nullptr; - const int num = input->dims()[0]; - const int chin = input->dims()[1]; - const int hin = input->dims()[2]; - const int win = input->dims()[3]; - const int chout = output->dims()[1]; - const int hout = output->dims()[2]; - const int wout = output->dims()[3]; - bool flag_relu = false; - bool flag_bias = bias != nullptr; - if (pad == 0 && hin > 2) { - math::depthwise::conv_depthwise_3x3p0(din, dout, num, chout, hout, wout, - chin, hin, win, weights, bias, - stride, flag_bias, flag_relu); - } else if (pad == 1) { - math::depthwise::conv_depthwise_3x3p1(din, dout, num, chout, hout, wout, - chin, hin, win, weights, bias, - stride, flag_bias, flag_relu); - } else { - GemmConv(param); - } - } else { - if (strides[0] == 1) { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - math::DepthwiseConv3x3S1(in_batch, *filter, paddings, - &out_batch); - } - } else if (strides[0] == 2) { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - math::DepthwiseConv3x3S2(in_batch, *filter, paddings, - &out_batch); - } - } else { - GemmConv(param); - } + int pad = paddings[0]; + int stride = strides[0]; + const float *din = input->data(); + float *dout = output->mutable_data(); + const float *weights = filter->data(); + const int num = input->dims()[0]; + const int chin = input->dims()[1]; + const int hin = input->dims()[2]; + const int win = input->dims()[3]; + const int chout = output->dims()[1]; + const int hout = output->dims()[2]; + const int wout = output->dims()[3]; + bool flag_bias = bias != nullptr; + if (pad == 0 && hin > 2) { + math::depthwise::conv_depthwise_3x3p0(din, dout, num, chout, hout, wout, + chin, hin, win, weights, bias, stride, + flag_bias, flag_relu); + } else if (pad == 1) { + math::depthwise::conv_depthwise_3x3p1(din, dout, num, chout, hout, wout, + chin, hin, win, weights, bias, stride, + flag_bias, flag_relu); } } diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 2fa06f7cee1ffdf97448964da04e95ddeb27aedf..8cd99aba4603e77ac95f60e90fd0cc28415837c6 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -44,6 +44,9 @@ void DepthwiseConv5x5(const ConvParam ¶m); template void SlidingwindowConv3x3(const ConvParam ¶m); +void FasterDepthwiseConv3x3_bias_relu(const ConvParam ¶m, + const float *bias, bool flag_relu); + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/conv_add_bn_relu_kernel.h b/src/operators/kernel/conv_add_bn_relu_kernel.h index 919c66106eda1159f14c40e768325f1f5dcf5ff6..d7ec6d2933e6f3ecfafd28e18a5b9b7633399e8d 100644 --- a/src/operators/kernel/conv_add_bn_relu_kernel.h +++ b/src/operators/kernel/conv_add_bn_relu_kernel.h @@ -36,6 +36,9 @@ class ConvAddBNReluKernel public: void Compute(const FusionConvAddBNReluParam ¶m); bool Init(FusionConvAddBNReluParam *param); + + private: + bool could_use_faster_depthwise_conv_ = false; }; } // namespace operators