提交 f9e1a5e6 编写于 作者: S StarryRain 提交者: Yanzhan Yang

improve the performance of gemm1*1s1_conv_add and gemm1*1s1_conv_add_bn_relu (#1754)

* add CPU_ARCH info, improve the performance of GEMM1*1s1

* improve the performance of gemm1*1s1_conv_add and gemm1*1s1_conv_add_bn_relu

* improve the performance of slidingwindow_bn_relu,slidingwindow_add,slidingwindow_add_bn_relu,gemm1*1s1_bn_relu,gemm1*1s1_add_relu
上级 8c89ef6c
...@@ -16,9 +16,12 @@ limitations under the License. */ ...@@ -16,9 +16,12 @@ limitations under the License. */
#include "operators/kernel/conv_add_bn_relu_kernel.h" #include "operators/kernel/conv_add_bn_relu_kernel.h"
#include <cmath> #include <cmath>
#include "framework/context.h"
#include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/element_wise.h" #include "operators/math/element_wise.h"
#include "operators/math/gemm/gemm1x1s1.h"
#include "operators/math/slidingwindow_utils.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -64,6 +67,13 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -64,6 +67,13 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
// try to use faster depthwise conv // try to use faster depthwise conv
switch (param->ExecMode()) { switch (param->ExecMode()) {
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
use_slidingwindow_add_bn_relu = true;
break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
use_gemm_add_bn_relu = true;
break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
const std::vector<int> &paddings = param->Paddings(); const std::vector<int> &paddings = param->Paddings();
...@@ -84,7 +94,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -84,7 +94,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
break; break;
} }
if (could_use_faster_depthwise_conv_) { if (could_use_faster_depthwise_conv_ || use_gemm_add_bn_relu ||
use_slidingwindow_add_bn_relu) {
auto filter_data = param->Filter()->data<float>(); auto filter_data = param->Filter()->data<float>();
auto filter_dim = param->Filter()->dims(); auto filter_dim = param->Filter()->dims();
int len = 1; int len = 1;
...@@ -99,6 +110,16 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -99,6 +110,16 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
filter_data[i * step + k] * new_scale_ptr[i]; filter_data[i * step + k] * new_scale_ptr[i];
} }
} }
if (use_gemm_add_bn_relu) {
ARMArch arch = framework::CPUContext::Context()->get_arch();
math::gemm1x1s1_transform_weight(*param->Filter(), *param->Output(),
param->transformed_filter_,
param->groups, arch);
}
if (use_slidingwindow_add_bn_relu) {
math::slidingwindow_transform_weight<float>(*param->Filter(),
param->transformed_filter_);
}
} }
return true; return true;
...@@ -129,11 +150,15 @@ void ConvAddBNReluKernel<CPU, float>::Compute( ...@@ -129,11 +150,15 @@ void ConvAddBNReluKernel<CPU, float>::Compute(
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); fusion_has_been_computed = true;
GemmConv1x1s1<float, float>(param, param.NewBias()->data<float>(), true,
true);
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, param.NewBias()->data<float>(),
true, true);
fusion_has_been_computed = true;
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
......
...@@ -30,6 +30,7 @@ bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) { ...@@ -30,6 +30,7 @@ bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) { void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
bool fusion_has_been_computed = false;
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
...@@ -45,16 +46,21 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) { ...@@ -45,16 +46,21 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); fusion_has_been_computed = true;
GemmConv1x1s1<float, float>(param, param.Bias()->data<float>(), true,
false);
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, param.Bias()->data<float>(),
true, false);
fusion_has_been_computed = true;
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode()); param.ExecMode());
} }
if (!fusion_has_been_computed) {
if (param.Bias()->dims() == param.Output()->dims()) { if (param.Bias()->dims() == param.Output()->dims()) {
math::AddElememtWise<IDENTITY>(param.Output(), param.Bias(), param.Axis(), math::AddElememtWise<IDENTITY>(param.Output(), param.Bias(), param.Axis(),
param.Output()); param.Output());
...@@ -62,6 +68,7 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) { ...@@ -62,6 +68,7 @@ void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(), math::AddChannelWise<IDENTITY>(param.Output(), param.Bias(),
param.Output()); param.Output());
} }
}
} }
template class ConvAddKernel<CPU, float>; template class ConvAddKernel<CPU, float>;
......
...@@ -31,6 +31,7 @@ bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam<CPU> *param) { ...@@ -31,6 +31,7 @@ bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam<CPU> *param) {
template <> template <>
void ConvAddReluKernel<CPU, float>::Compute( void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam<CPU> &param) { const FusionConvAddReluParam<CPU> &param) {
bool fusion_has_been_computed = false;
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
...@@ -46,22 +47,26 @@ void ConvAddReluKernel<CPU, float>::Compute( ...@@ -46,22 +47,26 @@ void ConvAddReluKernel<CPU, float>::Compute(
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); fusion_has_been_computed = true;
GemmConv1x1s1<float, float>(param, param.Bias()->data<float>(), true,
true);
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, nullptr, false, false);
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode()); param.ExecMode());
} }
if (!fusion_has_been_computed) {
if (param.Bias()->dims() == param.Output()->dims()) { if (param.Bias()->dims() == param.Output()->dims()) {
math::AddElememtWise<RELU>(param.Output(), param.Bias(), param.Axis(), math::AddElememtWise<RELU>(param.Output(), param.Bias(), param.Axis(),
param.Output()); param.Output());
} else { } else {
math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output()); math::AddChannelWise<RELU>(param.Output(), param.Bias(), param.Output());
} }
}
} }
template class ConvAddReluKernel<CPU, float>; template class ConvAddReluKernel<CPU, float>;
......
...@@ -65,11 +65,11 @@ void ConvBNAddReluKernel<CPU, float>::Compute( ...@@ -65,11 +65,11 @@ void ConvBNAddReluKernel<CPU, float>::Compute(
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); GemmConv1x1s1<float, float>(param, nullptr, false, false);
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, nullptr, false, false);
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
......
...@@ -16,9 +16,12 @@ limitations under the License. */ ...@@ -16,9 +16,12 @@ limitations under the License. */
#include "operators/kernel/conv_bn_relu_kernel.h" #include "operators/kernel/conv_bn_relu_kernel.h"
#include <cmath> #include <cmath>
#include "framework/context.h"
#include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/arm/convolution/conv_common.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include "operators/math/element_wise.h" #include "operators/math/element_wise.h"
#include "operators/math/gemm/gemm1x1s1.h"
#include "operators/math/slidingwindow_utils.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -57,12 +60,50 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) { ...@@ -57,12 +60,50 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
InitBaseConvKernel(param); InitBaseConvKernel(param);
switch (param->ExecMode()) {
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
use_slidingwindow_bn_relu = true;
break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
use_gemm_bn_relu = true;
break;
}
if (use_gemm_bn_relu || use_slidingwindow_bn_relu) {
auto filter_data = param->Filter()->data<float>();
auto filter_dim = param->Filter()->dims();
int len = 1;
for (int i = 0; i < filter_dim.size(); i++) {
len *= filter_dim[i];
}
int batch = filter_dim[0];
int step = len / batch;
for (int i = 0; i < batch; i++) {
for (int k = 0; k < step; k++) {
filter_data[i * step + k] =
filter_data[i * step + k] * new_scale_ptr[i];
}
}
if (use_gemm_bn_relu) {
ARMArch arch = framework::CPUContext::Context()->get_arch();
math::gemm1x1s1_transform_weight(*param->Filter(), *param->Output(),
param->transformed_filter_,
param->groups, arch);
}
if (use_slidingwindow_bn_relu) {
math::slidingwindow_transform_weight<float>(*param->Filter(),
param->transformed_filter_);
}
}
return true; return true;
} }
template <> template <>
void ConvBNReluKernel<CPU, float>::Compute( void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam<CPU> &param) { const FusionConvBNReluParam<CPU> &param) {
bool fusion_has_been_computed = false;
switch (param.ExecMode()) { switch (param.ExecMode()) {
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE3x3S2_FLOAT:
...@@ -78,18 +119,24 @@ void ConvBNReluKernel<CPU, float>::Compute( ...@@ -78,18 +119,24 @@ void ConvBNReluKernel<CPU, float>::Compute(
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); GemmConv1x1s1<float, float>(param, param.NewBias()->data<float>(), true,
true);
fusion_has_been_computed = true;
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, param.NewBias()->data<float>(),
true, true);
fusion_has_been_computed = true;
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
param.ExecMode()); param.ExecMode());
} }
if (!fusion_has_been_computed) {
math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(), math::ScaleAddChannelWise<RELU>(param.Output(), param.NewScale(),
param.NewBias(), param.Output()); param.NewBias(), param.Output());
}
} }
template class ConvBNReluKernel<CPU, float>; template class ConvBNReluKernel<CPU, float>;
......
...@@ -55,11 +55,11 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -55,11 +55,11 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); GemmConv1x1s1<float, float>(param, nullptr, false, false);
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, nullptr, false, false);
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
......
...@@ -46,11 +46,11 @@ void ConvReluKernel<CPU, float>::Compute( ...@@ -46,11 +46,11 @@ void ConvReluKernel<CPU, float>::Compute(
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); GemmConv1x1s1<float, float>(param, nullptr, false, false);
break; break;
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S1_FLOAT:
case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT: case ConvParam<CPU>::EXEC_SLIDINGWINDOW3x3S2_FLOAT:
SlidingwindowConv3x3<float, float>(param); SlidingwindowConv3x3<float, float>(param, nullptr, false, false);
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
......
...@@ -77,7 +77,7 @@ void DWConvBNReluKernel<CPU, float>::Compute( ...@@ -77,7 +77,7 @@ void DWConvBNReluKernel<CPU, float>::Compute(
GemmConv<float, float>(param); GemmConv<float, float>(param);
break; break;
case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT: case ConvParam<CPU>::EXEC_GEMM1x1s1_FLOAT:
GemmConv1x1s1<float, float>(param); GemmConv1x1s1<float, float>(param, nullptr, false, false);
break; break;
default: default:
PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d",
......
...@@ -140,7 +140,8 @@ void GemmConv(const ConvParam<CPU> &param) { ...@@ -140,7 +140,8 @@ void GemmConv(const ConvParam<CPU> &param) {
} }
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void GemmConv1x1s1(const ConvParam<CPU> &param) { void GemmConv1x1s1(const ConvParam<CPU> &param, const float *bias, bool is_bias,
bool is_relu) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.transformed_filter_; Tensor filter = *param.transformed_filter_;
Tensor *output = param.Output(); Tensor *output = param.Output();
...@@ -156,8 +157,6 @@ void GemmConv1x1s1(const ConvParam<CPU> &param) { ...@@ -156,8 +157,6 @@ void GemmConv1x1s1(const ConvParam<CPU> &param) {
const int hout = output->dims()[2]; const int hout = output->dims()[2];
const int wout = output->dims()[3]; const int wout = output->dims()[3];
const float *weights = filter.mutable_data<float>(); const float *weights = filter.mutable_data<float>();
const float *bias = nullptr;
int channel_size_out = wout * hout; int channel_size_out = wout * hout;
int channel_size_in = win * hin; int channel_size_in = win * hin;
const int group = param.Groups(); const int group = param.Groups();
...@@ -165,8 +164,16 @@ void GemmConv1x1s1(const ConvParam<CPU> &param) { ...@@ -165,8 +164,16 @@ void GemmConv1x1s1(const ConvParam<CPU> &param) {
const int n = hout * wout; const int n = hout * wout;
const int k = chin / group; const int k = chin / group;
bool flag_relu = false; bool flag_relu = true;
bool flag_bias = false; bool flag_bias = true;
if (!is_bias) {
bias = nullptr;
flag_bias = false;
}
if (!is_relu) {
flag_relu = false;
}
ARMArch arch = framework::CPUContext::Context()->get_arch(); ARMArch arch = framework::CPUContext::Context()->get_arch();
int hblock = math::get_hblock(arch); int hblock = math::get_hblock(arch);
...@@ -322,7 +329,8 @@ void DepthwiseConv5x5(const ConvParam<CPU> &param) { ...@@ -322,7 +329,8 @@ void DepthwiseConv5x5(const ConvParam<CPU> &param) {
} }
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void SlidingwindowConv3x3(const ConvParam<CPU> &param) { void SlidingwindowConv3x3(const ConvParam<CPU> &param, const float *bias,
bool is_bias, bool is_relu) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
const Tensor *filter = param.Filter(); const Tensor *filter = param.Filter();
const std::vector<int> &paddings = param.Paddings(); const std::vector<int> &paddings = param.Paddings();
...@@ -334,23 +342,29 @@ void SlidingwindowConv3x3(const ConvParam<CPU> &param) { ...@@ -334,23 +342,29 @@ void SlidingwindowConv3x3(const ConvParam<CPU> &param) {
// math::SlidingwindowConv3x3s1<Itype, Otype>(input, filter, paddings, // math::SlidingwindowConv3x3s1<Itype, Otype>(input, filter, paddings,
// output); // output);
math::SlidingwindowConv3x3s1Faster<Itype, Otype>( math::SlidingwindowConv3x3s1Faster<Itype, Otype>(
input, param.transformed_filter_, paddings, output); input, param.transformed_filter_, paddings, output, bias, is_bias,
is_relu);
} else if (strides[0] == 2) { } else if (strides[0] == 2) {
// math::SlidingwindowConv3x3s2<Itype, Otype>(input, filter, paddings, // math::SlidingwindowConv3x3s2<Itype, Otype>(input, filter, paddings,
// output); // output);
math::SlidingwindowConv3x3s2Faster<Itype, Otype>( math::SlidingwindowConv3x3s2Faster<Itype, Otype>(
input, param.transformed_filter_, paddings, output); input, param.transformed_filter_, paddings, output, bias, is_bias,
is_relu);
} else { } else {
GemmConv<Itype, Otype>(param); GemmConv<Itype, Otype>(param);
} }
} }
template void GemmConv<float, float>(const ConvParam<CPU> &param); template void GemmConv<float, float>(const ConvParam<CPU> &param);
template void GemmConv1x1s1<float, float>(const ConvParam<CPU> &param); template void GemmConv1x1s1<float, float>(const ConvParam<CPU> &param,
const float *bias, bool is_bias,
bool is_relu);
template void WinogradConv3x3<8, 3>(const ConvParam<CPU> &param); template void WinogradConv3x3<8, 3>(const ConvParam<CPU> &param);
template void DepthwiseConv3x3<float, float>(const ConvParam<CPU> &param); template void DepthwiseConv3x3<float, float>(const ConvParam<CPU> &param);
template void DepthwiseConv5x5<float, float>(const ConvParam<CPU> &param); template void DepthwiseConv5x5<float, float>(const ConvParam<CPU> &param);
template void SlidingwindowConv3x3<float, float>(const ConvParam<CPU> &param); template void SlidingwindowConv3x3<float, float>(const ConvParam<CPU> &param,
const float *bias,
bool is_bias, bool is_relu);
template void GemmConv<int8_t, int32_t>(const ConvParam<CPU> &param); template void GemmConv<int8_t, int32_t>(const ConvParam<CPU> &param);
#ifndef __aarch64__ #ifndef __aarch64__
......
...@@ -33,7 +33,8 @@ template <typename Itype, typename Otype> ...@@ -33,7 +33,8 @@ template <typename Itype, typename Otype>
void GemmConv(const ConvParam<CPU> &param); void GemmConv(const ConvParam<CPU> &param);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void GemmConv1x1s1(const ConvParam<CPU> &param); void GemmConv1x1s1(const ConvParam<CPU> &param, const float *bias, bool is_bias,
bool is_relu);
template <int tile, int kernel> template <int tile, int kernel>
void WinogradConv3x3(const ConvParam<CPU> &param); void WinogradConv3x3(const ConvParam<CPU> &param);
...@@ -45,7 +46,8 @@ template <typename Itype, typename Otype> ...@@ -45,7 +46,8 @@ template <typename Itype, typename Otype>
void DepthwiseConv5x5(const ConvParam<CPU> &param); void DepthwiseConv5x5(const ConvParam<CPU> &param);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void SlidingwindowConv3x3(const ConvParam<CPU> &param); void SlidingwindowConv3x3(const ConvParam<CPU> &param, const float *bias,
bool is_bias, bool is_relu);
void FasterDepthwiseConv3x3_bias_relu(const ConvParam<CPU> &param, void FasterDepthwiseConv3x3_bias_relu(const ConvParam<CPU> &param,
const float *bias, bool flag_relu); const float *bias, bool flag_relu);
......
...@@ -39,6 +39,8 @@ class ConvAddBNReluKernel ...@@ -39,6 +39,8 @@ class ConvAddBNReluKernel
private: private:
bool could_use_faster_depthwise_conv_ = false; bool could_use_faster_depthwise_conv_ = false;
bool use_gemm_add_bn_relu = false;
bool use_slidingwindow_add_bn_relu = false;
}; };
} // namespace operators } // namespace operators
......
...@@ -36,6 +36,10 @@ class ConvBNReluKernel ...@@ -36,6 +36,10 @@ class ConvBNReluKernel
public: public:
void Compute(const FusionConvBNReluParam<DeviceType> &param); void Compute(const FusionConvBNReluParam<DeviceType> &param);
bool Init(FusionConvBNReluParam<DeviceType> *param); bool Init(FusionConvBNReluParam<DeviceType> *param);
private:
bool use_gemm_bn_relu = false;
bool use_slidingwindow_bn_relu = false;
}; };
} // namespace operators } // namespace operators
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "operators/math/gemm/gemm1x1s1.h" #include "operators/math/gemm/gemm1x1s1.h"
#include <arm_neon.h> #include <arm_neon.h>
#include "framework/context.h" #include "framework/context.h"
#include "iostream"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -3710,12 +3710,15 @@ void SlidingwindowConv3x3s2<float, float>(const framework::Tensor *input, ...@@ -3710,12 +3710,15 @@ void SlidingwindowConv3x3s2<float, float>(const framework::Tensor *input,
template <> template <>
void SlidingwindowConv3x3s1Faster<float, float>( void SlidingwindowConv3x3s1Faster<float, float>(
const framework::Tensor *input, framework::Tensor *filter, const framework::Tensor *input, framework::Tensor *filter,
const std::vector<int> &paddings, framework::Tensor *output) { const std::vector<int> &paddings, framework::Tensor *output,
const float *bias, bool is_bias, bool is_relu) {
const float *din = input->data<float>(); const float *din = input->data<float>();
float *dout = output->mutable_data<float>(); float *dout = output->mutable_data<float>();
const float *weights = filter->mutable_data<float>(); const float *weights = filter->mutable_data<float>();
const float *bias = nullptr; if (!is_bias) {
bool relu = false; bias = nullptr;
}
bool relu = is_relu;
const int num = input->dims()[0]; const int num = input->dims()[0];
const int chin = input->dims()[1]; const int chin = input->dims()[1];
const int hin = input->dims()[2]; const int hin = input->dims()[2];
...@@ -4623,12 +4626,15 @@ void SlidingwindowConv3x3s1Faster<float, float>( ...@@ -4623,12 +4626,15 @@ void SlidingwindowConv3x3s1Faster<float, float>(
template <> template <>
void SlidingwindowConv3x3s2Faster<float, float>( void SlidingwindowConv3x3s2Faster<float, float>(
const framework::Tensor *input, framework::Tensor *filter, const framework::Tensor *input, framework::Tensor *filter,
const std::vector<int> &paddings, framework::Tensor *output) { const std::vector<int> &paddings, framework::Tensor *output,
const float *bias, bool is_bias, bool is_relu) {
const float *din = input->data<float>(); const float *din = input->data<float>();
float *dout = output->mutable_data<float>(); float *dout = output->mutable_data<float>();
const float *weights = filter->mutable_data<float>(); const float *weights = filter->mutable_data<float>();
const float *bias = nullptr; if (!is_bias) {
bool relu = false; bias = nullptr;
}
bool relu = is_relu;
const int num = input->dims()[0]; const int num = input->dims()[0];
const int chin = input->dims()[1]; const int chin = input->dims()[1];
const int hin = input->dims()[2]; const int hin = input->dims()[2];
......
...@@ -37,13 +37,15 @@ template <typename Itype, typename Otype> ...@@ -37,13 +37,15 @@ template <typename Itype, typename Otype>
void SlidingwindowConv3x3s1Faster(const framework::Tensor *input, void SlidingwindowConv3x3s1Faster(const framework::Tensor *input,
framework::Tensor *filter, framework::Tensor *filter,
const std::vector<int> &paddings, const std::vector<int> &paddings,
framework::Tensor *output); framework::Tensor *output, const float *bias,
bool is_bias, bool is_relu);
template <typename Itype, typename Otype> template <typename Itype, typename Otype>
void SlidingwindowConv3x3s2Faster(const framework::Tensor *input, void SlidingwindowConv3x3s2Faster(const framework::Tensor *input,
framework::Tensor *filter, framework::Tensor *filter,
const std::vector<int> &paddings, const std::vector<int> &paddings,
framework::Tensor *output); framework::Tensor *output, const float *bias,
bool is_bias, bool is_relu);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册