未验证 提交 02938b3d 编写于 作者: S Sławomir Siwek 提交者: GitHub

[UT] mish op, conv+mish, fc+mish fuse passes (#39340)

* mish unit tests

* code format

* remove unused imports

* code format

* remove hard-coded shape values

* remove timeouts

* remove timeouts v2

* restore timeouts
上级 66b5348e
...@@ -441,6 +441,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -441,6 +441,7 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_swish_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_mish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_sigmoid_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_hard_sigmoid_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_gelu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_gelu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass', graph = self._apply_pass(graph, 'fc_fuse_pass',
......
...@@ -123,6 +123,7 @@ if (WITH_MKLDNN) ...@@ -123,6 +123,7 @@ if (WITH_MKLDNN)
set_tests_properties(test_mkldnn_conv_elementwise_add_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_mkldnn_conv_elementwise_add_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_depthwise_conv_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_mkldnn_depthwise_conv_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_reshape_transpose_matmul_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_mkldnn_reshape_transpose_matmul_fuse_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_mkldnn_mish_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 250) set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 250)
...@@ -134,5 +135,7 @@ if (WITH_MKLDNN) ...@@ -134,5 +135,7 @@ if (WITH_MKLDNN)
set_tests_properties(test_mkldnn_matmul_v2_transpose_reshape_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_mkldnn_matmul_v2_transpose_reshape_fuse_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_mkldnn_conv_transpose_bias_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_mkldnn_conv_transpose_bias_fuse_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_fc_mish_fuse_pass PROPERTIES TIMEOUT 300)
endif() endif()
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestConvMishMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [op.attrs for op in program_config.ops]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input():
if data_format == "NCHW":
return np.random.random(
[batch_size, 48, 64, 64]).astype(np.float32)
else:
return np.random.random(
[batch_size, 64, 64, 48]).astype(np.float32)
def generate_weight():
return np.random.random(
[16, int(48 / groups), 3, 3]).astype(np.float32)
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["input_weight"]
},
"op_outputs": {
"Output": ["conv_output"]
},
"op_attrs": {
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
}
}, {
"op_type": "mish",
"op_inputs": {
"X": ["conv_output"]
},
"op_outputs": {
"Out": ["mish_output"]
},
"op_attrs": {},
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"input_weight": TensorConfig(data_gen=partial(generate_weight))
},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input)),
},
outputs=["mish_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["conv_mish_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig
import numpy as np
import unittest
import hypothesis.strategies as st
class TestFCMishMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=128), min_size=2, max_size=3))
in_num_col_dims = len(x_shape) - 1
w_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=128), min_size=2, max_size=2))
w_shape[0] = int(np.prod(x_shape[in_num_col_dims:]))
fc_bias_shape = [w_shape[1]]
ops_config = [{
"op_type": "fc",
"op_inputs": {
"Input": ["fc_x"],
"W": ["fc_w"],
"Bias": ["fc_bias"]
},
"op_outputs": {
"Out": ["fc_out"]
},
"op_attrs": {
"activation_type": "",
"padding_weights": False,
"in_num_col_dims": in_num_col_dims,
"use_mkldnn": True
}
}, {
"op_type": "mish",
"op_inputs": {
"X": ["fc_out"]
},
"op_outputs": {
"Out": ["mish_output"]
},
"op_attrs": {},
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"fc_w": TensorConfig(shape=w_shape),
"fc_bias": TensorConfig(shape=fc_bias_shape),
},
inputs={"fc_x": TensorConfig(shape=x_shape), },
outputs=["mish_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"])
yield config, ["fc"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import MkldnnAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
from hypothesis import given
import hypothesis.strategies as st
class TestMkldnnMishOp(MkldnnAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
# if mode is channel, and in_shape is 1 rank
if len(program_config.inputs['input_data'].
shape) == 1 and program_config.ops[0].attrs['mode'] == 'channel':
return False
return True
def sample_program_configs(self, *args, **kwargs):
def generate_input(*args, **kwargs):
return np.random.random(kwargs['in_shape']).astype(np.float32)
mish_op = OpConfig(
type="mish",
inputs={"X": ["input_data"]},
outputs={"Out": ["output_data"]},
attrs={
"mode": kwargs['mode'],
"data_format": kwargs['data_format']
})
program_config = ProgramConfig(
ops=[mish_op],
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input,
*args, **kwargs)),
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, (1e-5, 1e-5)
@given(
mode=st.sampled_from(['all', 'channel', 'element']),
data_format=st.sampled_from(['NCHW', 'NHWC']),
in_shape=st.lists(
st.integers(
min_value=1, max_value=32), min_size=1, max_size=4))
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册