未验证 提交 b4d3597a 编写于 作者: B baoachun 提交者: GitHub

update inference ut to support nhwc format (#39551)

* update inference ut to support nhwc format

* update ut and pass OpCompat

* update ut

* update ut
上级 1354652b
...@@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() { ...@@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() {
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.AddAttr("data_format") .AddAttr("data_format")
.IsOptional() .IsOptional()
.IsStringIn({"NCHW", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("relu")) AddOpCompat(OpCompat("relu"))
......
...@@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
...@@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("axis") .AddAttr("axis")
.IsIntIn({1}) .IsIntIn({1, 3})
.End(); .End();
} }
......
...@@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() { ...@@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("concat")) AddOpCompat(OpCompat("concat"))
......
...@@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): ...@@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
data_format = prog_config.ops[0].attrs["data_format"] data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID": if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
...@@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): ...@@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
x_shape = draw( x_shape = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=100), min_size=4, max_size=4)) min_value=5, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10)) x_shape[1] = draw(st.integers(min_value=5, max_value=10))
# 2. Generate legal attr:data_format of conv2d # 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"])) data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
...@@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): ...@@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
f_shape = draw( f_shape = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=7), min_size=4, max_size=4)) min_value=1, max_value=5), min_size=4, max_size=4))
if data_format == "NCHW": if data_format == "NCHW":
f_shape[1] = x_shape[1] f_shape[1] = x_shape[1]
else: else:
......
...@@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): ...@@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
data_format = prog_config.ops[0].attrs["data_format"] data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID": if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \ if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1: ((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
...@@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): ...@@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
x_shape = draw( x_shape = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=100), min_size=4, max_size=4)) min_value=5, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10)) x_shape[1] = draw(st.integers(min_value=5, max_value=10))
# 2. Generate legal attr:data_format of conv2d # 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"])) data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
...@@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): ...@@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
f_shape = draw( f_shape = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=7), min_size=4, max_size=4)) min_value=1, max_value=4), min_size=4, max_size=4))
if data_format == "NCHW": if data_format == "NCHW":
f_shape[1] = x_shape[1] f_shape[1] = x_shape[1]
else: else:
...@@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): ...@@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
strides = draw( strides = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=5), min_size=2, max_size=2)) min_value=1, max_value=4), min_size=2, max_size=2))
# 5. Generate legal attr:padding_algorithm of conv2d # 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
...@@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): ...@@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
padding = draw( padding = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=5), min_size=4, max_size=4)) min_value=1, max_value=4), min_size=4, max_size=4))
# 7. Generate legal attr:groups of conv2d # 7. Generate legal attr:groups of conv2d
groups = draw(st.integers(min_value=1, max_value=3)) groups = draw(st.integers(min_value=1, max_value=3))
...@@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest): ...@@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
dilations = draw( dilations = draw(
st.lists( st.lists(
st.integers( st.integers(
min_value=1, max_value=5), min_size=2, max_size=2)) min_value=1, max_value=4), min_size=2, max_size=2))
# 9. Generate legal shape of input:bias of elementwise_add # 9. Generate legal shape of input:bias of elementwise_add
bias_shape = [f_shape[0]] bias_shape = [f_shape[0]]
......
...@@ -27,15 +27,6 @@ import hypothesis.strategies as st ...@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvConcatReluMkldnnFusePass(PassAutoScanTest): class TestConvConcatReluMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
......
...@@ -27,15 +27,6 @@ import hypothesis.strategies as st ...@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvGeluMkldnnFusePass(PassAutoScanTest): class TestConvGeluMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
...@@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): ...@@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["conv2d"], (1e-5, 1e-5)
# If the problem has been fixed, the judgment
# needs to be deleted!!!
def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC":
return True
return False
self.add_ignore_check_case(
teller1, SkipReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC"
)
def test(self): def test(self):
self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"])
......
...@@ -27,15 +27,6 @@ import hypothesis.strategies as st ...@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
......
...@@ -27,15 +27,6 @@ import hypothesis.strategies as st ...@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): class TestConvHardSwishMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
......
...@@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): ...@@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
for i in range(len(program_config.ops)) for i in range(len(program_config.ops))
] ]
# If the problem has been fixed, the judgment if attrs[0]['data_format'] == "NCHW" and attrs[1]["axis"] == 3:
# needs to be deleted!!! return False
if attrs[0]['data_format'] == "NHWC": if attrs[0]['data_format'] == "NHWC" and attrs[1]["axis"] == 1:
return False return False
return True return True
...@@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): ...@@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
groups = draw(st.sampled_from([1, 2, 4, 8])) groups = draw(st.sampled_from([1, 2, 4, 8]))
paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
axis = draw(st.sampled_from([1])) axis = draw(st.sampled_from([1, 3]))
batch_size = draw(st.integers(min_value=1, max_value=4)) batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input(): def generate_input():
...@@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest): ...@@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, passes=["conv_transpose_bias_mkldnn_fuse_pass"]) quant=False,
max_duration=300,
passes=["conv_transpose_bias_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册