提交 0314aa1e 编写于 作者: H HappyAngel 提交者: yiicy

[lite][arm] add conv+relu6/leakyRelu fusion (#2599)

上级 e5cbdb7b
...@@ -295,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input, ...@@ -295,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input,
hout, hout,
wout, wout,
false, false,
zero_ptr); zero_ptr,
nullptr);
} }
} else { } else {
for (int ci = 0; ci < oc_4; ++ci) { for (int ci = 0; ci < oc_4; ++ci) {
...@@ -341,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input, ...@@ -341,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input,
hout, hout,
wout, wout,
false, false,
zero_ptr); zero_ptr,
nullptr);
} }
} }
} }
...@@ -562,7 +564,8 @@ void conv_compute_2x2_3x3(const float* input, ...@@ -562,7 +564,8 @@ void conv_compute_2x2_3x3(const float* input,
hout, hout,
wout, wout,
false, false,
zero_ptr); zero_ptr,
nullptr);
} }
} else { } else {
for (int ci = 0; ci < oc_4; ++ci) { for (int ci = 0; ci < oc_4; ++ci) {
...@@ -602,7 +605,8 @@ void conv_compute_2x2_3x3(const float* input, ...@@ -602,7 +605,8 @@ void conv_compute_2x2_3x3(const float* input,
hout, hout,
wout, wout,
false, false,
zero_ptr); zero_ptr,
nullptr);
} }
} }
} }
...@@ -814,7 +818,8 @@ void conv_compute_2x2_3x3_small(const float* input, ...@@ -814,7 +818,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout, hout,
wout, wout,
false, false,
zero_ptr); zero_ptr,
nullptr);
} }
} else { } else {
for (int ci = 0; ci < oc_4; ++ci) { for (int ci = 0; ci < oc_4; ++ci) {
...@@ -854,7 +859,8 @@ void conv_compute_2x2_3x3_small(const float* input, ...@@ -854,7 +859,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout, hout,
wout, wout,
false, false,
zero_ptr); zero_ptr,
nullptr);
} }
} }
} }
......
...@@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const int threads = ctx->threads(); const int threads = ctx->threads();
int l2_size = ctx->llc_size() / sizeof(float); int l2_size = ctx->llc_size() / sizeof(float);
auto paddings = *param.paddings; auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int pad_h = paddings[0]; const int pad_h = paddings[0];
const int pad_w = paddings[2]; const int pad_w = paddings[2];
...@@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh, oh,
ow, ow,
flag_relu, flag_relu,
ptr_write); ptr_write,
&act_param);
} }
const float* weight_remain_ptr = weights + c_round_down * w_stride; const float* weight_remain_ptr = weights + c_round_down * w_stride;
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
...@@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data, ...@@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh, oh,
ow, ow,
flag_relu, flag_relu,
ptr_write); ptr_write,
&act_param);
} }
} }
} }
......
...@@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
//! prepack input to tmp buffer //! prepack input to tmp buffer
//! write output to tmp buffer //! write output to tmp buffer
auto paddings = *param.paddings; auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int threads = ctx->threads(); const int threads = ctx->threads();
int l2_size = ctx->llc_size() / sizeof(float); int l2_size = ctx->llc_size() / sizeof(float);
const int pad_w = paddings[2]; const int pad_w = paddings[2];
...@@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh, oh,
ow, ow,
flag_relu, flag_relu,
ptr_write); ptr_write,
&act_param);
} }
#pragma omp parallel for num_threads(threads) #pragma omp parallel for num_threads(threads)
...@@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data, ...@@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh, oh,
ow, ow,
flag_relu, flag_relu,
ptr_write); ptr_write,
&act_param);
} }
} }
} }
......
...@@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \ "ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \
"fadd v16.4s, v16.4s, v11.4s \n" \ "fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \
#define LEFT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \ "fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \ "fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \ "fmla v17.4s, v10.4s, %[w2].s[0] \n"
\
#define LEFT_RESULT_S2 \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \
\ \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
...@@ -280,17 +278,16 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -280,17 +278,16 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \ "ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\ \
"fadd v16.4s, v16.4s, v11.4s \n" \ "fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" "fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \
#define MID_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \ "fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \ "fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \ "fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\ \
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \ "ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \ "ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \ "ld1 {v18.4s}, [%[inptr1]] \n"
#define MID_RESULT_S2 \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \ "st1 {v16.4s}, [%[outptr0]], #16 \n" \
\ \
"fadd v17.4s, v17.4s, v13.4s \n" \ "fadd v17.4s, v17.4s, v13.4s \n" \
...@@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\ \
"fadd v16.4s, v16.4s, v11.4s \n" \ "fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" \ "fadd v16.4s, v16.4s, v12.4s \n" \
"ld1 {v1.4s}, [%[outptr1]] \n" "ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \
#define RIGHT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \ "fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \ "fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \ "fmla v17.4s, v10.4s, %[w2].s[2] \n"
\
#define RIGHT_RESULT_S2 \
"bif v16.16b, v0.16b, %[wmask].16b \n" \ "bif v16.16b, v0.16b, %[wmask].16b \n" \
\ \
"fadd v17.4s, v17.4s, v13.4s \n" \ "fadd v17.4s, v17.4s, v13.4s \n" \
...@@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"4: \n" "4: \n"
#define LEFT_RESULT_S2_RELU \ #define LEFT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" \ "fmax v16.4s, v16.4s, %[vzero].4s \n" \
\ \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \ "ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
...@@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"blt 1f \n" "blt 1f \n"
#define MID_RESULT_S2_RELU \ #define MID_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\ \
"fadd v17.4s, v17.4s, v13.4s \n" \ "fadd v17.4s, v17.4s, v13.4s \n" \
...@@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din, ...@@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"bne 2b \n" "bne 2b \n"
#define RIGHT_RESULT_S2_RELU \ #define RIGHT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \ "fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\ \
"fadd v17.4s, v17.4s, v13.4s \n" \ "fadd v17.4s, v17.4s, v13.4s \n" \
......
...@@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data, ...@@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* weights, const float* weights,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_3x3s2_depthwise_fp32(const float* i_data, void conv_3x3s2_depthwise_fp32(const float* i_data,
...@@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din, ...@@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din,
int pad, int pad,
bool flag_bias, bool flag_bias,
bool flag_relu, bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din, void conv_depthwise_3x3s2_fp32(const float* din,
......
...@@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din,
ARMContext* ctx, ARMContext* ctx,
const float* scale) { const float* scale) {
auto paddings = *param.paddings; auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int pad_h = paddings[0]; const int pad_h = paddings[0];
const int pad_w = paddings[2]; const int pad_w = paddings[2];
int stride = param.strides[1]; int stride = param.strides[1];
...@@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din,
pad, pad,
flag_bias, flag_bias,
flag_relu, flag_relu,
act_param,
ctx); ctx);
} else { } else {
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din), conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
...@@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast<const float*>(weights), reinterpret_cast<const float*>(weights),
bias, bias,
param, param,
act_param,
ctx); ctx);
} }
......
...@@ -67,7 +67,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -67,7 +67,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv"; VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) { no_dilation && pads_all_equal) {
/// winograd conv impl /// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv"; VLOG(3) << "invoking winograd conv";
......
...@@ -52,6 +52,34 @@ inline int ConvOutputSize(int input_size, ...@@ -52,6 +52,34 @@ inline int ConvOutputSize(int input_size,
return output_size; return output_size;
} }
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
bool ConvOpLite::InferShape() const { bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
......
...@@ -137,34 +137,6 @@ class ConvOpLite : public OpLite { ...@@ -137,34 +137,6 @@ class ConvOpLite : public OpLite {
std::string padding_algorithm_{""}; std::string padding_algorithm_{""};
}; };
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -59,6 +59,8 @@ DEFINE_bool(flag_bias, true, "with bias"); ...@@ -59,6 +59,8 @@ DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim; typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam; typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in, DDim compute_out_dim(const DDim& dim_in,
...@@ -118,6 +120,13 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -118,6 +120,13 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
param.dilations = std::make_shared<std::vector<int>>(dilas); param.dilations = std::make_shared<std::vector<int>>(dilas);
param.fuse_relu = flag_relu; param.fuse_relu = flag_relu;
param.groups = group; param.groups = group;
if (flag_relu) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type =
(paddle::lite_api::ActivationType)1; // 2-relu6 4-leakyrelu
param.activation_param = act_param;
}
param.output = new Tensor; param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat)); param.output->set_precision(PRECISION(kFloat));
...@@ -243,6 +252,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -243,6 +252,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<< pads[2] << ", " << pads[3] << pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls << ", threads: " << th << ", power_mode: " << cls
...@@ -255,6 +265,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -255,6 +265,7 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
<< ", pad: " << pads[0] << ", " << pads[1] << ", pad: " << pads[0] << ", " << pads[1]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls << ", threads: " << th << ", power_mode: " << cls
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册