diff --git a/lite/kernels/arm/unsqueeze_compute.cc b/lite/kernels/arm/unsqueeze_compute.cc index e623407c2e718a51b51e880a4d81df4ee0d96f87..91c8c0423b6fcc5bade5751985f190b3395b0779 100644 --- a/lite/kernels/arm/unsqueeze_compute.cc +++ b/lite/kernels/arm/unsqueeze_compute.cc @@ -54,12 +54,12 @@ REGISTER_LITE_KERNEL(unsqueeze, kNCHW, paddle::lite::kernels::host::UnsqueezeCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("AxesTensor", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("AxesTensorList", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); REGISTER_LITE_KERNEL(unsqueeze2, @@ -68,11 +68,11 @@ REGISTER_LITE_KERNEL(unsqueeze2, kNCHW, paddle::lite::kernels::host::Unsqueeze2Compute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("AxesTensor", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindInput("AxesTensorList", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/yolo_box_compute.cc b/lite/kernels/arm/yolo_box_compute.cc index 1336e5e1e0a6438a08f542d299eddc30d15dad15..ad8a630b8c0064af7358674d1b7424eff25a194a 100644 --- a/lite/kernels/arm/yolo_box_compute.cc +++ b/lite/kernels/arm/yolo_box_compute.cc @@ -54,7 +54,8 @@ REGISTER_LITE_KERNEL(yolo_box, paddle::lite::kernels::arm::YoloBoxCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("ImgSize", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index 22e475672a87dafee29d68a3824e4f8ac0c15615..590d3fd29c37e16cfeec53557a825a4acf9684ca 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -107,6 +107,7 @@ class UnsqueezeComputeTester : public arena::TestCase { } void PrepareData() override { + SetPrecisionType(out_, PRECISION(kFloat)); std::vector in_data(dims_.production()); for (int i = 0; i < dims_.production(); ++i) { in_data[i] = i; @@ -213,6 +214,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase { } void PrepareData() override { + SetPrecisionType(out_, PRECISION(kFloat)); std::vector in_data(dims_.production()); for (int i = 0; i < dims_.production(); ++i) { in_data[i] = i;