未验证 提交 ebc0e39c 编写于 作者: H huzhiqiang 提交者: GitHub

[ARM] Add int64 implement for `gather` and `greater_than` (#4342)

上级 aa228ed2
...@@ -73,10 +73,10 @@ void GatherCompute<IndexType>::Run() { ...@@ -73,10 +73,10 @@ void GatherCompute<IndexType>::Run() {
REGISTER_LITE_KERNEL(gather, REGISTER_LITE_KERNEL(gather,
kARM, kARM,
kAny, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::arm::GatherCompute<int32_t>, paddle::lite::kernels::arm::GatherCompute<int32_t>,
def) int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index", .BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
...@@ -85,10 +85,10 @@ REGISTER_LITE_KERNEL(gather, ...@@ -85,10 +85,10 @@ REGISTER_LITE_KERNEL(gather,
REGISTER_LITE_KERNEL(gather, REGISTER_LITE_KERNEL(gather,
kARM, kARM,
kAny, kFloat,
kNCHW, kNCHW,
paddle::lite::kernels::arm::GatherCompute<int64_t>, paddle::lite::kernels::arm::GatherCompute<int64_t>,
def_int64_idx) int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index", .BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
......
...@@ -24,7 +24,7 @@ namespace kernels { ...@@ -24,7 +24,7 @@ namespace kernels {
namespace arm { namespace arm {
template <typename IndexType> template <typename IndexType>
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> { class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
void Run() override; void Run() override;
......
...@@ -230,6 +230,21 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def) ...@@ -230,6 +230,21 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def)
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
using greater_than_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_GreaterThanFunctor<int64_t>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kInt64, kAny, greater_than_int64, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using greater_equal_float = paddle::lite::kernels::host::CompareCompute< using greater_equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat), PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterEqualFunctor<float>>; paddle::lite::kernels::host::_GreaterEqualFunctor<float>>;
...@@ -245,3 +260,19 @@ REGISTER_LITE_KERNEL( ...@@ -245,3 +260,19 @@ REGISTER_LITE_KERNEL(
{LiteType::GetTensorTy( {LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)}) TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
using greater_equal_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_GreaterEqualFunctor<int64_t>>;
REGISTER_LITE_KERNEL(
greater_equal, kHost, kInt64, kAny, greater_equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
...@@ -53,9 +53,9 @@ class GatherComputeTest : public arena::TestCase { ...@@ -53,9 +53,9 @@ class GatherComputeTest : public arena::TestCase {
out_dims[0] = batch_size; out_dims[0] = batch_size;
out->Resize(out_dims); out->Resize(out_dims);
auto x_data = x->data<float>(); auto x_data = x->data<int64_t>();
auto index_data = index->data<int>(); auto index_data = index->data<int64_t>();
auto out_data = out->mutable_data<float>(); auto out_data = out->mutable_data<int64_t>();
auto slice_num = x_dims[0]; auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production(); auto slice_size = x_dims.Slice(1, x_dims.size()).production();
...@@ -66,7 +66,7 @@ class GatherComputeTest : public arena::TestCase { ...@@ -66,7 +66,7 @@ class GatherComputeTest : public arena::TestCase {
CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index; CHECK_GE(index, 0) << "gather ids[i] expected >= 0 but got " << index;
memcpy(out_data + i * slice_size, memcpy(out_data + i * slice_size,
x_data + index * slice_size, x_data + index * slice_size,
slice_size * sizeof(float)); slice_size * sizeof(int64_t));
} }
} }
...@@ -78,11 +78,11 @@ class GatherComputeTest : public arena::TestCase { ...@@ -78,11 +78,11 @@ class GatherComputeTest : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
std::vector<float> x(x_dims_.production()); std::vector<int64_t> x(x_dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, x_dims_.production()); fill_data_rand(x.data(), int64_t(-1), int64_t(1), x_dims_.production());
std::vector<int32_t> index(index_dims_.production()); std::vector<int64_t> index(index_dims_.production());
fill_data_rand<int32_t>( fill_data_rand<int64_t>(
index.data(), 0, x_dims_[0] - 1, index_dims_.production()); index.data(), 0, x_dims_[0] - 1, index_dims_.production());
SetCommonTensor(x_, x_dims_, x.data()); SetCommonTensor(x_, x_dims_, x.data());
...@@ -110,8 +110,8 @@ TEST(Gather, precision) { ...@@ -110,8 +110,8 @@ TEST(Gather, precision) {
for (auto x_dims : for (auto x_dims :
std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) { std::vector<std::vector<int64_t>>{{5, 2, 3, 4}, {8, 3, 5}, {12, 3}}) {
for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) { for (auto index_dims : std::vector<std::vector<int64_t>>{{3}, {7}, {10}}) {
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(new GatherComputeTest(
new GatherComputeTest(place, "def", DDim(x_dims), DDim(index_dims))); place, "int64", DDim(x_dims), DDim(index_dims)));
arena::Arena arena(std::move(tester), place, abs_error); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册