提交 8b36f2aa 编写于 作者: L lijianshe02 提交者: GitHub

fix conv2d kernel bugs that results in precision diff test=develop (#2420)


* fix conv kernel bugs and open mobilenet ci test=develop
上级 553f314d
......@@ -65,7 +65,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
lite::DDim col_shape(col_shape_vec);
lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim + 1);
lite::DDim col_matrix_shape = col_shape.Flatten2D(data_dim);
bool is_expand = IsExpand(
filter_shape_vec, param.strides, param.paddings, param.dilations);
lite::Tensor col;
......@@ -95,19 +95,14 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto blas =
paddle::lite::x86::math::GetBlas<lite::TargetType::kX86, T>(context);
for (int i = 0; i < batch_size; i++) {
lite::Tensor in_batch;
lite::Tensor tmp_in_batch = param.x->Slice<T>(i, i + 1);
tmp_in_batch.Resize(input_shape);
in_batch.ShareDataWith(tmp_in_batch);
lite::Tensor out_batch;
lite::Tensor tmp_out_batch = param.output->Slice<T>(i, i + 1);
tmp_out_batch.Resize(output_matrix_shape);
out_batch.ShareDataWith(tmp_out_batch);
lite::Tensor in_batch = param.x->Slice<T>(i, i + 1);
in_batch.Resize(input_shape);
lite::Tensor out_batch = param.output->Slice<T>(i, i + 1);
out_batch.Resize(output_matrix_shape);
for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice;
in_slice.ShareDataWith(
lite::Tensor in_slice =
in_batch.Slice<T>(static_cast<int64_t>(g * in_step),
static_cast<int64_t>((g + 1) * in_step)));
static_cast<int64_t>((g + 1) * in_step));
if (!is_expand) {
col.ShareDataWith(in_slice);
......@@ -136,13 +131,13 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
// gemm
lite::Tensor out_slice;
out_slice.ShareDataWith(
out_slice =
out_batch.Slice<T>(static_cast<int64_t>(g * out_step),
static_cast<int64_t>((g + 1) * out_step)));
static_cast<int64_t>((g + 1) * out_step));
lite::Tensor filter_slice;
filter_slice.ShareDataWith(
filter_slice =
filter.Slice<T>(static_cast<int64_t>(g * out_step),
static_cast<int64_t>((g + 1) * out_step)));
static_cast<int64_t>((g + 1) * out_step));
blas.MatMul(filter_slice,
false,
col_matrix,
......
......@@ -195,7 +195,6 @@ function test_server {
# Due to the missing of x86 kernels, we skip the following tests temporarily.
# TODO(xxx) clear the skip list latter
local skip_list=("test_paddle_api" "test_cxx_api"
"test_mobilenetv1_lite_x86" "test_mobilenetv2_lite_x86"
"test_light_api"
"test_apis" "test_model_bin"
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册