未验证 提交 4b5b5403 编写于 作者: H huzhiqiang 提交者: GitHub

[cherry-pick] [ARM] Add int64 implement for `gather` and `greater_than` (#4444)

上级 038c07f3
......@@ -73,10 +73,10 @@ void GatherCompute<IndexType>::Run() {
REGISTER_LITE_KERNEL(gather,
kARM,
kAny,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int32_t>,
def)
int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
......@@ -85,10 +85,10 @@ REGISTER_LITE_KERNEL(gather,
REGISTER_LITE_KERNEL(gather,
kARM,
kAny,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GatherCompute<int64_t>,
def_int64_idx)
int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
......
......@@ -24,7 +24,7 @@ namespace kernels {
namespace arm {
template <typename IndexType>
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
......
......@@ -230,6 +230,21 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def)
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using greater_than_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_GreaterThanFunctor<int64_t>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kInt64, kAny, greater_than_int64, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using greater_equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterEqualFunctor<float>>;
......@@ -245,3 +260,19 @@ REGISTER_LITE_KERNEL(
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_GreaterEqualFunctor<int64_t>>;
REGISTER_LITE_KERNEL(
greater_equal, kHost, kInt64, kAny, greater_equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
......@@ -53,9 +53,9 @@ class GatherComputeTest : public arena::TestCase {
out_dims[0] = batch_size;
out->Resize(out_dims);
auto x_data = x->data<float>();
auto index_data = index->data<int>();
auto out_data = out->mutable_data<float>();
auto x_data = x->data<int64_t>();
auto index_data = index->data<int64_t>();
auto out_data = out->mutable_data<int64_t>();
auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production();
......@@ -66,7 +66,7 @@ class GatherComputeTest : public arena::TestCase {
CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index;
memcpy(out_data + i * slice_size,
x_data + index * slice_size,
slice_size * sizeof(float));
slice_size * sizeof(int64_t));
}
}
......@@ -78,11 +78,11 @@ class GatherComputeTest : public arena::TestCase {
}
void PrepareData() override {
std::vector<float> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production());
std::vector<int64_t> x(x_dims_.production());
fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production());
std::vector<int32_t> index(index_dims_.production());
fill_data_rand<int32_t>(
std::vector<int64_t> index(index_dims_.production());
fill_data_rand<int64_t>(
index.data(), 0, x_dims_[0] - 1, index_dims_.production());
SetCommonTensor(x_, x_dims_, x.data());
......@@ -110,8 +110,8 @@ TEST(Gather, precision) {
for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) {
std::unique_ptr<arena::TestCase> tester(
new GatherComputeTest(place, "def", DDim(x_dims), DDim(index_dims)));
std::unique_ptr<arena::TestCase> tester(new GatherComputeTest(
place, "int64", DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册