提交 0314aa1e 编写于 作者: H HappyAngel 提交者: yiicy

[lite][arm] add conv+relu6/leakyRelu fusion (#2599)

上级 e5cbdb7b
......@@ -295,7 +295,8 @@ void conv_compute_6x6_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
......@@ -341,7 +342,8 @@ void conv_compute_6x6_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
}
}
......@@ -562,7 +564,8 @@ void conv_compute_2x2_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
......@@ -602,7 +605,8 @@ void conv_compute_2x2_3x3(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
}
}
......@@ -814,7 +818,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
} else {
for (int ci = 0; ci < oc_4; ++ci) {
......@@ -854,7 +859,8 @@ void conv_compute_2x2_3x3_small(const float* input,
hout,
wout,
false,
zero_ptr);
zero_ptr,
nullptr);
}
}
}
......
......@@ -76,6 +76,7 @@ void conv_3x3s1_direct_fp32(const float* i_data,
const int threads = ctx->threads();
int l2_size = ctx->llc_size() / sizeof(float);
auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
......@@ -469,7 +470,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
const float* weight_remain_ptr = weights + c_round_down * w_stride;
#pragma omp parallel for num_threads(threads)
......@@ -780,7 +782,8 @@ void conv_3x3s1_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
}
}
......
......@@ -75,6 +75,7 @@ void conv_3x3s2_direct_fp32(const float* i_data,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int threads = ctx->threads();
int l2_size = ctx->llc_size() / sizeof(float);
const int pad_w = paddings[2];
......@@ -510,7 +511,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
#pragma omp parallel for num_threads(threads)
......@@ -839,7 +841,8 @@ void conv_3x3s2_direct_fp32(const float* i_data,
oh,
ow,
flag_relu,
ptr_write);
ptr_write,
&act_param);
}
}
}
......
......@@ -205,14 +205,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"ext v10.16b, %[vzero].16b, v9.16b, #12 \n" \
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
"fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n"
#define LEFT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
......@@ -244,53 +242,52 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"blt 1f \n"
#define MID_COMPUTE_S2 \
"2: \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n"
#define MID_COMPUTE_S2 \
"2: \n" /* r0 */ \
"fmul v11.4s, v0.4s, %[w0].s[0] \n" \
"fmul v12.4s, v1.4s, %[w0].s[1] \n" \
"fmla v16.4s, v10.4s, %[w0].s[2] \n" \
\
"ext v10.16b, v2.16b, v18.16b, #4 \n" \
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" /* r1 */ \
"fmla v11.4s, v2.4s, %[w1].s[0] \n" \
"fmla v12.4s, v3.4s, %[w1].s[1] \n" \
"fmla v16.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v4.16b, v19.16b, #4 \n" \
\
"ld2 {v2.4s, v3.4s}, [%[inptr1]], #32 \n" /* r2 */ \
"fmul v13.4s, v4.4s, %[w0].s[0] \n" \
"fmla v11.4s, v4.4s, %[w2].s[0] \n" \
\
"fmul v14.4s, v5.4s, %[w0].s[1] \n" \
"fmla v12.4s, v5.4s, %[w2].s[1] \n" \
\
"fmla v17.4s, v10.4s, %[w0].s[2] \n" \
"fmla v16.4s, v10.4s, %[w2].s[2] \n" \
\
"ext v10.16b, v6.16b, v20.16b, #4 \n" \
\
"ld2 {v4.4s, v5.4s}, [%[inptr2]], #32 \n" /* r3 */ \
"fmla v13.4s, v6.4s, %[w1].s[0] \n" \
"fmla v14.4s, v7.4s, %[w1].s[1] \n" \
"fmla v17.4s, v10.4s, %[w1].s[2] \n" \
\
"ext v10.16b, v8.16b, v21.16b, #4 \n" \
\
"ld2 {v6.4s, v7.4s}, [%[inptr3]], #32 \n" \
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" /* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n"
#define MID_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......@@ -360,14 +357,12 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"fadd v16.4s, v16.4s, v11.4s \n" \
"fadd v16.4s, v16.4s, v12.4s \n" \
"ld1 {v1.4s}, [%[outptr1]] \n"
"ld1 {v1.4s}, [%[outptr1]] \n" /* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n"
#define RIGHT_RESULT_S2 \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"bif v16.16b, v0.16b, %[wmask].16b \n" \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......@@ -382,11 +377,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"4: \n"
#define LEFT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[1] \n" \
"fmla v14.4s, v9.4s, %[w2].s[2] \n" \
"fmla v17.4s, v10.4s, %[w2].s[0] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" \
\
"ld2 {v0.4s, v1.4s}, [%[inptr0]], #32 \n" \
......@@ -424,14 +414,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"blt 1f \n"
#define MID_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"ld2 {v8.4s, v9.4s}, [%[inptr4]], #32 \n" \
"ld1 {v15.4s}, [%[inptr0]] \n" \
"ld1 {v18.4s}, [%[inptr1]] \n" \
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......@@ -457,11 +439,6 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"bne 2b \n"
#define RIGHT_RESULT_S2_RELU \
/* r4 */ \
"fmla v13.4s, v8.4s, %[w2].s[0] \n" \
"fmla v14.4s, v9.4s, %[w2].s[1] \n" \
"fmla v17.4s, v10.4s, %[w2].s[2] \n" \
\
"fmax v16.4s, v16.4s, %[vzero].4s \n" /* relu */ \
\
"fadd v17.4s, v17.4s, v13.4s \n" \
......
......@@ -37,6 +37,7 @@ void conv_3x3s1_depthwise_fp32(const float* i_data,
const float* weights,
const float* bias,
const operators::ConvParam& param,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_3x3s2_depthwise_fp32(const float* i_data,
......@@ -67,6 +68,7 @@ void conv_depthwise_3x3s1_fp32(const float* din,
int pad,
bool flag_bias,
bool flag_relu,
const operators::ActivationParam act_param,
ARMContext* ctx);
void conv_depthwise_3x3s2_fp32(const float* din,
......
......@@ -579,6 +579,7 @@ void conv_depthwise_3x3_fp32(const void* din,
ARMContext* ctx,
const float* scale) {
auto paddings = *param.paddings;
auto act_param = param.activation_param;
const int pad_h = paddings[0];
const int pad_w = paddings[2];
int stride = param.strides[1];
......@@ -603,6 +604,7 @@ void conv_depthwise_3x3_fp32(const void* din,
pad,
flag_bias,
flag_relu,
act_param,
ctx);
} else {
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
......@@ -617,6 +619,7 @@ void conv_depthwise_3x3_fp32(const void* din,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
}
......
......@@ -67,7 +67,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
impl_ = new DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) {
no_dilation && pads_all_equal) {
/// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv";
......
......@@ -52,6 +52,34 @@ inline int ConvOutputSize(int input_size,
return output_size;
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims();
......
......@@ -137,34 +137,6 @@ class ConvOpLite : public OpLite {
std::string padding_algorithm_{""};
};
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册