From f664a533f022de6890d14b7b14e8f874a9e4dcd2 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Mon, 27 Dec 2021 17:12:12 +0800 Subject: [PATCH] add matmulv2_transpose_reshape_pass ut (#37416) * update mkldnn matmul_v2_transpose_reshape_fuse_pass ut * update mkldnn matmul_v2_transpose_reshape_fuse_pass ut * update ut * update ut --- paddle/fluid/operators/matmul_v2_op.cc | 25 ++- .../unittests/ir/inference/CMakeLists.txt | 1 + ...n_matmul_v2_transpose_reshape_fuse_pass.py | 197 ++++++++++++------ 3 files changed, 159 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 24201b1ba8..5add86f5b3 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -194,9 +194,32 @@ class MatMulV2Op : public framework::OperatorWithKernel { "received %d", reshape_out_size)); - auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); + // int num_negative = std::count(reshape_out.begin(), reshape_out.end(), + // -1); + // PADDLE_ENFORCE_LE(num_negative, 1, + // platform::errors::InvalidArgument( + // "The max number of -1 in fused_reshape_Out is 1 " + // "but received %d.", + // num_negative)); + + // auto it_zero = std::find(reshape_out.begin(), reshape_out.end(), 0); + // if (it_zero != reshape_out.end()) { + // for (uint64_t i = 0; i < reshape_out.size(); i++) { + // if (reshape_out[i] == 0) { + // PADDLE_ENFORCE_LT( + // i, ddim_out.size(), + // platform::errors::InvalidArgument( + // "The index of 0 in fused_reshape_Out ", + // "should be less than output dim size, ", + // "but the index is %d and output dim size is %d", i, + // ddim_out.size())); + // reshape_out[i] = ddim_out.at(i); + // } + // } + // } // if "-1" is present then one of reshape dims must be infered + auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); if (it != reshape_out.end()) { int index = std::distance(reshape_out.begin(), it); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index e69328f5fc..347293b4a4 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -101,6 +101,7 @@ if (WITH_MKLDNN) set_tests_properties(test_mkldnn_conv_hard_sigmoid_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_conv_hard_swish_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_batch_norm_act_fuse_pass PROPERTIES TIMEOUT 100) + set_tests_properties(test_mkldnn_matmul_v2_transpose_reshape_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_mkldnn_conv_transpose_bias_fuse_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 300) endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py index 698e399c71..ffdc84b8bd 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py @@ -12,71 +12,142 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import AnalysisConfig -from paddle.fluid.core import PassVersionChecker - - -class TestMatmulV2OneDNNTransposeReshapeFusePass(InferencePassTest): - def setUp(self): - self.set_params() - self.tranpose_perm = [0, 2, 1, 3] - self.pass_name = 'matmul_v2_transpose_reshape_fuse_pass' - - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=self.data_shape, dtype="float32") - weight = fluid.layers.create_parameter( - shape=self.weight_shape, dtype="float32") - matmul = paddle.matmul( - data, - weight, - transpose_x=self.transpose_x, - transpose_y=self.transpose_y) - transpose = fluid.layers.transpose(matmul, self.tranpose_perm) - reshape = fluid.layers.reshape(transpose, shape=self.reshape_shape) - - self.fetch_list = [reshape] - self.enable_mkldnn = True - - def set_params(self): - self.data_shape = [-1, 3, 100, 110] - self.weight_shape = [1, 3, 110, 100] - self.feeds = { - "data": np.random.random((1, 3, 100, 110)).astype("float32") - } - self.transpose_x = False - self.transpose_y = False - self.reshape_shape = [3, 100, 100] - - def test_check_output(self): - use_gpu = False - self.check_output_with_option(use_gpu) - - def test_pass_compatible(self): - self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) - - -class TestMatmulV2OneDNNTransposeReshapeFusePassDifferentDims( - TestMatmulV2OneDNNTransposeReshapeFusePass): - def set_params(self): - self.data_shape = [-1, 4, 100, 80] - self.weight_shape = [1, 4, 80, 100] - self.feeds = { - "data": np.random.random((1, 4, 100, 80)).astype("float32") - } - self.transpose_x = True - self.transpose_y = True - self.reshape_shape = [8, 40, 80] +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + if program_config.inputs["input_data1"].shape[ + -4] != 1 and program_config.inputs["input_data2"].shape[ + -4] != 1: + if program_config.inputs["input_data1"].shape[ + -4] != program_config.inputs["input_data2"].shape[-4]: + return False + + if program_config.inputs["input_data1"].shape[ + -3] != 1 and program_config.inputs["input_data2"].shape[ + -3] != 1: + if program_config.inputs["input_data1"].shape[ + -3] != program_config.inputs["input_data2"].shape[-3]: + return False + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + # If the problem has been fixed, the judgment + # needs to be deleted!!! + if 0 in attrs[2]['shape']: + return False + + return True + + def sample_program_config(self, draw): + transpose_X = draw(st.booleans()) + transpose_Y = draw(st.booleans()) + axis = draw(st.sampled_from([[0, 2, 1, 3]])) + shape = draw(st.sampled_from([[0, -1, 128], [-1, 1, 64], [1, -1, 32]])) + batch_size1 = draw(st.integers(min_value=1, max_value=4)) + batch_size2 = draw(st.integers(min_value=1, max_value=4)) + channel1 = draw(st.sampled_from([1, 16, 32, 64])) + channel2 = draw(st.sampled_from([1, 16, 32, 64])) + input_dim = draw(st.sampled_from([16, 32, 64])) + + def generate_input(type): + if transpose_X and transpose_Y: + shape_x = [batch_size1, channel1, input_dim, 32] + shape_y = [batch_size2, channel2, 64, input_dim] + elif transpose_X: + shape_x = [batch_size1, channel1, input_dim, 32] + shape_y = [batch_size2, channel2, input_dim, 64] + elif transpose_Y: + shape_x = [batch_size1, channel1, 32, input_dim] + shape_y = [batch_size2, channel2, 8, input_dim] + else: + shape_x = [batch_size1, channel1, 32, input_dim] + shape_y = [batch_size2, channel2, input_dim, 16] + + if type == "x": + return np.random.random(shape_x).astype(np.float32) + else: + return np.random.random(shape_y).astype(np.float32) + + matmul_op = OpConfig( + type="matmul_v2", + inputs={"X": ["input_data1"], + "Y": ["input_data2"]}, + outputs={"Out": ["matmul_output"]}, + attrs={ + "trans_x": transpose_X, + "trans_y": transpose_Y, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [] + }) + + transpose2_op = OpConfig( + type="transpose2", + inputs={"X": ["matmul_output"]}, + outputs={ + "Out": ["transpose2_output"], + "XShape": ["transpose2_xshape"] + }, + attrs={'axis': axis}) + + reshape2_op = OpConfig( + type="reshape2", + inputs={"X": ["transpose2_output"]}, + outputs={ + "Out": ["reshape2_output"], + "XShape": ["reshape2_xshape"] + }, + attrs={'shape': shape}) + + model_net = [matmul_op, transpose2_op, reshape2_op] + + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + "input_data1": + TensorConfig(data_gen=partial(generate_input, "x")), + "input_data2": + TensorConfig(data_gen=partial(generate_input, "y")) + }, + outputs=["reshape2_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + # map_matmul_v2_to_matmul_pass will affect the type of final fused op + fused_op = "matmul_v2" + input1_dim1 = program_config.inputs["input_data1"].shape[0] + input2_dim1 = program_config.inputs["input_data2"].shape[0] + input1_dim2 = program_config.inputs["input_data1"].shape[1] + input2_dim2 = program_config.inputs["input_data2"].shape[1] + if input1_dim1 == input2_dim1 and input1_dim2 == input2_dim2: + fused_op = "matmul" + + config = self.create_inference_config(use_mkldnn=True) + yield config, [fused_op], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=["matmul_v2_transpose_reshape_fuse_pass"]) if __name__ == "__main__": - paddle.enable_static() unittest.main() -- GitLab