未验证 提交 94947acf 编写于 作者: Y Yanzhan Yang 提交者: GitHub

prune faster depthwise (#1752)

* prune faster depthwise

* fix style
上级 74a309cb
...@@ -72,13 +72,15 @@ bool ConvAddBNReluKernel<CPU, float>::Init( ...@@ -72,13 +72,15 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
strides.size() == 2 && strides[0] == strides[1]) { strides.size() == 2 && strides[0] == strides[1]) {
int pad = paddings[0]; int pad = paddings[0];
int stride = strides[0]; int stride = strides[0];
const int hin = param->Input()->dims()[2]; const int win = param->Input()->dims()[3];
if (pad == 0 && hin > 2) { if (pad == 1) {
if (stride == 1) {
could_use_faster_depthwise_conv_ = true; could_use_faster_depthwise_conv_ = true;
} else if (pad == 1) { } else if (stride == 2 && win > 7) {
could_use_faster_depthwise_conv_ = true; could_use_faster_depthwise_conv_ = true;
} }
} }
}
break; break;
} }
......
...@@ -292,11 +292,7 @@ void FasterDepthwiseConv3x3_bias_relu(const ConvParam<CPU> &param, ...@@ -292,11 +292,7 @@ void FasterDepthwiseConv3x3_bias_relu(const ConvParam<CPU> &param,
const int hout = output->dims()[2]; const int hout = output->dims()[2];
const int wout = output->dims()[3]; const int wout = output->dims()[3];
bool flag_bias = bias != nullptr; bool flag_bias = bias != nullptr;
if (pad == 0 && hin > 2) { if (pad == 1) {
math::depthwise::conv_depthwise_3x3p0(din, dout, num, chout, hout, wout,
chin, hin, win, weights, bias, stride,
flag_bias, flag_relu);
} else if (pad == 1) {
math::depthwise::conv_depthwise_3x3p1(din, dout, num, chout, hout, wout, math::depthwise::conv_depthwise_3x3p1(din, dout, num, chout, hout, wout,
chin, hin, win, weights, bias, stride, chin, hin, win, weights, bias, stride,
flag_bias, flag_relu); flag_bias, flag_relu);
......
...@@ -23,11 +23,6 @@ namespace operators { ...@@ -23,11 +23,6 @@ namespace operators {
namespace math { namespace math {
namespace depthwise { namespace depthwise {
void conv_depthwise_3x3p0(const float* din, float* dout, int num, int ch_out,
int h_out, int w_out, int ch_in, int h_in, int w_in,
const float* weights, const float* bias, int stride,
bool flag_bias, bool flag_relu);
void conv_depthwise_3x3p1(const float* din, float* dout, int num, int ch_out, void conv_depthwise_3x3p1(const float* din, float* dout, int num, int ch_out,
int h_out, int w_out, int ch_in, int h_in, int w_in, int h_out, int w_out, int ch_in, int h_in, int w_in,
const float* weights, const float* bias, int stride, const float* weights, const float* bias, int stride,
......
...@@ -295,7 +295,8 @@ def check_mobile_results(args, fuse, mem_opt): ...@@ -295,7 +295,8 @@ def check_mobile_results(args, fuse, mem_opt):
args = "{} {} {}".format("1" if fuse else "0", "1" if mem_opt else "0", args) args = "{} {} {}".format("1" if fuse else "0", "1" if mem_opt else "0", args)
res = sh("adb shell \"cd {} && export LD_LIBRARY_PATH=. && ./test-net {}\"".format(mobile_exec_root, args)) res = sh("adb shell \"cd {} && export LD_LIBRARY_PATH=. && ./test-net {}\"".format(mobile_exec_root, args))
lines = res.split("\n") lines = res.split("\n")
# print(lines) # for line in lines:
# print(line)
for line in lines: for line in lines:
if line.startswith("auto-test-debug"): if line.startswith("auto-test-debug"):
print(line) print(line)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册