未验证 提交 5dd5ed67 编写于 作者: Z zhupengyang 提交者: GitHub

[xpu] fix gather and cast unittests (#4396)

上级 db44e8b4
...@@ -130,7 +130,6 @@ void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) { ...@@ -130,7 +130,6 @@ void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) {
} }
TEST(Cast, precision) { TEST(Cast, precision) {
LOG(INFO) << "test cast op";
Place place; Place place;
float abs_error = 2e-5; float abs_error = 2e-5;
#if defined(LITE_WITH_ARM) #if defined(LITE_WITH_ARM)
...@@ -150,7 +149,7 @@ TEST(Cast, precision) { ...@@ -150,7 +149,7 @@ TEST(Cast, precision) {
TestCast(place, abs_error, 20, 5); TestCast(place, abs_error, 20, 5);
#endif #endif
TestCast(place, abs_error, 2, 5); TestCast(place, abs_error, 2, 5);
#if defined(LITE_WITH_XPU) || defined(LITE_WITH_HUAWEI_ASCEND_NPU) #if defined(LITE_WITH_HUAWEI_ASCEND_NPU)
TestCast(place, abs_error, 3, 5); TestCast(place, abs_error, 3, 5);
TestCast(place, abs_error, 5, 3); TestCast(place, abs_error, 5, 3);
#endif #endif
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
template <class T = float, class R = int64_t>
class GatherComputeTest : public arena::TestCase { class GatherComputeTest : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
...@@ -53,9 +54,9 @@ class GatherComputeTest : public arena::TestCase { ...@@ -53,9 +54,9 @@ class GatherComputeTest : public arena::TestCase {
out_dims[0] = batch_size; out_dims[0] = batch_size;
out->Resize(out_dims); out->Resize(out_dims);
auto x_data = x->data<int64_t>(); auto x_data = x->template data<T>();
auto index_data = index->data<int64_t>(); auto index_data = index->template data<R>();
auto out_data = out->mutable_data<int64_t>(); auto out_data = out->template mutable_data<T>();
auto slice_num = x_dims[0]; auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production(); auto slice_size = x_dims.Slice(1, x_dims.size()).production();
...@@ -66,7 +67,7 @@ class GatherComputeTest : public arena::TestCase { ...@@ -66,7 +67,7 @@ class GatherComputeTest : public arena::TestCase {
CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index; CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index;
memcpy(out_data + i * slice_size, memcpy(out_data + i * slice_size,
x_data + index * slice_size, x_data + index * slice_size,
slice_size * sizeof(int64_t)); slice_size * sizeof(T));
} }
} }
...@@ -78,11 +79,12 @@ class GatherComputeTest : public arena::TestCase { ...@@ -78,11 +79,12 @@ class GatherComputeTest : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
std::vector<int64_t> x(x_dims_.production()); std::vector<T> x(x_dims_.production());
fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production()); fill_data_rand(
x.data(), static_cast<T>(-1), static_cast<T>(1), x_dims_.production());
std::vector<int64_t> index(index_dims_.production()); std::vector<R> index(index_dims_.production());
fill_data_rand<int64_t>( fill_data_rand<R>(
index.data(), 0, x_dims_[0] - 1, index_dims_.production()); index.data(), 0, x_dims_[0] - 1, index_dims_.production());
SetCommonTensor(x_, x_dims_, x.data()); SetCommonTensor(x_, x_dims_, x.data());
...@@ -90,8 +92,20 @@ class GatherComputeTest : public arena::TestCase { ...@@ -90,8 +92,20 @@ class GatherComputeTest : public arena::TestCase {
} }
}; };
template <class T = float, class R = int64_t>
void TestGather(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& index_dims,
Place place,
float abs_error = 1e-5,
const std::string& alias = "def") {
std::unique_ptr<arena::TestCase> tester(new GatherComputeTest<T, R>(
place, alias, DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
TEST(Gather, precision) { TEST(Gather, precision) {
float abs_error = 2e-5; float abs_error = 1e-5;
Place place; Place place;
#if defined(LITE_WITH_NPU) #if defined(LITE_WITH_NPU)
place = TARGET(kNPU); place = TARGET(kNPU);
...@@ -110,10 +124,14 @@ TEST(Gather, precision) { ...@@ -110,10 +124,14 @@ TEST(Gather, precision) {
for (auto x_dims : for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) { std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) { for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) {
std::unique_ptr<arena::TestCase> tester(new GatherComputeTest( #if defined(LITE_WITH_XPU) || defined(LITE_WITH_NPU)
place, "int64", DDim(x_dims), DDim(index_dims))); TestGather<float, int>(x_dims, index_dims, place, abs_error, "def");
arena::Arena arena(std::move(tester), place, abs_error); #else
arena.TestPrecision(); TestGather<float, int64_t>(x_dims, index_dims, place, abs_error, "int64");
TestGather<int64_t, int64_t>(
x_dims, index_dims, place, abs_error, "int64");
TestGather<float, int>(x_dims, index_dims, place, abs_error, "int32");
#endif
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册