未验证 提交 bfd2a950 编写于 作者: L lijianshe02 提交者: GitHub

fix conv2d kernel bugs that results in precision diff test=develop (#2420)


* fix conv kernel bugs and open mobilenet ci test=develop
上级 c1837d76
...@@ -65,7 +65,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -65,7 +65,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
} }
lite::DDim col_shape(col_shape_vec); lite::DDim col_shape(col_shape_vec);
lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim + 1); lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim);
bool is_expand = IsExpand( bool is_expand = IsExpand(
filter_shape_vec, param.strides, param.paddings, param.dilations); filter_shape_vec, param.strides, param.paddings, param.dilations);
lite::Tensor col; lite::Tensor col;
...@@ -95,19 +95,14 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -95,19 +95,14 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto blas = auto blas =
paddle::lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context); paddle::lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
lite::Tensor in_batch; lite::Tensor in_batch = param.x->Slice<T>(i, i + 1);
lite::Tensor tmp_in_batch = param.x->Slice<T>(i, i + 1); in_batch.Resize(input_shape);
tmp_in_batch.Resize(input_shape); lite::Tensor out_batch = param.output->Slice<T>(i, i + 1);
in_batch.ShareDataWith(tmp_in_batch); out_batch.Resize(output_matrix_shape);
lite::Tensor out_batch;
lite::Tensor tmp_out_batch = param.output->Slice<T>(i, i + 1);
tmp_out_batch.Resize(output_matrix_shape);
out_batch.ShareDataWith(tmp_out_batch);
for (int g = 0; g < param.groups; g++) { for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice; lite::Tensor in_slice =
in_slice.ShareDataWith(
in_batch.Slice<T>(static_cast<int64_t>(g * in_step), in_batch.Slice<T>(static_cast<int64_t>(g * in_step),
static_cast<int64_t>((g + 1) * in_step))); static_cast<int64_t>((g + 1) * in_step));
if (!is_expand) { if (!is_expand) {
col.ShareDataWith(in_slice); col.ShareDataWith(in_slice);
...@@ -136,13 +131,13 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -136,13 +131,13 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
// gemm // gemm
lite::Tensor out_slice; lite::Tensor out_slice;
out_slice.ShareDataWith( out_slice =
out_batch.Slice<T>(static_cast<int64_t>(g * out_step), out_batch.Slice<T>(static_cast<int64_t>(g * out_step),
static_cast<int64_t>((g + 1) * out_step))); static_cast<int64_t>((g + 1) * out_step));
lite::Tensor filter_slice; lite::Tensor filter_slice;
filter_slice.ShareDataWith( filter_slice =
filter.Slice<T>(static_cast<int64_t>(g * out_step), filter.Slice<T>(static_cast<int64_t>(g * out_step),
static_cast<int64_t>((g + 1) * out_step))); static_cast<int64_t>((g + 1) * out_step));
blas.MatMul(filter_slice, blas.MatMul(filter_slice,
false, false,
col_matrix, col_matrix,
......
...@@ -195,7 +195,6 @@ function test_server { ...@@ -195,7 +195,6 @@ function test_server {
# Due to the missing of x86 kernels, we skip the following tests temporarily. # Due to the missing of x86 kernels, we skip the following tests temporarily.
# TODO(xxx) clear the skip list latter # TODO(xxx) clear the skip list latter
local skip_list=("test_paddle_api" "test_cxx_api" local skip_list=("test_paddle_api" "test_cxx_api"
"test_mobilenetv1_lite_x86" "test_mobilenetv2_lite_x86"
"test_light_api" "test_light_api"
"test_apis" "test_model_bin" "test_apis" "test_model_bin"
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册