From aadc867467adfb16044c35c0ca6dcac4344d1e90 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Tue, 21 Dec 2021 19:51:59 +0800 Subject: [PATCH] update squared_mat_sub_fuse_pass ut (#37838) * update squared_mat_sub_fuse_pass ut * update ut * update ut --- .../framework/ir/squared_mat_sub_fuse_pass.cc | 7 +- .../test_squared_mat_sub_fuse_pass.py | 181 ++++++++++++++---- 2 files changed, 149 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index 62f1db426c4..7c43b022182 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -398,8 +398,7 @@ SquaredMatSubFusePass::SquaredMatSubFusePass() { .IsTensor() .End() .AddAttr("alpha") - .IsNumGE(0.99f) - .IsNumLE(1.01f) + .IsNumEQ(1.0f) .End() .AddAttr("transpose_X") .IsBoolEQ(false) @@ -465,6 +464,10 @@ SquaredMatSubFusePass::SquaredMatSubFusePass() { .End() // type:float,there is no restriction .AddAttr("value") + .End() + .AddAttr("str_value") + .IsStringEQ("") + .IsOptional() .End(); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py index 69a9ae3c0ad..64166daa91f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,53 +12,160 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - -import unittest +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -from inference_pass_test import InferencePassTest -import paddle -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import AnalysisConfig -from paddle.fluid.core import PassVersionChecker +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestSquaredMatSubFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_config(self, draw): + transpose_X = False + transpose_Y = False + alpha1 = 1.0 + alpha2 = 1.0 + axis1 = draw(st.sampled_from([-1, 0])) + place_type = draw(st.sampled_from([-1, 0])) + has_str_value = draw(st.booleans()) + str_value = '' + value = draw(st.floats(min_value=-10, max_value=10)) + shape = draw(st.sampled_from([[1]])) + axis2 = draw(st.sampled_from([-1, 0])) + input_dim = draw(st.sampled_from([32, 64])) + + def generate_input(type): + shape_x = [32, input_dim] + shape_y = [input_dim, 16] + + if type == "x": + return np.random.random(shape_x).astype(np.float32) + else: + return np.random.random(shape_y).astype(np.float32) + + matmul_op1 = OpConfig( + type="matmul", + inputs={"X": ["input_data1"], + "Y": ["input_data2"]}, + outputs={"Out": ["matmul1_output"]}, + attrs={ + "transpose_X": transpose_X, + "transpose_Y": transpose_Y, + "alpha": alpha1, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [] + }) + + square_op1 = OpConfig( + type="square", + inputs={"X": ["matmul1_output"]}, + outputs={"Out": ["square1_output"]}, + attrs={}) + + square_op2 = OpConfig( + type="square", + inputs={"X": ["input_data1"]}, + outputs={"Out": ["square2_output"]}, + attrs={}) + square_op3 = OpConfig( + type="square", + inputs={"X": ["input_data2"]}, + outputs={"Out": ["square3_output"]}, + attrs={}) -class SquaredMatSubFusePassTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data_a = fluid.data(name="data_a", shape=[128, 1], dtype="float32") - data_b = fluid.data(name="data_b", shape=[256, 1], dtype="float32") + matmul_op2 = OpConfig( + type="matmul", + inputs={"X": ["square2_output"], + "Y": ["square3_output"]}, + outputs={"Out": ["matmul2_output"]}, + attrs={ + "transpose_X": transpose_X, + "transpose_Y": transpose_Y, + "alpha": alpha2, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [] + }) - fc_a = fluid.layers.fc(data_a, size=256) - fc_b = fluid.layers.fc(data_b, size=64) + elt_sub_op = OpConfig( + type="elementwise_sub", + inputs={"X": ["square1_output"], + "Y": ["matmul2_output"]}, + outputs={"Out": ["sub_out"]}, + attrs={"axis": axis1}) - data_a_square = paddle.square(fc_a) - data_b_square = paddle.square(fc_b) + if has_str_value: + fill_constant_op = OpConfig( + type="fill_constant", + inputs={}, + outputs={"Out": ["constant_out"]}, + attrs={ + "dtype": 5, + "place_type": place_type, + "str_value": str_value, + "value": value, + "shape": shape + }) + else: + fill_constant_op = OpConfig( + type="fill_constant", + inputs={}, + outputs={"Out": ["constant_out"]}, + attrs={ + "dtype": 5, + "place_type": place_type, + "value": value, + "shape": shape + }) - matmul_ab = paddle.matmul(fc_a, fc_b) - matmul_ab_square = paddle.square(matmul_ab) - matmul_square_ab = paddle.matmul(data_a_square, data_b_square) + elt_mul_op = OpConfig( + type="elementwise_mul", + inputs={"X": ["sub_out"], + "Y": ["constant_out"]}, + outputs={"Out": ["mul_out"]}, + attrs={"axis": axis2}) - scale = paddle.fluid.layers.fill_constant( - shape=[1], value=0.5, dtype='float32') + model_net = [ + matmul_op1, square_op1, square_op2, square_op3, matmul_op2, + elt_sub_op, fill_constant_op, elt_mul_op + ] - sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, - matmul_square_ab) - squared_mat_sub_out = fluid.layers.elementwise_mul(sub_val, scale) + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + "input_data1": + TensorConfig(data_gen=partial(generate_input, "x")), + "input_data2": + TensorConfig(data_gen=partial(generate_input, "y")) + }, + outputs=["mul_out"]) - self.feeds = { - "data_a": np.random.random((128, 1)).astype("float32"), - "data_b": np.random.random((256, 1)).astype("float32") - } - self.fetch_list = [squared_mat_sub_out] + return program_config - def test_check_output(self): - use_gpu = False - self.check_output_with_option(use_gpu) + def sample_predictor_configs(self, program_config): + config = self.create_inference_config() + yield config, ["fusion_squared_mat_sub"], (1e-5, 1e-5) - self.assertTrue( - PassVersionChecker.IsCompatible('squared_mat_sub_fuse_pass')) + def test(self): + self.run_and_statis(quant=False, passes=["squared_mat_sub_fuse_pass"]) if __name__ == "__main__": -- GitLab