未验证 提交 d14e57f7 编写于 作者: L Leonardo-Ding 提交者: GitHub

[ARM] optimize depthwise int8 f3s1 arm neon kernel,test=develop (#4125)

上级 42cefe1b
...@@ -106,6 +106,42 @@ void conv_depthwise_3x3s1_int8(Dtype* dout, ...@@ -106,6 +106,42 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
int padh, int padh,
ARMContext* ctx); ARMContext* ctx);
void conv_depthwise_3x3s1_int8_int8_impl(int8_t* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
void conv_depthwise_3x3s1_int8_float_impl(float* dout,
const int8_t* din,
const int8_t* weights,
const float* scale,
const float* bias,
bool flag_bias,
int flag_act,
float* alpha,
int num,
int chin,
int hin,
int win,
int hout,
int wout,
int padw,
int padh,
ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
void conv_depthwise_3x3s2_int8(Dtype* dout, void conv_depthwise_3x3s2_int8(Dtype* dout,
const int8_t* din, const int8_t* din,
......
...@@ -814,7 +814,15 @@ void conv_depthwise_3x3_int8_fp32(const void* din, ...@@ -814,7 +814,15 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
alpha[3] = local_alpha; alpha[3] = local_alpha;
} }
} }
bool support_act_type = flag_act <= 1;
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (param.strides[0] == 1 && param.strides[1] == 1);
bool support_width_type = w_in > 9 ? true : false;
if (stride == 1) { if (stride == 1) {
if (!support_act_type || !support_pad_type || !support_stride_type ||
!support_width_type) {
conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout), conv_depthwise_3x3s1_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights), reinterpret_cast<const int8_t*>(weights),
...@@ -832,6 +840,26 @@ void conv_depthwise_3x3_int8_fp32(const void* din, ...@@ -832,6 +840,26 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
pad_w, pad_w,
pad_h, pad_h,
ctx); ctx);
} else {
conv_depthwise_3x3s1_int8_float_impl(
reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_act,
alpha,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
}
} else if (stride == 2) { } else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout), conv_depthwise_3x3s2_int8(reinterpret_cast<float*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
...@@ -897,7 +925,15 @@ void conv_depthwise_3x3_int8_int8(const void* din, ...@@ -897,7 +925,15 @@ void conv_depthwise_3x3_int8_int8(const void* din,
alpha[3] = local_alpha; alpha[3] = local_alpha;
} }
} }
bool support_act_type = flag_act <= 1;
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (param.strides[0] == 1 && param.strides[1] == 1);
bool support_width_type = w_in > 9 ? true : false;
if (stride == 1) { if (stride == 1) {
if (!support_act_type || !support_pad_type || !support_stride_type ||
!support_width_type) {
conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout), conv_depthwise_3x3s1_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights), reinterpret_cast<const int8_t*>(weights),
...@@ -915,6 +951,26 @@ void conv_depthwise_3x3_int8_int8(const void* din, ...@@ -915,6 +951,26 @@ void conv_depthwise_3x3_int8_int8(const void* din,
pad_w, pad_w,
pad_h, pad_h,
ctx); ctx);
} else {
conv_depthwise_3x3s1_int8_int8_impl(
reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din),
reinterpret_cast<const int8_t*>(weights),
scale,
bias,
flag_bias,
flag_act,
alpha,
num,
ch_in,
h_in,
w_in,
h_out,
w_out,
pad_w,
pad_h,
ctx);
}
} else if (stride == 2) { } else if (stride == 2) {
conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout), conv_depthwise_3x3s2_int8(reinterpret_cast<int8_t*>(dout),
reinterpret_cast<const int8_t*>(din), reinterpret_cast<const int8_t*>(din),
......
...@@ -31,7 +31,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -31,7 +31,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto paddings = *param.paddings; auto paddings = *param.paddings;
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
// VLOG(5) << "invoke 3x3 dw conv fp32";
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] && if (pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) { (paddings[0] == 0 || paddings[0] == 1)) {
...@@ -54,7 +53,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -54,7 +53,6 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
kernel_func_name_ = "conv_depthwise_3x3_fp32"; kernel_func_name_ = "conv_depthwise_3x3_fp32";
#endif #endif
} else if (kw == 5) { } else if (kw == 5) {
// VLOG(5) << "invoke 5x5 dw conv fp32";
auto strides = param.strides; auto strides = param.strides;
if ((strides[0] == 1 && strides[1] == 1) || if ((strides[0] == 1 && strides[1] == 1) ||
(strides[0] == 2 && strides[1] == 2)) { (strides[0] == 2 && strides[1] == 2)) {
...@@ -104,23 +102,44 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -104,23 +102,44 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
w_scale_[i] = scale[i] * in_scale; w_scale_[i] = scale[i] * in_scale;
} }
} }
auto paddings = *param.paddings;
auto strides = param.strides;
auto x_dims = param.x->dims();
int iw = x_dims[3];
int ih = x_dims[2];
auto act_param = param.activation_param;
bool has_act = act_param.has_active;
lite_api::ActivationType act_type = act_param.active_type;
// no activation and relu activation is supported now
bool support_act_type =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (strides[0] == 1 && strides[1] == 1);
bool support_width_type = iw > 9 ? true : false;
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32; impl_ = lite::arm::math::conv_depthwise_3x3_int8_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_fp32"; kernel_func_name_ = "conv_depthwise_3x3_int8_fp32";
#endif #endif
if (!support_act_type || !support_pad_type || !support_stride_type ||
!support_width_type) {
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8}); weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>(); auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>(); auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel fp32 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32; impl_ = lite::arm::math::conv_depthwise_5x5_int8_fp32;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_fp32"; kernel_func_name_ = "conv_depthwise_5x5_int8_fp32";
...@@ -175,23 +194,45 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -175,23 +194,45 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
param.activation_param.Relu_clipped_coef = param.activation_param.Relu_clipped_coef =
param.activation_param.Relu_clipped_coef / param.output_scale; param.activation_param.Relu_clipped_coef / param.output_scale;
} }
auto paddings = *param.paddings;
auto strides = param.strides;
auto x_dims = param.x->dims();
int iw = x_dims[3];
int ih = x_dims[2];
auto act_param = param.activation_param;
bool has_act = act_param.has_active;
lite_api::ActivationType act_type = act_param.active_type;
// no activation and relu activation is supported now
bool support_act_type =
(has_act == false) ||
(has_act == true && act_type == lite_api::ActivationType::kRelu);
bool support_pad_type =
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]) && (paddings[0] == 0 || paddings[0] == 1);
bool support_stride_type = (strides[0] == 1 && strides[1] == 1);
bool support_width_type = iw > 9 ? true : false;
/// select dw conv kernel /// select dw conv kernel
if (kw == 3) { if (kw == 3) {
// trans weights // trans weights
// VLOG(5) << "invoke 3x3 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8; impl_ = lite::arm::math::conv_depthwise_3x3_int8_int8;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_3x3_int8_int8"; kernel_func_name_ = "conv_depthwise_3x3_int8_int8";
#endif #endif
if (!support_act_type || !support_pad_type || !support_stride_type ||
!support_width_type) {
int cround = ROUNDUP(w_dims[0], 8); int cround = ROUNDUP(w_dims[0], 8);
weights_.Resize({cround / 8, 1, kh * kw, 8}); weights_.Resize({cround / 8, 1, kh * kw, 8});
auto wptr = param.filter->data<int8_t>(); auto wptr = param.filter->data<int8_t>();
auto wptr_new = weights_.mutable_data<int8_t>(); auto wptr_new = weights_.mutable_data<int8_t>();
lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9); lite::arm::math::conv_trans_weights_numc(wptr, wptr_new, oc, 1, 8, 9);
flag_trans_weights_ = true; flag_trans_weights_ = true;
} else {
flag_trans_weights_ = false;
}
} else if (kw == 5) { } else if (kw == 5) {
// trans weights // trans weights
// VLOG(5) << "invoke 5x5 dw conv int8 kernel int8 out";
impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8; impl_ = lite::arm::math::conv_depthwise_5x5_int8_int8;
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_depthwise_5x5_int8_int8"; kernel_func_name_ = "conv_depthwise_5x5_int8_int8";
...@@ -283,7 +324,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -283,7 +324,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw int iw = x_dims[3];
int ih = x_dims[2]; int ih = x_dims[2];
int ic = x_dims[1]; int ic = x_dims[1];
int bs = x_dims[0]; int bs = x_dims[0];
...@@ -333,7 +374,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -333,7 +374,7 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
auto o_dims = param.output->dims(); auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw int iw = x_dims[3];
int ih = x_dims[2]; int ih = x_dims[2];
int ic = x_dims[1]; int ic = x_dims[1];
int bs = x_dims[0]; int bs = x_dims[0];
......
...@@ -125,7 +125,7 @@ void release_param(ConvParam* param) { ...@@ -125,7 +125,7 @@ void release_param(ConvParam* param) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h" #include "lite/backends/arm/math/funcs.h"
void test_conv_int8(const std::vector<DDim>& input_dims, void test_conv_int8(const DDim& dim_in,
const DDim& weight_dim, const DDim& weight_dim,
int group, int group,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -237,24 +237,21 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -237,24 +237,21 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
conv_int8_fp32.SetContext(std::move(ctx2)); conv_int8_fp32.SetContext(std::move(ctx2));
/// set param and context /// set param and context
for (auto& dim_in : input_dims) {
param_int8_out.x->Resize(dim_in); param_int8_out.x->Resize(dim_in);
DDim out_tmp_dims = compute_out_dim(dim_in, param_int8_out); DDim out_tmp_dims = compute_out_dim(dim_in, param_int8_out);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) { if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue; return;
} }
param_fp32_out.x->Resize(dim_in); param_fp32_out.x->Resize(dim_in);
param_int8_out.output->Resize(out_tmp_dims); param_int8_out.output->Resize(out_tmp_dims);
param_fp32_out.output->Resize(out_tmp_dims); param_fp32_out.output->Resize(out_tmp_dims);
break;
}
conv_int8_int8.SetParam(param_int8_out); conv_int8_int8.SetParam(param_int8_out);
conv_int8_fp32.SetParam(param_fp32_out); conv_int8_fp32.SetParam(param_fp32_out);
/// prepare for run /// prepare for run
conv_int8_int8.PrepareForRun(); conv_int8_int8.PrepareForRun();
conv_int8_fp32.PrepareForRun(); conv_int8_fp32.PrepareForRun();
for (auto& dim_in : input_dims) {
CHECK_EQ(weight_dim[1] * group, dim_in[1]) CHECK_EQ(weight_dim[1] * group, dim_in[1])
<< "input channel must equal to weights channel"; << "input channel must equal to weights channel";
DDim dim_out = compute_out_dim(dim_in, param_int8_out); DDim dim_out = compute_out_dim(dim_in, param_int8_out);
...@@ -333,7 +330,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -333,7 +330,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
weight_dim[3] / group; weight_dim[3] / group;
/// warm up /// warm up
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
conv_int8_int8.Launch(); conv_int8_fp32.Launch();
} }
/// compute fp32 output /// compute fp32 output
Timer t0; Timer t0;
...@@ -343,13 +340,13 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -343,13 +340,13 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
t0.Stop(); t0.Stop();
} }
LOG(INFO) << "int8 conv, fp32 output: output shape" << dim_out LOG(INFO) << "int8 conv, fp32 output: output shape" << dim_out
<< ",running time, avg: " << t0.LapTimes().Avg() << ",running time, avg: " << t0.LapTimes().Avg() << " ms"
<< ", min time: " << t0.LapTimes().Min() << ", min time: " << t0.LapTimes().Min() << " ms"
<< ", total GOPS: " << 1e-9 * gops << ", total GOPS: " << 1e-9 * gops
<< " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg() << " GOPS, avg GOPs: " << 1e-6 * gops / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min(); << " GOPs, max GOPs: " << 1e-6 * gops / t0.LapTimes().Min();
/// compute int8 output // compute int8 output
t0.Reset(); t0.Reset();
for (int i = 0; i < FLAGS_repeats; ++i) { for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start(); t0.Start();
...@@ -386,9 +383,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -386,9 +383,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
release_param(&param_fp32_out); release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in LOG(FATAL) << "test int8 conv, fp32 out: input: " << dim_in
<< ", output: " << dim_out << ", output: " << dim_out
<< ", weight dim: " << weight_dim << ", weight dim: " << weight_dim << ", pad: " << pads[0]
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << ", " << pads[1] << ", " << pads[2] << ", " << pads[3]
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group << ", group: " << group
...@@ -398,7 +394,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -398,7 +394,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
} }
} }
} }
/// compare result int8 output // compare result int8 output
if (FLAGS_check_result) { if (FLAGS_check_result) {
double max_ratio = 0; double max_ratio = 0;
double max_diff = 0; double max_diff = 0;
...@@ -431,8 +427,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -431,8 +427,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
count += 1; count += 1;
} }
} }
check = check = check &&
check &&
count < std::max(10, static_cast<int>(0.01 * tdiff.numel())); count < std::max(10, static_cast<int>(0.01 * tdiff.numel()));
if (!check) { if (!check) {
LOG(WARNING) << "int8 basic result"; LOG(WARNING) << "int8 basic result";
...@@ -445,9 +440,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -445,9 +440,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
release_param(&param_fp32_out); release_param(&param_fp32_out);
LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in LOG(FATAL) << "test int8 conv, int8 out: input: " << dim_in
<< ", output: " << dim_out << ", output: " << dim_out
<< ", weight dim: " << weight_dim << ", weight dim: " << weight_dim << ", pad: " << pads[0]
<< ", pad: " << pads[0] << ", " << pads[1] << ", " << ", " << pads[1] << ", " << pads[2] << ", " << pads[3]
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
...@@ -466,12 +460,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -466,12 +460,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", power_mode: " << cls << " successed!!\n"; << ", power_mode: " << cls << " successed!!\n";
} }
} }
}
release_param(&param_int8_out); release_param(&param_int8_out);
release_param(&param_fp32_out); release_param(&param_fp32_out);
} }
#else #else
void test_conv_int8(const std::vector<DDim>& input_dims, void test_conv_int8(const DDim& dims_in,
const DDim& weight_dim, const DDim& weight_dim,
int group, int group,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -493,13 +486,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -493,13 +486,10 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) { for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3}); DDim weights_dim({c, 1, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 33}) { for (auto& h : {1, 3, 15, 33}) {
dims.push_back(DDim({batch, c, h, h})); DDim dims({batch, c, h, h});
}
}
test_conv_int8(dims, test_conv_int8(dims,
weights_dim, weights_dim,
c, c,
...@@ -508,7 +498,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -508,7 +498,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
{1, 1}, {1, 1},
flag_bias, flag_bias,
flag_act, flag_act,
{4}, {FLAGS_threads},
{FLAGS_power_mode}, {FLAGS_power_mode},
FLAGS_clipped_coef, FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha); FLAGS_leakey_relu_alpha);
...@@ -518,6 +508,8 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) { ...@@ -518,6 +508,8 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
} }
} }
} }
}
}
} }
#endif /// 3x3dw #endif /// 3x3dw
...@@ -529,13 +521,10 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -529,13 +521,10 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 5, 15, 33}) { for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 15, 33, 112, 224}) { for (auto& h : {1, 3, 15, 33, 112, 224}) {
dims.push_back(DDim({batch, c, h, h})); DDim dims({batch, c, h, h});
}
}
test_conv_int8(dims, test_conv_int8(dims,
weights_dim, weights_dim,
c, c,
...@@ -554,6 +543,8 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -554,6 +543,8 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
} }
} }
} }
}
}
} }
#endif /// 5x5dw #endif /// 5x5dw
...@@ -565,16 +556,13 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -565,16 +556,13 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
for (auto& g : {1, 2}) { for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
continue; continue;
} }
DDim weights_dim({cout, cin / g, 1, 1}); DDim weights_dim({cout, cin / g, 1, 1});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 9, 16, 33}) { for (auto& h : {1, 9, 16, 33}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
}
}
test_conv_int8(dims, test_conv_int8(dims,
weights_dim, weights_dim,
g, g,
...@@ -593,6 +581,8 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) { ...@@ -593,6 +581,8 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
} }
} }
} }
}
}
} }
#endif /// conv1x1s1 #endif /// conv1x1s1
...@@ -606,18 +596,16 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -606,18 +596,16 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
for (auto& pad_left : {1, 2}) { for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 17, 33}) { for (auto& h : {1, 7, 17, 33}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
}
}
if (cin == 1 && cout == 1) { if (cin == 1 && cout == 1) {
continue; continue;
} }
test_conv_int8(dims, test_conv_int8(
dims,
weights_dim, weights_dim,
1, 1,
{1, 1}, {1, 1},
...@@ -638,6 +626,8 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -638,6 +626,8 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
} }
} }
} }
}
}
} }
#endif /// conv3x3s1 #endif /// conv3x3s1
...@@ -652,14 +642,12 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -652,14 +642,12 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
for (auto& pad_right : {1, 2}) { for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) { for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3}); DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 7, 19, 33}) { for (auto& h : {1, 7, 19, 33}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
} test_conv_int8(
} dims,
test_conv_int8(dims,
weights_dim, weights_dim,
1, 1,
{2, 2}, {2, 2},
...@@ -680,6 +668,8 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) { ...@@ -680,6 +668,8 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
} }
} }
} }
}
}
} }
#endif /// conv3x3s2 #endif /// conv3x3s2
...@@ -702,19 +692,18 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -702,19 +692,18 @@ TEST(TestConvRandInt8, test_conv_rand) {
if (cin % g != 0 || cout % g != 0) { if (cin % g != 0 || cout % g != 0) {
break; break;
} }
std::vector<DDim> dims;
DDim weights_dim({cout, cin / g, kh, kw}); DDim weights_dim({cout, cin / g, kh, kw});
for (auto& batch : {1, 2}) { for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 5, 19}) { for (auto& h : {1, 3, 5, 19}) {
dims.push_back(DDim({batch, cin, h, h})); DDim dims({batch, cin, h, h});
} test_conv_int8(dims,
}
test_conv_int8(
dims,
weights_dim, weights_dim,
g, g,
{stride, stride}, {stride, stride},
{pad_top, pad_bottom, pad_left, pad_right}, {pad_top,
pad_bottom,
pad_left,
pad_right},
{dila, dila}, {dila, dila},
flag_bias, flag_bias,
flag_act, flag_act,
...@@ -736,6 +725,8 @@ TEST(TestConvRandInt8, test_conv_rand) { ...@@ -736,6 +725,8 @@ TEST(TestConvRandInt8, test_conv_rand) {
} }
} }
} }
}
}
} }
#endif /// random param conv #endif /// random param conv
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册