未验证 提交 d14e57f7 编写于 作者: L Leonardo-Ding 提交者: GitHub

[ARM] optimize depthwise int8 f3s1 arm neon kernel,test=develop (#4125)

上级 42cefe1b
...@@ -106,6 +106,42 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -106,6 +106,42 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int padh, int padh,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1_int8_int8_impl(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
void conv_depthwise_3x3s1_int8_float_impl(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
void conv_depthwise_3x3s2_int8(Dtype* dout, void conv_depthwise_3x3s2_int8(Dtype* dout,
const int8_t* din, const int8_t* din,
......
...@@ -814,24 +814,52 @@ void conv_depthwise_3x3_int8_fp32(const void* din, ...@@ -814,24 +814,52 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
alpha[3] = local_alpha; alpha[3] = local_alpha;
} }
} }
bool support_act_type = flag_act <= 1;
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (param.strides[0] == 1 && param.strides[1] == 1);
bool support_width_type = w_in > 9 ? true : false;
if (stride == 1) { if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout), if (!support_act_type || !support_pad_type || !support_stride_type ||
reinterpret_cast<const int8_t*>(din), !support_width_type) {
reinterpret_cast<const int8_t*>(weights), conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout),
scale, reinterpret_cast<const int8_t*>(din),
bias, reinterpret_cast<const int8_t*>(weights),
flag_bias, scale,
flag_act, bias,
alpha, flag_bias,
num, flag_act,
ch_in, alpha,
h_in, num,
w_in, ch_in,
h_out, h_in,
w_out, w_in,
pad_w, h_out,
pad_h, w_out,
ctx); pad_w,
pad_h,
ctx);
} else {
conv_depthwise_3x3s1_int8_float_impl(
reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_act,
alpha,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
}
} else if (stride == 2) { } else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout), conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
...@@ -897,24 +925,52 @@ void conv_depthwise_3x3_int8_int8(const void* din, ...@@ -897,24 +925,52 @@ void conv_depthwise_3x3_int8_int8(const void* din,
alpha[3] = local_alpha; alpha[3] = local_alpha;
} }
} }
bool support_act_type = flag_act <= 1;
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (param.strides[0] == 1 && param.strides[1] == 1);
bool support_width_type = w_in > 9 ? true : false;
if (stride == 1) { if (stride == 1) {
conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout), if (!support_act_type || !support_pad_type || !support_stride_type ||
reinterpret_cast<const int8_t*>(din), !support_width_type) {
reinterpret_cast<const int8_t*>(weights), conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout),
scale, reinterpret_cast<const int8_t*>(din),
bias, reinterpret_cast<const int8_t*>(weights),
flag_bias, scale,
flag_act, bias,
alpha, flag_bias,
num, flag_act,
ch_in, alpha,
h_in, num,
w_in, ch_in,
h_out, h_in,
w_out, w_in,
pad_w, h_out,
pad_h, w_out,
ctx); pad_w,
pad_h,
ctx);
} else {
conv_depthwise_3x3s1_int8_int8_impl(
reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_act,
alpha,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
}
} else if (stride == 2) { } else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout), conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
......
...@@ -31,7 +31,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -31,7 +31,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto paddings = *param.paddings; auto paddings = *param.paddings;
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
// VLOG(5) << "invoke 3x3 dw conv fp32";
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] && if (pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) { (paddings[0] == 0 || paddings[0] == 1)) {
...@@ -54,7 +53,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -54,7 +53,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
kernel_func_name_ = "conv_depthwise_3x3_fp32"; kernel_func_name_ = "conv_depthwise_3x3_fp32";
#endif #endif
} else if (kw == 5) { } else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
auto strides = param.strides; auto strides = param.strides;
if ((strides[0] == 1 && strides[1] == 1) || if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) { (strides[0] == 2 && strides[1] == 2)) {
...@@ -104,23 +102,44 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -104,23 +102,44 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
w_scale_[i] = scale[i] * in_scale; w_scale_[i] = scale[i] * in_scale;
} }
} }
auto paddings = *param.paddings;
auto strides = param.strides;
auto x_dims = param.x->dims();
int iw = x_dims[3];
int ih = x_dims[2];
auto act_param = param.activation_param;
bool has_act = act_param.has_active;
lite_api::ActivationType act_type = act_param.active_type;
// no activation and relu activation is supported now
bool support_act_type =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (strides[0] == 1 && strides[1] == 1);
bool support_width_type = iw > 9 ? true : false;
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_fp32"; kernel_func_name_ = "conv_depthwise_3x3_int8_fp32";
#endif #endif
int cround = ROUNDUP(w_dims[0], 8); if (!support_act_type || !support_pad_type || !support_stride_type ||
weights_.Resize({cround / 8, 1, kh * kw, 8}); !support_width_type) {
auto wptr = param.filter->data<int8_t>(); int cround = ROUNDUP(w_dims[0], 8);
auto wptr_new = weights_.mutable_data<int8_t>(); weights_.Resize({cround / 8, 1, kh * kw, 8});
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); auto wptr = param.filter->data<int8_t>();
flag_trans_weights_ = true; auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_fp32"; kernel_func_name_ = "conv_depthwise_5x5_int8_fp32";
...@@ -175,23 +194,45 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -175,23 +194,45 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
param.activation_param.Relu_clipped_coef = param.activation_param.Relu_clipped_coef =
param.activation_param.Relu_clipped_coef / param.output_scale; param.activation_param.Relu_clipped_coef / param.output_scale;
} }
auto paddings = *param.paddings;
auto strides = param.strides;
auto x_dims = param.x->dims();
int iw = x_dims[3];
int ih = x_dims[2];
auto act_param = param.activation_param;
bool has_act = act_param.has_active;
lite_api::ActivationType act_type = act_param.active_type;
// no activation and relu activation is supported now
bool support_act_type =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (strides[0] == 1 && strides[1] == 1);
bool support_width_type = iw > 9 ? true : false;
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_int8"; kernel_func_name_ = "conv_depthwise_3x3_int8_int8";
#endif #endif
int cround = ROUNDUP(w_dims[0], 8); if (!support_act_type || !support_pad_type || !support_stride_type ||
weights_.Resize({cround / 8, 1, kh * kw, 8}); !support_width_type) {
auto wptr = param.filter->data<int8_t>(); int cround = ROUNDUP(w_dims[0], 8);
auto wptr_new = weights_.mutable_data<int8_t>(); weights_.Resize({cround / 8, 1, kh * kw, 8});
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); auto wptr = param.filter->data<int8_t>();
flag_trans_weights_ = true; auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8; impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_int8"; kernel_func_name_ = "conv_depthwise_5x5_int8_int8";
...@@ -283,7 +324,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -283,7 +324,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw int iw = x_dims[3];
int ih = x_dims[2]; int ih = x_dims[2];
int ic = x_dims[1]; int ic = x_dims[1];
int bs = x_dims[0]; int bs = x_dims[0];
...@@ -333,7 +374,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -333,7 +374,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw int iw = x_dims[3];
int ih = x_dims[2]; int ih = x_dims[2];
int ic = x_dims[1]; int ic = x_dims[1];
int bs = x_dims[0]; int bs = x_dims[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册