未验证 提交 645e81f0 编写于 作者: F Frank Lin 提交者: GitHub

Improve stablity of Paddle-TensorRT FP16 UT GitHub (1) (#51554)

* Improve Readability and Overall Clarity of Logging

* Adds the set_input_type API for specifying input data types

* Specifying input data types
上级 4b85e5db
此差异已折叠。
...@@ -54,8 +54,8 @@ class TensorConfig: ...@@ -54,8 +54,8 @@ class TensorConfig:
if data_gen is not None: if data_gen is not None:
self.data_gen = data_gen self.data_gen = data_gen
self.data = data_gen() self.data = data_gen()
self.dtype = data_gen().dtype self.dtype = self.data.dtype
self.shape = data_gen().shape self.shape = self.data.shape
else: else:
assert ( assert (
shape is not None shape is not None
...@@ -67,6 +67,11 @@ class TensorConfig: ...@@ -67,6 +67,11 @@ class TensorConfig:
def __repr__(self): def __repr__(self):
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype}) return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})
def astype(self, type: np.dtype):
self.data = self.data.astype(type)
self.dtype = self.data.dtype
return self
class VarType(enum.Enum): class VarType(enum.Enum):
LOD_TENSOR = 1 LOD_TENSOR = 1
...@@ -270,6 +275,16 @@ class ProgramConfig: ...@@ -270,6 +275,16 @@ class ProgramConfig:
return log_str return log_str
def set_input_type(self, type: np.dtype):
for inp in self.inputs.values():
inp.astype(type)
for weight in self.weights.values():
weight.astype(type)
return self
def get_input_type(self) -> np.dtype:
return next(iter(self.inputs.values())).dtype
def create_fake_model(program_config): def create_fake_model(program_config):
'''Create a Paddle model(in memory) according to the given config.''' '''Create a Paddle model(in memory) according to the given config.'''
......
...@@ -33,6 +33,9 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest): ...@@ -33,6 +33,9 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
def generate_elewise_input(): def generate_elewise_input():
return np.random.random([1, 12, 128, 128]).astype(np.float32) return np.random.random([1, 12, 128, 128]).astype(np.float32)
def generate_weight(shape):
return np.random.random(shape).astype(np.float32)
mul_0 = OpConfig( mul_0 = OpConfig(
"mul", "mul",
inputs={"X": ["mul_x"], "Y": ["mul_0_w"]}, inputs={"X": ["mul_x"], "Y": ["mul_0_w"]},
...@@ -195,13 +198,27 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest): ...@@ -195,13 +198,27 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
), ),
}, },
weights={ weights={
"mul_0_w": TensorConfig(shape=[768, 768]), "mul_0_w": TensorConfig(
"mul_1_w": TensorConfig(shape=[768, 768]), data_gen=partial(generate_weight, [768, 768])
"mul_2_w": TensorConfig(shape=[768, 768]), ),
"mul_3_w": TensorConfig(shape=[768, 768]), "mul_1_w": TensorConfig(
"ele_0_w": TensorConfig(shape=[768]), data_gen=partial(generate_weight, [768, 768])
"ele_1_w": TensorConfig(shape=[768]), ),
"ele_2_w": TensorConfig(shape=[768]), "mul_2_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_3_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"ele_0_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_1_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_2_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
}, },
outputs=[ops[-1].outputs["Out"][0]], outputs=[ops[-1].outputs["Out"][0]],
) )
......
...@@ -103,11 +103,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest): ...@@ -103,11 +103,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
ver = paddle_infer.get_trt_compile_version() ver = paddle_infer.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if trt_version >= 8400: if trt_version >= 8400:
if self.dims == 1 and not dynamic_shape: if self.dims == 1:
return 0, 3 return 0, 3
return 1, 2 return 1, 2
else: else:
if (self.dims == 1 and not dynamic_shape) or ( if self.dims <= 2 or (
program_config.inputs['input_data'].dtype program_config.inputs['input_data'].dtype
in ['bool', 'int8', 'uint8'] in ['bool', 'int8', 'uint8']
): ):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册