diff --git a/lite/kernels/x86/conv_compute.h b/lite/kernels/x86/conv_compute.h index 063e66a1c19dc2250e1834a4d9dd2a22b0e5f6e4..e9f403059f90cf6635bc22db3e6890b86cbe85f6 100644 --- a/lite/kernels/x86/conv_compute.h +++ b/lite/kernels/x86/conv_compute.h @@ -65,7 +65,7 @@ class Conv2dCompute : public KernelLite { col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; } lite::DDim col_shape(col_shape_vec); - lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim); + lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim + 1); bool is_expand = IsExpand( filter_shape_vec, param.strides, *param.paddings, *param.dilations); lite::Tensor col;