未验证 提交 645e81f0 编写于 作者: F Frank Lin 提交者: GitHub

Improve stablity of Paddle-TensorRT FP16 UT GitHub (1) (#51554)

* Improve Readability and Overall Clarity of Logging

* Adds the set_input_type API for specifying input data types

* Specifying input data types
上级 4b85e5db
此差异已折叠。
......@@ -54,8 +54,8 @@ class TensorConfig:
if data_gen is not None:
self.data_gen = data_gen
self.data = data_gen()
self.dtype = data_gen().dtype
self.shape = data_gen().shape
self.dtype = self.data.dtype
self.shape = self.data.shape
else:
assert (
shape is not None
......@@ -67,6 +67,11 @@ class TensorConfig:
def __repr__(self):
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})
def astype(self, type: np.dtype):
self.data = self.data.astype(type)
self.dtype = self.data.dtype
return self
class VarType(enum.Enum):
LOD_TENSOR = 1
......@@ -270,6 +275,16 @@ class ProgramConfig:
return log_str
def set_input_type(self, type: np.dtype):
for inp in self.inputs.values():
inp.astype(type)
for weight in self.weights.values():
weight.astype(type)
return self
def get_input_type(self) -> np.dtype:
return next(iter(self.inputs.values())).dtype
def create_fake_model(program_config):
'''Create a Paddle model(in memory) according to the given config.'''
......
......@@ -33,6 +33,9 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
def generate_elewise_input():
return np.random.random([1, 12, 128, 128]).astype(np.float32)
def generate_weight(shape):
return np.random.random(shape).astype(np.float32)
mul_0 = OpConfig(
"mul",
inputs={"X": ["mul_x"], "Y": ["mul_0_w"]},
......@@ -195,13 +198,27 @@ class TestMultiheadMatmulFusePass(PassAutoScanTest):
),
},
weights={
"mul_0_w": TensorConfig(shape=[768, 768]),
"mul_1_w": TensorConfig(shape=[768, 768]),
"mul_2_w": TensorConfig(shape=[768, 768]),
"mul_3_w": TensorConfig(shape=[768, 768]),
"ele_0_w": TensorConfig(shape=[768]),
"ele_1_w": TensorConfig(shape=[768]),
"ele_2_w": TensorConfig(shape=[768]),
"mul_0_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_1_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_2_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"mul_3_w": TensorConfig(
data_gen=partial(generate_weight, [768, 768])
),
"ele_0_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_1_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
"ele_2_w": TensorConfig(
data_gen=partial(generate_weight, [768])
),
},
outputs=[ops[-1].outputs["Out"][0]],
)
......
......@@ -103,11 +103,11 @@ class TrtConvertActivationTest(TrtLayerAutoScanTest):
ver = paddle_infer.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
if trt_version >= 8400:
if self.dims == 1 and not dynamic_shape:
if self.dims == 1:
return 0, 3
return 1, 2
else:
if (self.dims == 1 and not dynamic_shape) or (
if self.dims <= 2 or (
program_config.inputs['input_data'].dtype
in ['bool', 'int8', 'uint8']
):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册