From a842c1d0e0194146098314419c315d1b3e956e4a Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Mon, 5 Dec 2022 10:21:18 +0800 Subject: [PATCH] [Paddle Inference] Support fill_any_like bool input. (#48671) * fill_any_like_bool * fill_any_like_bool --- paddle/fluid/inference/tensorrt/op_teller.cc | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index d88de415e8..d8801bd8f5 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1226,17 +1226,26 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); + auto* block = desc.Block(); + auto* x_var_desc = block->FindVar(desc.Input("X")[0]); + auto input_type = x_var_desc->GetDataType(); +#if IS_TRT_VERSION_GE(8400) + if (dtype == 0 || + (dtype == -1 && input_type == framework::proto::VarType::BOOL)) { + VLOG(3) << "the fill_any_like supports input of BOOL by trt8.4 above"; + return true; + } +#endif if (dtype != -1 && dtype != 2 && dtype != 5) { - VLOG(3) << "the fill_any_like only supports int32 and float32"; + VLOG(3) << "the fill_any_like only supports int32 and float32 by " + "trt8.4 below"; return false; } if (dtype == -1) { - auto* block = desc.Block(); - auto* x_var_desc = block->FindVar(desc.Input("X")[0]); - auto input_type = x_var_desc->GetDataType(); if (input_type != framework::proto::VarType::INT32 && input_type != framework::proto::VarType::FP32) { - VLOG(3) << "the fill_any_like only supports int32 and float32"; + VLOG(3) << "the fill_any_like only supports int32 and float32 by " + "trt8.4 below"; return false; } } -- GitLab