提交 c5cd78ab 编写于 作者: H HappyAngel 提交者: Xiaoyang LI

speedup fp32 depthwise conv

* update con_dw

* update

* add conv_depthwise_3x3s1.cc and conv_depthwise_3x3s2.cc

* add conv_depthwise_3x3s1_fp32 and conv_depthwise_3x3s2_fp32

* add new conv_dw

* only support conv_dw pad=0, 1

* add conv_dw_s1 conv_dw_s2 fp32

*     //conv2_func _impl2{nullptr};
update conv_dw, add conv_3x3s1 and conv_3x3s2, pad=[0,1]

* fix format, test=develop

* fix formmat, test=develop
上级 95372548
...@@ -78,6 +78,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -78,6 +78,8 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc
conv_depthwise_3x3p0.cc conv_depthwise_3x3p0.cc
conv_depthwise_3x3p1.cc conv_depthwise_3x3p1.cc
conv_depthwise_3x3s1.cc
conv_depthwise_3x3s2.cc
conv_winograd_3x3.cc conv_winograd_3x3.cc
conv_impl.cc conv_impl.cc
softmax.cc softmax.cc
......
...@@ -25,7 +25,6 @@ namespace paddle { ...@@ -25,7 +25,6 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void conv_3x3s1_depthwise_fp32(const float* i_data, void conv_3x3s1_depthwise_fp32(const float* i_data,
float* o_data, float* o_data,
int bs, int bs,
......
...@@ -53,6 +53,38 @@ void conv_3x3s2_depthwise_fp32(const float* i_data, ...@@ -53,6 +53,38 @@ void conv_3x3s2_depthwise_fp32(const float* i_data,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1_fp32(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din,
float* dout,
int num,
int ch_out,
int h_out,
int w_out,
int ch_in,
int h_in,
int w_in,
const float* weights,
const float* bias,
int pad,
bool flag_bias,
bool flag_relu,
ARMContext* ctx);
void conv_depthwise_3x3p0_fp32(const float* din, void conv_depthwise_3x3p0_fp32(const float* din,
float* dout, float* dout,
int num, int num,
......
此差异已折叠。
此差异已折叠。
...@@ -562,9 +562,19 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -562,9 +562,19 @@ void conv_depthwise_3x3_fp32(const void* din,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx, ARMContext* ctx,
const float* scale) { const float* scale) {
const int pad_h = param.paddings[0];
const int pad_w = param.paddings[1];
if (pad_w != pad_h) {
LOG(FATAL) << "fp32 depthwise conv3x3 pad_w: " << pad_w
<< ", pad_h: " << pad_h << " must be equal";
return;
}
int stride = param.strides[1]; int stride = param.strides[1];
if (stride == 1) { int pad = pad_w;
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din), bool flag_relu = param.fuse_relu;
bool flag_bias = param.bias != nullptr;
if (stride == 1 && pad < 2) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
ch_out, ch_out,
...@@ -575,10 +585,12 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -575,10 +585,12 @@ void conv_depthwise_3x3_fp32(const void* din,
w_in, w_in,
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
param, pad,
flag_bias,
flag_relu,
ctx); ctx);
} else if (stride == 2) { } else if (stride == 2 && pad < 2) { // support pad = [0, 1]
conv_3x3s2_depthwise_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
ch_out, ch_out,
...@@ -589,10 +601,13 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -589,10 +601,13 @@ void conv_depthwise_3x3_fp32(const void* din,
w_in, w_in,
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
param, pad,
flag_bias,
flag_relu,
ctx); ctx);
} else { } else {
LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride << " unsupported"; LOG(FATAL) << "fp32 depthwise conv3x3 stride: " << stride
<< " or pad(<2): " << pad << " unsupported";
} }
#if 0 #if 0
if (pad == 1) { if (pad == 1) {
......
...@@ -37,12 +37,13 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -37,12 +37,13 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto kh = w_dims[2]; auto kh = w_dims[2];
auto cround = ROUNDUP(oc, cblock); auto cround = ROUNDUP(oc, cblock);
weights_.Resize({cround, 1, kh, kw}); weights_.Resize({cround, 1, kh, kw});
auto w_data = weights_.mutable_data<float>(); // auto w_data = weights_.mutable_data<float>();
auto w_data_in = param.filter->data<float>(); // auto w_data_in = param.filter->data<float>();
lite::arm::math::conv_trans_weights_numc( // lite::arm::math::conv_trans_weights_numc(
w_data_in, w_data, oc, 1, cblock, kh * kw); // w_data_in, w_data, oc, 1, cblock, kh * kw);
impl_ = lite::arm::math::conv_depthwise_3x3_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_fp32;
flag_trans_weights_ = true; flag_trans_weights_ = false;
// flag_trans_weights_ = true;
} else if (kw == 5) { } else if (kw == 5) {
VLOG(5) << "invoke 5x5 dw conv fp32"; VLOG(5) << "invoke 5x5 dw conv fp32";
impl_ = lite::arm::math::conv_depthwise_5x5_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_fp32;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册