未验证 提交 b4d3597a 编写于 作者: B baoachun 提交者: GitHub

update inference ut to support nhwc format (#39551)

* update inference ut to support nhwc format

* update ut and pass OpCompat

* update ut

* update ut
上级 1354652b
......@@ -157,7 +157,7 @@ ConvActivationFusePass::ConvActivationFusePass() {
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NCHW", "AnyLayout"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("relu"))
......
......@@ -115,7 +115,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
......@@ -129,7 +129,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({1})
.IsIntIn({1, 3})
.End();
}
......
......@@ -59,7 +59,7 @@ ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("concat"))
......
......@@ -53,8 +53,6 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
......@@ -80,8 +78,8 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
min_value=5, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=5, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
......@@ -90,7 +88,7 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=7), min_size=4, max_size=4))
min_value=1, max_value=5), min_size=4, max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
......
......@@ -53,8 +53,6 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
data_format = prog_config.ops[0].attrs["data_format"]
filter_shape = prog_config.weights["filter"].shape
input_shape = prog_config.inputs["input_x"].shape
if data_format != "NCHW":
return False
if padding_algorithm == "VALID":
if ((input_shape[2] - (dilations[0] * (filter_shape[2] - 1) + 1)) / strides[0] + 1) <= 1 or \
((input_shape[3] - (dilations[1] * (filter_shape[3] - 1) + 1)) / strides[1] + 1) <= 1:
......@@ -80,8 +78,8 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=1, max_value=10))
min_value=5, max_value=100), min_size=4, max_size=4))
x_shape[1] = draw(st.integers(min_value=5, max_value=10))
# 2. Generate legal attr:data_format of conv2d
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
......@@ -90,7 +88,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
f_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=7), min_size=4, max_size=4))
min_value=1, max_value=4), min_size=4, max_size=4))
if data_format == "NCHW":
f_shape[1] = x_shape[1]
else:
......@@ -100,7 +98,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
min_value=1, max_value=4), min_size=2, max_size=2))
# 5. Generate legal attr:padding_algorithm of conv2d
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
......@@ -109,7 +107,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=4, max_size=4))
min_value=1, max_value=4), min_size=4, max_size=4))
# 7. Generate legal attr:groups of conv2d
groups = draw(st.integers(min_value=1, max_value=3))
......@@ -118,7 +116,7 @@ class TestConvBiasMkldnnFusePass(PassAutoScanTest):
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
min_value=1, max_value=4), min_size=2, max_size=2))
# 9. Generate legal shape of input:bias of elementwise_add
bias_shape = [f_shape[0]]
......
......@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvConcatReluMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True
def sample_program_config(self, draw):
......
......@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvGeluMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True
def sample_program_config(self, draw):
......@@ -108,19 +99,6 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
# If the problem has been fixed, the judgment
# needs to be deleted!!!
def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC":
return True
return False
self.add_ignore_check_case(
teller1, SkipReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC"
)
def test(self):
self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"])
......
......@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True
def sample_program_config(self, draw):
......
......@@ -27,15 +27,6 @@ import hypothesis.strategies as st
class TestConvHardSwishMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True
def sample_program_config(self, draw):
......
......@@ -32,9 +32,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
if attrs[0]['data_format'] == "NCHW" and attrs[1]["axis"] == 3:
return False
if attrs[0]['data_format'] == "NHWC" and attrs[1]["axis"] == 1:
return False
return True
......@@ -46,7 +46,7 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
groups = draw(st.sampled_from([1, 2, 4, 8]))
paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
axis = draw(st.sampled_from([1]))
axis = draw(st.sampled_from([1, 3]))
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input():
......@@ -110,7 +110,9 @@ class TestConvTransposeMkldnnFusePass(PassAutoScanTest):
def test(self):
self.run_and_statis(
quant=False, passes=["conv_transpose_bias_mkldnn_fuse_pass"])
quant=False,
max_duration=300,
passes=["conv_transpose_bias_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册