diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a34c2e9aa87348807f89c9c0b2c3ff633f6a7b7f..0d9c460628e17186152462c313937aff5490e723 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -97,6 +97,7 @@ pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(layer_norm_fuse_pass inference) pass_library(add_support_int8_pass inference) +pass_library(matmul_scale_fuse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6a5bca7bde47ecb2856b46922f7232dcde02045d..314f791da4f462416f076168ae9f75fc4bd9a37a 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1699,6 +1699,49 @@ PDNode *patterns::MatmulV2::operator()() { return matmul_v2_out; } +PDNode *patterns::MatmulScale::operator()() { + auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "X"); + auto matmul_in_y = pattern->NewNode(matmul_in_y_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "Y"); + auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale"); + auto scale_in_x = pattern->NewNode(scale_in_x_repr()) + ->assert_is_op_output("matmul", "Out") + ->assert_is_op_input("scale", "X"); + auto scale_out = pattern->NewNode(scale_out_repr()) + ->AsOutput() + ->assert_is_op_output("scale", "Out"); + matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({scale_in_x}); + scale_op->LinksFrom({scale_in_x}).LinksTo({scale_out}); + return scale_out; +} + +PDNode *patterns::MatmulV2Scale::operator()() { + auto matmul_v2_op = + pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2"); + auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr()) + ->AsInput() + ->assert_is_op_input("matmul_v2", "X"); + auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr()) + ->AsInput() + ->assert_is_persistable_var() // Y is weight + ->assert_is_op_input("matmul_v2", "Y"); + auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale"); + auto scale_in_x = pattern->NewNode(scale_in_x_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_input("scale", "X"); + auto scale_out = pattern->NewNode(scale_out_repr()) + ->AsOutput() + ->assert_is_op_output("scale", "Out"); + matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y}) + .LinksTo({scale_in_x}); + scale_op->LinksFrom({scale_in_x}).LinksTo({scale_out}); + return scale_out; +} + PDNode *patterns::Squeeze2Matmul::operator()() { auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr()) ->assert_is_op_input("squeeze2", "X") diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 7d143129ebd346f6af1c2637566094076306d63d..deaba36ba5da2475fcec02f8adee5486a67aaccd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1032,6 +1032,36 @@ struct MatmulV2 : public PatternBase { PATTERN_DECL_NODE(matmul_v2_out); }; +// Matmul + scale +// Forward pass. +struct MatmulScale : public PatternBase { + MatmulScale(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "matmul_scale") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(matmul_in_x); + PATTERN_DECL_NODE(matmul_in_y); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(scale_in_x); + PATTERN_DECL_NODE(scale_op); + PATTERN_DECL_NODE(scale_out); +}; + +// Matmul_v2 + scale +// Forward pass. +struct MatmulV2Scale : public PatternBase { + MatmulV2Scale(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "matmul_v2_scale") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(matmul_v2_in_x); + PATTERN_DECL_NODE(matmul_v2_in_y); + PATTERN_DECL_NODE(matmul_v2_op); + PATTERN_DECL_NODE(scale_in_x); + PATTERN_DECL_NODE(scale_op); + PATTERN_DECL_NODE(scale_out); +}; + // Squeeze2 + Matmul // Forward pass. struct Squeeze2Matmul : public PatternBase { diff --git a/paddle/fluid/framework/ir/matmul_scale_fuse_pass.cc b/paddle/fluid/framework/ir/matmul_scale_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2335e5eee01dbe234dd2673895245bdb5e9f5898 --- /dev/null +++ b/paddle/fluid/framework/ir/matmul_scale_fuse_pass.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/matmul_scale_fuse_pass.h" + +#include +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Node; + +MatmulScaleFusePass::MatmulScaleFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End() + .AddAttr("alpha") + .IsType() + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("ScaleTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("bias_after_scale") + .IsType() + .End() + .AddAttr("scale") + .End() + .AddAttr("bias") + .IsNumEQ(0.0f) + .End(); +} + +MatmulV2ScaleFusePass::MatmulV2ScaleFusePass() { + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("scale")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("ScaleTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("bias_after_scale") + .IsType() + .End() + .AddAttr("scale") + .End() + .AddAttr("bias") + .IsNumEQ(0.0f) + .End(); +} + +void MatmulScaleFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "matmul_scale_fuse"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::MatmulScale matmul_scale_pattern(gpd.mutable_pattern(), name_scope); + matmul_scale_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "matmul_scale_fuse pass"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_in_x, scale_in_x, matmul_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, matmul_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, matmul_scale_pattern); + + auto* scope = param_scope(); + float bias = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias")); + if (std::abs(bias) > 1e-5) return; + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "matmul_scale_fuse_pass in op compat failed."; + return; + } + + float scale = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale")); + float matmul_alpha = + BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + auto const& names = scale_op->Op()->InputNames(); + bool has_scale_tensor = + std::find(names.begin(), names.end(), "ScaleTensor") != names.end(); + if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) { + std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front(); + auto* scale_var = scope->FindVar(scale_var_name); + // ScaleTensor must be weight + if (scale_var == nullptr) return; + auto* scale_tensor = scale_var->GetMutable(); + scale = *(scale_tensor->data()); + } + + OpDesc* matmul_desc = matmul_op->Op(); + matmul_desc->SetAttr("alpha", scale * matmul_alpha); + matmul_desc->SetOutput("Out", {scale_out->Name()}); + if (!IsCompat(*matmul_desc)) { + LOG(WARNING) << "matmul_scale_fuse_pass in out mul op compat failed."; + return; + } + IR_NODE_LINK_TO(matmul_op, scale_out); + GraphSafeRemoveNodes(graph, {scale_in_x, scale_op}); + ++found_count; + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void MatmulV2ScaleFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "matmul_v2_scale_fuse"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::MatmulV2Scale matmul_v2_scale_pattern(gpd.mutable_pattern(), + name_scope); + matmul_v2_scale_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "matmul_v2_scale_fuse pass"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, + matmul_v2_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, + matmul_v2_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, + matmul_v2_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_in_x, scale_in_x, matmul_v2_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, matmul_v2_scale_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, matmul_v2_scale_pattern); + + auto* scope = param_scope(); + float bias = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias")); + if (std::abs(bias) > 1e-5) return; + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "matmul_v2_scale_fuse_pass in op compat failed."; + return; + } + + float scale = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale")); + auto const& names = scale_op->Op()->InputNames(); + bool has_scale_tensor = + std::find(names.begin(), names.end(), "ScaleTensor") != names.end(); + if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) { + std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front(); + auto* scale_var = scope->FindVar(scale_var_name); + // ScaleTensor must be weight + if (scale_var == nullptr) return; + auto* scale_tensor = scale_var->GetMutable(); + scale = *(scale_tensor->data()); + } + + auto* matmul_y = + scope->FindVar(matmul_v2_in_y->Name())->GetMutable(); + auto y_data = matmul_y->mutable_data(platform::CPUPlace()); + for (int i = 0; i < matmul_y->numel(); ++i) { + y_data[i] *= scale; + } + + OpDesc* matmul_v2_desc = matmul_v2_op->Op(); + matmul_v2_desc->SetOutput("Out", {scale_out->Name()}); + if (!IsCompat(*matmul_v2_desc)) { + LOG(WARNING) << "matmul_v2_scale_fuse_pass in out mul op compat failed."; + return; + } + IR_NODE_LINK_TO(matmul_v2_op, scale_out); + GraphSafeRemoveNodes(graph, {scale_in_x, scale_op}); + ++found_count; + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(matmul_scale_fuse_pass, + paddle::framework::ir::MatmulScaleFusePass); +REGISTER_PASS_CAPABILITY(matmul_scale_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("scale", 0)); + +REGISTER_PASS(matmul_v2_scale_fuse_pass, + paddle::framework::ir::MatmulV2ScaleFusePass); +REGISTER_PASS_CAPABILITY(matmul_v2_scale_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .EQ("scale", 0)); diff --git a/paddle/fluid/framework/ir/matmul_scale_fuse_pass.h b/paddle/fluid/framework/ir/matmul_scale_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..bdab7121f1b12b99451bebc04161ce7e36dce080 --- /dev/null +++ b/paddle/fluid/framework/ir/matmul_scale_fuse_pass.h @@ -0,0 +1,54 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +/* + * Fuse the matmul and scale to a matmul. + */ +class MatmulScaleFusePass : public FusePassBase { + public: + MatmulScaleFusePass(); + virtual ~MatmulScaleFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Fuse the matmul_v2 and scale to a matmul_v2. + */ +class MatmulV2ScaleFusePass : public FusePassBase { + public: + MatmulV2ScaleFusePass(); + virtual ~MatmulV2ScaleFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e96f182e3b1b0c1f3803d24686f19def657b1240..674b7cdd6993519d685a8d114e0e0b931c7fe3cc 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -205,8 +205,10 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "squeeze2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", // + "matmul_v2_scale_fuse_pass", // "map_matmul_v2_to_mul_pass", // "map_matmul_v2_to_matmul_pass", // + "matmul_scale_fuse_pass", // "map_matmul_to_mul_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // @@ -258,8 +260,8 @@ void CpuPassStrategy::EnableMKLDNN() { "matmul_transpose_reshape_fuse_pass", // "matmul_v2_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up - // "fc_mkldnn_pass", - // "fc_act_mkldnn_fuse_pass", + // "fc_mkldnn_pass", + // "fc_act_mkldnn_fuse_pass", "batch_norm_act_fuse_pass", // "softplus_activation_mkldnn_fuse_pass", // // TODO(intel): Please fix the bug on windows. diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 0b4f963ef88f1f95ca3adf790516278f16daaf8a..e2002765b1fd9a130ee86ddc8ee0f4065d0db35c 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -83,7 +83,10 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_elementwise_add2_act_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 90) + set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 60) + set_tests_properties(test_matmul_v2_scale_fuse_pass PROPERTIES TIMEOUT 60) endif() if (WITH_MKLDNN) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_matmul_scale_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_matmul_scale_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..9c10ff18fa1f1f37db3bf2fca04d164799449629 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_matmul_scale_fuse_pass.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestMatmulScaleFusePass(PassAutoScanTest): + """ + x_var y_var(persistable) + \ / + matmul + | + scale + """ + + def sample_predictor_configs(self, program_config): + # cpu + config = self.create_inference_config(use_gpu=False) + yield config, ["matmul", ], (1e-5, 1e-5) + + # mkldnn + config = self.create_inference_config(use_mkldnn=True) + yield config, ["matmul", ], (1e-5, 1e-5) + + def sample_program_config(self, draw): + # 1. Generate shape and attr of matmul + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=5)) + x_shape_rank = len(x_shape) + y_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), + min_size=x_shape_rank, + max_size=x_shape_rank)) + y_shape_rank = len(y_shape) + y_shape[-2] = x_shape[-1] + for i in range(y_shape_rank - 3, -1, -1): + j = x_shape_rank - (y_shape_rank - i) + if j < 0 or j >= x_shape_rank: + break + y_shape[i] = x_shape[j] + + transpose_X = False + transpose_Y = False + alpha = draw(st.floats(min_value=-2.0, max_value=2.0, width=32)) + # scale tensor + scale_shape = [1] + scale_value = draw(st.floats(min_value=-5.0, max_value=5.0, width=32)) + + matmul_op = OpConfig( + "matmul", + inputs={"X": ["matmul_x"], + "Y": ["matmul_y"]}, + outputs={"Out": ["matmul_out"]}, + transpose_X=transpose_X, + transpose_Y=transpose_Y, + alpha=alpha, + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_X=[], + fused_transpose_Y=[], + fused_reshape_Out=[], + fused_transpose_Out=[], + head_number=1, ) + is_scale_tensor = draw(st.booleans()) + if is_scale_tensor: + scale_op = OpConfig( + "scale", + inputs={"X": ["matmul_out"], + "ScaleTensor": ["scale_tensor"]}, + outputs={"Out": ["scale_out"]}, + scale=scale_value, + bias=0.0, + bias_after_scale=draw(st.booleans()), ) + else: + scale_op = OpConfig( + "scale", + inputs={"X": ["matmul_out"], }, + outputs={"Out": ["scale_out"]}, + scale=scale_value, + bias=0.0, + bias_after_scale=draw(st.booleans()), ) + + ops = [matmul_op, scale_op] + weights = {} + inputs = {} + if is_scale_tensor: + weights = { + "matmul_y": TensorConfig(shape=y_shape), + "scale_tensor": TensorConfig(shape=scale_shape) + } + inputs = {"matmul_x": TensorConfig(shape=x_shape), } + else: + inputs = { + "matmul_x": TensorConfig(shape=x_shape), + "matmul_y": TensorConfig(shape=y_shape), + } + + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs=inputs, + outputs=ops[-1].outputs["Out"], ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=100, + passes=["matmul_scale_fuse_pass"], ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_matmul_v2_scale_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_matmul_v2_scale_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..47bd5623646a7e880758eeae728f19f2680b8119 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_matmul_v2_scale_fuse_pass.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestMatmulV2ScaleFusePass(PassAutoScanTest): + """ + x_var y_var(persistable) x_var y_var*scale(persistable) + \ / \ / + matmul_v2 matmul_v2 + | => | + scale scale_out + | + scale_out + """ + + def sample_predictor_configs(self, program_config): + # for cpu + # config = self.create_inference_config(use_gpu=False) + # yield config, ["matmul_v2", ], (1e-5, 1e-5) + + # mkldnn + config = self.create_inference_config(use_mkldnn=True) + yield config, ["matmul_v2", ], (1e-5, 1e-5) + + def sample_program_config(self, draw): + # 1. Generate shape and attr of matmul + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=5)) + x_shape_rank = len(x_shape) + y_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), + min_size=x_shape_rank, + max_size=x_shape_rank)) + y_shape_rank = len(y_shape) + y_shape[-2] = x_shape[-1] + for i in range(y_shape_rank - 3, -1, -1): + j = x_shape_rank - (y_shape_rank - i) + if j < 0 or j >= x_shape_rank: + break + y_shape[i] = x_shape[j] + + transpose_X = False + transpose_Y = False + # scale tensor + scale_shape = [1] + scale_value = draw(st.floats(min_value=-5.0, max_value=5.0, width=32)) + + matmul_v2_op = OpConfig( + "matmul_v2", + inputs={"X": ["matmul_x"], + "Y": ["matmul_y"]}, + outputs={"Out": ["matmul_out"]}, + trans_x=transpose_X, + trans_y=transpose_Y, + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_X=[], + fused_transpose_Y=[], + fused_reshape_Out=[], + fused_transpose_Out=[], ) + is_scale_tensor = draw(st.booleans()) + if is_scale_tensor: + scale_op = OpConfig( + "scale", + inputs={"X": ["matmul_out"], + "ScaleTensor": ["scale_tensor"]}, + outputs={"Out": ["scale_out"]}, + scale=scale_value, + bias=0.0, + bias_after_scale=draw(st.booleans()), ) + else: + scale_op = OpConfig( + "scale", + inputs={"X": ["matmul_out"], }, + outputs={"Out": ["scale_out"]}, + scale=scale_value, + bias=0.0, + bias_after_scale=draw(st.booleans()), ) + + ops = [matmul_v2_op, scale_op] + weights = {"matmul_y": TensorConfig(shape=y_shape), } + if is_scale_tensor: + weights["scale_tensor"] = TensorConfig(shape=scale_shape) + inputs = {"matmul_x": TensorConfig(shape=x_shape), } + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs=inputs, + outputs=ops[-1].outputs["Out"], ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=100, + passes=["matmul_v2_scale_fuse_pass"], ) + + +if __name__ == "__main__": + unittest.main()