未验证 提交 d5e7e73e 编写于 作者: H HappyAngel 提交者: GitHub

delete ch_four contorl in conv_3x3_dw. test=develop (#4435)

上级 3e7359e4
...@@ -645,7 +645,6 @@ void conv_3x3s1_depthwise_fp32_bias(const float* i_data, ...@@ -645,7 +645,6 @@ void conv_3x3s1_depthwise_fp32_bias(const float* i_data,
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
/// get workspace /// get workspace
LOG(INFO) << "conv_3x3s1_depthwise_fp32_bias: ";
float* ptr_zero = ctx->workspace_data<float>(); float* ptr_zero = ctx->workspace_data<float>();
memset(ptr_zero, 0, sizeof(float) * win_round); memset(ptr_zero, 0, sizeof(float) * win_round);
float* ptr_write = ptr_zero + win_round; float* ptr_write = ptr_zero + win_round;
......
...@@ -620,10 +620,8 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -620,10 +620,8 @@ void conv_depthwise_3x3_fp32(const void* din,
int pad = pad_w; int pad = pad_w;
bool flag_bias = param.bias != nullptr; bool flag_bias = param.bias != nullptr;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
bool ch_four = ch_in <= 4 * w_in;
if (stride == 1) { if (stride == 1) {
if (ch_four && pads_less && (pad_h == pad_w) && if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
...@@ -656,8 +654,7 @@ void conv_depthwise_3x3_fp32(const void* din, ...@@ -656,8 +654,7 @@ void conv_depthwise_3x3_fp32(const void* din,
ctx); ctx);
} }
} else if (stride == 2) { } else if (stride == 2) {
if (ch_four && pads_less && pad_h == pad_w && if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din), conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout), reinterpret_cast<float*>(dout),
num, num,
......
...@@ -32,11 +32,10 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -32,11 +32,10 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto hin = param.x->dims()[2]; auto hin = param.x->dims()[2];
auto win = param.x->dims()[3]; auto win = param.x->dims()[3];
auto paddings = *param.paddings; auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel // select dw conv kernel
if (kw == 3) { if (kw == 3) {
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2)); bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (ch_four && pads_less && paddings[0] == paddings[2] && if (pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) { (paddings[0] == 0 || paddings[0] == 1)) {
flag_trans_weights_ = false; flag_trans_weights_ = false;
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册