提交 8c89ef6c 编写于 作者: Y Yanzhan Yang 提交者: GitHub

prune faster depthwise (#1752)

* prune faster depthwise

* fix style
上级 5b7993e1
......@@ -72,11 +72,13 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
strides.size() == 2 && strides[0] == strides[1]) {
int pad = paddings[0];
int stride = strides[0];
const int hin = param->Input()->dims()[2];
if (pad == 0 && hin > 2) {
could_use_faster_depthwise_conv_ = true;
} else if (pad == 1) {
could_use_faster_depthwise_conv_ = true;
const int win = param->Input()->dims()[3];
if (pad == 1) {
if (stride == 1) {
could_use_faster_depthwise_conv_ = true;
} else if (stride == 2 && win > 7) {
could_use_faster_depthwise_conv_ = true;
}
}
}
break;
......
......@@ -292,11 +292,7 @@ void FasterDepthwiseConv3x3_bias_relu(const ConvParam<CPU> &param,
const int hout = output->dims()[2];
const int wout = output->dims()[3];
bool flag_bias = bias != nullptr;
if (pad == 0 && hin > 2) {
math::depthwise::conv_depthwise_3x3p0(din, dout, num, chout, hout, wout,
chin, hin, win, weights, bias, stride,
flag_bias, flag_relu);
} else if (pad == 1) {
if (pad == 1) {
math::depthwise::conv_depthwise_3x3p1(din, dout, num, chout, hout, wout,
chin, hin, win, weights, bias, stride,
flag_bias, flag_relu);
......
......@@ -23,11 +23,6 @@ namespace operators {
namespace math {
namespace depthwise {
void conv_depthwise_3x3p0(const float* din, float* dout, int num, int ch_out,
int h_out, int w_out, int ch_in, int h_in, int w_in,
const float* weights, const float* bias, int stride,
bool flag_bias, bool flag_relu);
void conv_depthwise_3x3p1(const float* din, float* dout, int num, int ch_out,
int h_out, int w_out, int ch_in, int h_in, int w_in,
const float* weights, const float* bias, int stride,
......
......@@ -295,7 +295,8 @@ def check_mobile_results(args, fuse, mem_opt):
args = "{} {} {}".format("1" if fuse else "0", "1" if mem_opt else "0", args)
res = sh("adb shell \"cd {} && export LD_LIBRARY_PATH=. && ./test-net {}\"".format(mobile_exec_root, args))
lines = res.split("\n")
# print(lines)
# for line in lines:
# print(line)
for line in lines:
if line.startswith("auto-test-debug"):
print(line)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册