From d3fecf01a654223041c043255aa2c7345929c1e0 Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Mon, 10 Aug 2020 02:43:21 +0000 Subject: [PATCH] fix ttfnet bug. test=develop --- lite/kernels/host/where_index_compute.cc | 2 +- lite/operators/compare_op.cc | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/lite/kernels/host/where_index_compute.cc b/lite/kernels/host/where_index_compute.cc index d06be8d332..1bec0460b2 100644 --- a/lite/kernels/host/where_index_compute.cc +++ b/lite/kernels/host/where_index_compute.cc @@ -169,5 +169,5 @@ REGISTER_LITE_KERNEL(where_index, kHost, kAny, kAny, whereindex, def) .BindInput("Condition", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))}) + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) .Finalize(); diff --git a/lite/operators/compare_op.cc b/lite/operators/compare_op.cc index f458eae71e..fb1c63275b 100644 --- a/lite/operators/compare_op.cc +++ b/lite/operators/compare_op.cc @@ -30,7 +30,14 @@ bool CompareOp::InferShapeImpl() const { CHECK_OR_FALSE(param_.Out); // TODO(Superjomn) Enable data sharing. auto input_dims = param_.X->dims(); - param_.Out->Resize(input_dims); + std::vector new_dims; + if (input_dims.size() == 2 && input_dims[1] == 1) { + new_dims.push_back(input_dims[0]); + param_.Out->Resize(new_dims); + } else { + param_.Out->Resize(input_dims); + } + // param_.Out->Resize(input_dims); return true; } -- GitLab