From c3ba8056bd6b6308a74736e88037bbca8c0b1f88 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 21 Jul 2022 19:06:24 +0800 Subject: [PATCH] [Paddle-TRT] fix_fill_constant (#44481) * fix_fill_constant * fix_fill_constant * fix_ernie --- paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc | 4 ++++ .../ir/inference/test_trt_convert_fill_constant.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc b/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc index 3bbbfe03743..4d524c01b78 100644 --- a/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc @@ -32,6 +32,10 @@ class FillConstantOpConverter : public OpConverter { PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value")); std::vector shape = PADDLE_GET_CONST(std::vector, op_desc.GetAttr("shape")); + if (str_value == "") { + float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); + str_value = std::to_string(value); + } std::unique_ptr out_tensor(new framework::Tensor()); out_tensor->Resize(phi::make_ddim(shape)); nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py index 84ee70782ac..cc686be6d8a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py @@ -42,8 +42,14 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): for dtype in [5, 2, 3]: for str_value in ["2", "23", "-1"]: self.num_input = num_input + value = float(str_value) + if np.random.choice([False, True]): + str_value = str_value + else: + str_value = "" dics = [{ "str_value": str_value, + "value": value, "shape": shape, "dtype": dtype }, { -- GitLab