From a8dee3bbe6de6312d88393481b5cafde3c56a128 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 24 Aug 2021 20:25:38 +0800 Subject: [PATCH] fix convert ut framework problem (#35112) --- .../fluid/tests/unittests/ir/inference/program_config.py | 7 +++++-- .../unittests/ir/inference/trt_layer_auto_scan_test.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py index 6b465f4b2af..e570796c36a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/program_config.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/program_config.py @@ -137,8 +137,11 @@ def create_fake_model(program_config): op_desc._set_attr(name, values) for name, values in op_config.outputs.items(): op_desc.set_output(name, values) - var_desc = main_block_desc.var(cpt.to_bytes(name)) - var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR) + for v in values: + var_desc = main_block_desc.var(cpt.to_bytes(v)) + var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR) + var_desc.set_dtype( + convert_np_dtype_to_dtype_(tensor_config.dtype)) op_desc.infer_var_type(main_block_desc) op_desc.infer_shape(main_block_desc) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py index bf6fc7a24a3..90e69a9fce3 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/trt_layer_auto_scan_test.py @@ -100,6 +100,12 @@ class TrtLayerAutoScanTest(AutoScanTest): attrs=op_attr)) self.update_program_input_and_weight_with_attr(op_attr_list) + # if no weight need to save, we create a place_holder to help seriazlie params. + if not self.program_weights: + self.program_weights = { + "place_holder_weight": TensorConfig( + shape=[1], data=np.array([1]).astype(np.float32)) + } program_config = ProgramConfig( ops=ops, weights=self.program_weights, -- GitLab