diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index e0edb3c54e38b2e4387a5886ae6f74facd5752ba..a80bc0d0720f6341a62239ac263b351b46cf3fec 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -130,7 +130,6 @@ void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) { } TEST(Cast, precision) { - LOG(INFO) << "test cast op"; Place place; float abs_error = 2e-5; #if defined(LITE_WITH_ARM) @@ -150,7 +149,7 @@ TEST(Cast, precision) { TestCast(place, abs_error, 20, 5); #endif TestCast(place, abs_error, 2, 5); -#if defined(LITE_WITH_XPU) || defined(LITE_WITH_HUAWEI_ASCEND_NPU) +#if defined(LITE_WITH_HUAWEI_ASCEND_NPU) TestCast(place, abs_error, 3, 5); TestCast(place, abs_error, 5, 3); #endif diff --git a/lite/tests/kernels/gather_compute_test.cc b/lite/tests/kernels/gather_compute_test.cc index 3f93627c03e3edfce3fc2511ae320571f68b8598..11165d335f05851bb3549da6fdf4296d65860257 100644 --- a/lite/tests/kernels/gather_compute_test.cc +++ b/lite/tests/kernels/gather_compute_test.cc @@ -21,6 +21,7 @@ namespace paddle { namespace lite { +template class GatherComputeTest : public arena::TestCase { protected: // common attributes for this op. @@ -53,9 +54,9 @@ class GatherComputeTest : public arena::TestCase { out_dims[0] = batch_size; out->Resize(out_dims); - auto x_data = x->data(); - auto index_data = index->data(); - auto out_data = out->mutable_data(); + auto x_data = x->template data(); + auto index_data = index->template data(); + auto out_data = out->template mutable_data(); auto slice_num = x_dims[0]; auto slice_size = x_dims.Slice(1, x_dims.size()).production(); @@ -66,7 +67,7 @@ class GatherComputeTest : public arena::TestCase { CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index; memcpy(out_data + i * slice_size, x_data + index * slice_size, - slice_size * sizeof(int64_t)); + slice_size * sizeof(T)); } } @@ -78,11 +79,12 @@ class GatherComputeTest : public arena::TestCase { } void PrepareData() override { - std::vector x(x_dims_.production()); - fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production()); + std::vector x(x_dims_.production()); + fill_data_rand( + x.data(), static_cast(-1), static_cast(1), x_dims_.production()); - std::vector index(index_dims_.production()); - fill_data_rand( + std::vector index(index_dims_.production()); + fill_data_rand( index.data(), 0, x_dims_[0] - 1, index_dims_.production()); SetCommonTensor(x_, x_dims_, x.data()); @@ -90,8 +92,20 @@ class GatherComputeTest : public arena::TestCase { } }; +template +void TestGather(const std::vector& x_dims, + const std::vector& index_dims, + Place place, + float abs_error = 1e-5, + const std::string& alias = "def") { + std::unique_ptr tester(new GatherComputeTest( + place, alias, DDim(x_dims), DDim(index_dims))); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); +} + TEST(Gather, precision) { - float abs_error = 2e-5; + float abs_error = 1e-5; Place place; #if defined(LITE_WITH_NPU) place = TARGET(kNPU); @@ -110,10 +124,14 @@ TEST(Gather, precision) { for (auto x_dims : std::vector>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) { for (auto index_dims : std::vector>{{3}, {7}, {10}}) { - std::unique_ptr tester(new GatherComputeTest( - place, "int64", DDim(x_dims), DDim(index_dims))); - arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(); +#if defined(LITE_WITH_XPU) || defined(LITE_WITH_NPU) + TestGather(x_dims, index_dims, place, abs_error, "def"); +#else + TestGather(x_dims, index_dims, place, abs_error, "int64"); + TestGather( + x_dims, index_dims, place, abs_error, "int64"); + TestGather(x_dims, index_dims, place, abs_error, "int32"); +#endif } } }