未验证 提交 e9b61fd4 编写于 作者: Z zhupengyang 提交者: GitHub

[xpu] fix gather and cast unittests (#4396) (#4463)

上级 27144a7f
......@@ -130,7 +130,6 @@ void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) {
}
TEST(Cast, precision) {
LOG(INFO) << "test cast op";
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
......@@ -150,7 +149,7 @@ TEST(Cast, precision) {
TestCast(place, abs_error, 20, 5);
#endif
TestCast(place, abs_error, 2, 5);
#if defined(LITE_WITH_XPU) || defined(LITE_WITH_HUAWEI_ASCEND_NPU)
#if defined(LITE_WITH_HUAWEI_ASCEND_NPU)
TestCast(place, abs_error, 3, 5);
TestCast(place, abs_error, 5, 3);
#endif
......
......@@ -21,6 +21,7 @@
namespace paddle {
namespace lite {
template <class T = float, class R = int64_t>
class GatherComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
......@@ -53,9 +54,9 @@ class GatherComputeTest : public arena::TestCase {
out_dims[0] = batch_size;
out->Resize(out_dims);
auto x_data = x->data<int64_t>();
auto index_data = index->data<int64_t>();
auto out_data = out->mutable_data<int64_t>();
auto x_data = x->template data<T>();
auto index_data = index->template data<R>();
auto out_data = out->template mutable_data<T>();
auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production();
......@@ -66,7 +67,7 @@ class GatherComputeTest : public arena::TestCase {
CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index;
memcpy(out_data + i * slice_size,
x_data + index * slice_size,
slice_size * sizeof(int64_t));
slice_size * sizeof(T));
}
}
......@@ -78,11 +79,12 @@ class GatherComputeTest : public arena::TestCase {
}
void PrepareData() override {
std::vector<int64_t> x(x_dims_.production());
fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production());
std::vector<T> x(x_dims_.production());
fill_data_rand(
x.data(), static_cast<T>(-1), static_cast<T>(1), x_dims_.production());
std::vector<int64_t> index(index_dims_.production());
fill_data_rand<int64_t>(
std::vector<R> index(index_dims_.production());
fill_data_rand<R>(
index.data(), 0, x_dims_[0] - 1, index_dims_.production());
SetCommonTensor(x_, x_dims_, x.data());
......@@ -90,8 +92,20 @@ class GatherComputeTest : public arena::TestCase {
}
};
template <class T = float, class R = int64_t>
void TestGather(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& index_dims,
Place place,
float abs_error = 1e-5,
const std::string& alias = "def") {
std::unique_ptr<arena::TestCase> tester(new GatherComputeTest<T, R>(
place, alias, DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
TEST(Gather, precision) {
float abs_error = 2e-5;
float abs_error = 1e-5;
Place place;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
......@@ -110,10 +124,14 @@ TEST(Gather, precision) {
for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) {
std::unique_ptr<arena::TestCase> tester(new GatherComputeTest(
place, "int64", DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
#if defined(LITE_WITH_XPU) || defined(LITE_WITH_NPU)
TestGather<float, int>(x_dims, index_dims, place, abs_error, "def");
#else
TestGather<float, int64_t>(x_dims, index_dims, place, abs_error, "int64");
TestGather<int64_t, int64_t>(
x_dims, index_dims, place, abs_error, "int64");
TestGather<float, int>(x_dims, index_dims, place, abs_error, "int32");
#endif
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册