From f1bc89c777c70014abcdb4393a6057d86c4dbfae Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Mon, 30 Dec 2019 16:31:38 +0800 Subject: [PATCH] fix op inputs and outputs type (#2647) (#2702) * fix op inputs and outputs type, test=develop Co-authored-by: juncaipeng <52520497+juncaipeng@users.noreply.github.com> --- lite/kernels/arm/unsqueeze_compute.cc | 8 ++++---- lite/kernels/arm/yolo_box_compute.cc | 3 ++- lite/tests/kernels/unsqueeze_compute_test.cc | 2 ++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lite/kernels/arm/unsqueeze_compute.cc b/lite/kernels/arm/unsqueeze_compute.cc index e623407c2e..91c8c0423b 100644 --- a/lite/kernels/arm/unsqueeze_compute.cc +++ b/lite/kernels/arm/unsqueeze_compute.cc @@ -54,12 +54,12 @@ REGISTER_LITE_KERNEL(unsqueeze, kNCHW, paddle::lite::kernels::host::UnsqueezeCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("AxesTensor", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("AxesTensorList", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); REGISTER_LITE_KERNEL(unsqueeze2, @@ -68,11 +68,11 @@ REGISTER_LITE_KERNEL(unsqueeze2, kNCHW, paddle::lite::kernels::host::Unsqueeze2Compute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("AxesTensor", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("AxesTensorList", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/yolo_box_compute.cc b/lite/kernels/arm/yolo_box_compute.cc index 1336e5e1e0..ad8a630b8c 100644 --- a/lite/kernels/arm/yolo_box_compute.cc +++ b/lite/kernels/arm/yolo_box_compute.cc @@ -54,7 +54,8 @@ REGISTER_LITE_KERNEL(yolo_box, paddle::lite::kernels::arm::YoloBoxCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("ImgSize", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index 22e475672a..590d3fd29c 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -107,6 +107,7 @@ class UnsqueezeComputeTester : public arena::TestCase { } void PrepareData() override { + SetPrecisionType(out_, PRECISION(kFloat)); std::vector in_data(dims_.production()); for (int i = 0; i < dims_.production(); ++i) { in_data[i] = i; @@ -213,6 +214,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase { } void PrepareData() override { + SetPrecisionType(out_, PRECISION(kFloat)); std::vector in_data(dims_.production()); for (int i = 0; i < dims_.production(); ++i) { in_data[i] = i; -- GitLab