提交 f1840b53 编写于 作者: H hjchen2

Fix entering the winograd branch for depthwise convolution, fix ShareDataWith bug

上级 44689fc1
...@@ -57,6 +57,8 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) { ...@@ -57,6 +57,8 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Strides()[0] == 2 && param->Paddings()[0] == 1 && param->Strides()[0] == 2 && param->Paddings()[0] == 1 &&
param->Paddings()[0] == param->Paddings()[1]) { param->Paddings()[0] == param->Paddings()[1]) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3S2P1_FLOAT;
} else if (depth3x3) {
param->ExecMode() = ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT;
#ifndef __aarch64__ #ifndef __aarch64__
} else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] &&
param->Strides()[0] == 1) { param->Strides()[0] == 1) {
...@@ -106,6 +108,10 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) { ...@@ -106,6 +108,10 @@ void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(),
nullptr, false, false); nullptr, false, false);
break; break;
case ConvParam<CPU>::EXEC_DEPTHWISE3x3_FLOAT:
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), nullptr, param.Output(), false);
break;
#ifndef __aarch64__ #ifndef __aarch64__
case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT: case ConvParam<CPU>::EXEC_DEPTHWISE5x5_FLOAT:
DepthwiseConv5x5<float, float>(param); DepthwiseConv5x5<float, float>(param);
......
...@@ -90,8 +90,8 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -90,8 +90,8 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) { if (!is_expand) {
col.ShareDataWith(in_slice); // col_matrix.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col); col_matrix = in_slice;
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) { } else if (data_dim == 2U) {
// im2col // im2col
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册