提交 f1bc89c7 编写于 作者: H huzhiqiang 提交者: juncaipeng

fix op inputs and outputs type (#2647) (#2702)

* fix op inputs and outputs type, test=develop
Co-authored-by: Njuncaipeng <52520497+juncaipeng@users.noreply.github.com>
上级 9c8e4642
......@@ -54,12 +54,12 @@ REGISTER_LITE_KERNEL(unsqueeze,
kNCHW,
paddle::lite::kernels::host::UnsqueezeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(unsqueeze2,
......@@ -68,11 +68,11 @@ REGISTER_LITE_KERNEL(unsqueeze2,
kNCHW,
paddle::lite::kernels::host::Unsqueeze2Compute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -54,7 +54,8 @@ REGISTER_LITE_KERNEL(yolo_box,
paddle::lite::kernels::arm::YoloBoxCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ImgSize", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ImgSize",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Scores", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -107,6 +107,7 @@ class UnsqueezeComputeTester : public arena::TestCase {
}
void PrepareData() override {
SetPrecisionType(out_, PRECISION(kFloat));
std::vector<float> in_data(dims_.production());
for (int i = 0; i < dims_.production(); ++i) {
in_data[i] = i;
......@@ -213,6 +214,7 @@ class Unsqueeze2ComputeTester : public arena::TestCase {
}
void PrepareData() override {
SetPrecisionType(out_, PRECISION(kFloat));
std::vector<float> in_data(dims_.production());
for (int i = 0; i < dims_.production(); ++i) {
in_data[i] = i;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册