未验证 提交 f07b8cbe 编写于 作者: F feng_shuai 提交者: GitHub

fix:the axis must be 1(channel), when the dims of bias is 1 (#39052)

上级 1e515aa8
......@@ -36,7 +36,7 @@ class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest):
for shape in [[32], [batch, 32], [batch, 32, 32],
[batch, 32, 16, 32]]:
for op_type in ["elementwise_add", "elementwise_mul"]:
for axis in [len(shape) - 1, -1]:
for axis in [-1 if len(shape) == 1 else 1]:
self.dims = len(shape)
dics = [{"axis": axis}]
ops_config = [{
......@@ -129,33 +129,7 @@ class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest):
True), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dims == 2 and len(self.dynamic_shape.max_input_shape) == 0:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape are not equal between gpu and tensorrt when input dim is 2."
)
def teller2(program_config, predictor_config):
if self.dims == 3:
return True
return False
self.add_skip_case(
teller2, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and tensorrt when input dim is 3.")
def teller3(program_config, predictor_config):
if self.dims == 4:
return True
return False
self.add_skip_case(
teller3, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and tensorrt when input dim is 4.")
pass
def test(self):
self.add_skip_trt_case()
......@@ -287,15 +261,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dims == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape are not equal between gpu and tensorrt when input dim is 2."
)
pass
def test(self):
self.add_skip_trt_case()
......@@ -418,15 +384,7 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
yield self.create_inference_config(), (1, 3), 1e-5
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(self.shape1) == 2:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output shape are not equal between gpu and tensorrt when input dim is 2."
)
pass
def test(self):
self.add_skip_trt_case()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册