未验证 提交 d0cf9d9d 编写于 作者: S Sławomir Siwek 提交者: GitHub

Merge matmul_v1 and matmul_v2 fuse passes (#44870)

* remove v2_transpose_reshape

* matmul_transpose_reshape

* reshape_transpose_matmul

* restore ut

* adjust old ut

* restore parallel UT ruels

* feedback from review
上级 1f7e9546
......@@ -218,10 +218,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR
mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
......@@ -497,13 +494,11 @@ if(WITH_MKLDNN)
cc_test(
test_reshape_transpose_matmul_mkldnn_fuse_pass
SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc
DEPS reshape_transpose_matmul_mkldnn_fuse_pass
reshape_transpose_matmul_v2_mkldnn_fuse_pass)
DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(
test_matmul_transpose_reshape_fuse_pass
SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc
DEPS matmul_transpose_reshape_fuse_pass
matmul_v2_transpose_reshape_fuse_pass)
SRCS mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc
DEPS matmul_transpose_reshape_mkldnn_fuse_pass)
cc_test(
test_shuffle_channel_mkldnn_detect_pass
SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,12 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
#include <paddle/fluid/string/pretty_log.h>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -25,76 +21,28 @@ namespace paddle {
namespace framework {
namespace ir {
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
op_name_ = "matmul";
using string::PrettyLogDetail;
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha") // unconstrained. can be any float value.
.IsType<float>()
.End()
.AddAttr("transpose_X") // unconstrained. can be any bool value.
.IsType<bool>()
.End()
.AddAttr("transpose_Y") // unconstrained. can be any bool value.
.IsType<bool>()
.End();
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"};
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis") // ints
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape") // ints
.IsType<std::vector<int>>()
.End();
for (const auto &matmul_type : matmul_types) {
Fuse(graph, matmul_type);
}
}
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
void MatmulTransposeReshapeMKLDNNPass::Fuse(
Graph *graph, const std::string &matmul_type) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph);
FusePassBase::Init(matmul_type + "_transpose_reshape_mkldnn_fuse_pass",
graph);
GraphPatternDetector gpd;
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
name_scope_);
mtrp(op_name_);
patterns::MatmulTransposeReshapePattern mtrp(
gpd.mutable_pattern(),
matmul_type + "_transpose_reshape_mkldnn_fuse_pass");
mtrp(matmul_type);
int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
......@@ -103,7 +51,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
......@@ -112,6 +60,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp);
auto reshape_shape =
PADDLE_GET_CONST(std::vector<int>, reshape_op->Op()->GetAttr("shape"));
auto transpose_axis =
......@@ -123,17 +72,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) {
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size;
return;
}
if (!supported_transpose_axis) {
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
return;
}
if (reshape_out_size != 3) {
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
VLOG(3) << "do not perform " + matmul_type + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received "
<< reshape_out_size;
return;
......@@ -158,23 +107,93 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
gpd(graph, handler);
AddStatis(found_matmul_transpose_reshape_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
<< " MatmulTransposeReshape patterns for " + op_name_ + " Op";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_matmul_transpose_reshape_count > 0) {
PrettyLogDetail("--- fused %d %s + transpose + reshape patterns",
found_matmul_transpose_reshape_count,
matmul_type);
}
}
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_transpose_reshape_fuse_pass,
REGISTER_PASS(matmul_transpose_reshape_mkldnn_fuse_pass,
paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass);
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass)
REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("transpose", 0)
.EQ("reshape", 0));
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,14 +14,11 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
public:
......@@ -29,10 +26,10 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
virtual ~MatmulTransposeReshapeMKLDNNPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"};
std::string op_name_;
void ApplyImpl(Graph *graph) const override;
void Fuse(Graph *graph, const std::string &matmul_type) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
......@@ -74,7 +74,7 @@ void MainTest(const ProgramDesc &prog, const std::string &op_name) {
int original_nodes_num = graph->Nodes().size();
auto pass =
PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass");
PassRegistry::Instance().Get("matmul_transpose_reshape_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......@@ -106,5 +106,4 @@ TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
} // namespace framework
} // namespace paddle
USE_PASS(matmul_transpose_reshape_fuse_pass);
USE_PASS(matmul_v2_transpose_reshape_fuse_pass);
USE_PASS(matmul_transpose_reshape_mkldnn_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass);
REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulV2TransposeReshapeMKLDNNPass
: public MatmulTransposeReshapeMKLDNNPass {
public:
MatmulV2TransposeReshapeMKLDNNPass();
virtual ~MatmulV2TransposeReshapeMKLDNNPass() {}
protected:
const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,11 +13,6 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
......@@ -26,78 +21,46 @@ namespace paddle {
namespace framework {
namespace ir {
ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
op_name_ = "matmul";
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(Graph *graph) const {
auto matmul_types = {"matmul", "matmul_v2"};
bool with_reshape_xshape = true;
bool with_transpose_xshape = true;
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
for (const auto &matmul_type : matmul_types) {
Fuse(graph, matmul_type, with_reshape_xshape, with_transpose_xshape);
Fuse(graph, matmul_type, with_reshape_xshape, !with_transpose_xshape);
Fuse(graph, matmul_type, !with_reshape_xshape, with_transpose_xshape);
Fuse(graph, matmul_type, !with_reshape_xshape, !with_transpose_xshape);
}
}
void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const {
Graph *graph,
const std::string &matmul_type,
bool with_reshape_xshape,
bool with_transpose_xshape) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init("reshape_transpose_" + matmul_type + "_mkldnn_fuse_pass",
graph);
GraphPatternDetector gpd;
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
name_scope_);
patterns::ReshapeTransposeMatmulPattern rtm_pattern(
gpd.mutable_pattern(),
"reshape_transpose_" + matmul_type + "_mkldnn_fuse_pass");
rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape);
rtm_pattern(matmul_type, with_reshape_xshape, with_transpose_xshape);
int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_
LOG(WARNING) << "Op compatible check in reshape_transpose_" << matmul_type
<< "_mkldnn_fuse_pass failed.";
return;
}
VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
......@@ -137,7 +100,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
UpdateMatmul("Y");
} else {
throw platform::errors::InvalidArgument("Unexpected input to " +
op_name_ + " encountered.");
matmul_type + " encountered.");
}
std::unordered_set<const ir::Node *> nodes_to_remove{
......@@ -153,26 +116,85 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
gpd(graph, handler);
AddStatis(found_reshape_transpose_matmul_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_reshape_transpose_matmul_count > 0) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmul patterns for " << op_name_ << " Op";
msg_ss << "--- fused " << found_reshape_transpose_matmul_count
<< " reshape + transpose + " << matmul_type;
if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str());
}
}
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph);
ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();
Fuse(graph, false, false);
Fuse(graph, false, true);
Fuse(graph, true, false);
Fuse(graph, true, true);
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
......@@ -184,5 +206,8 @@ REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"matmul", 1));
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("transpose2", 0)
.EQ("matmul", 1)
.EQ("matmul_v2", 0));
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,17 +13,11 @@
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn.
*/
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public:
......@@ -31,13 +25,11 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"reshape_transpose_matmul_fuse"};
void ApplyImpl(Graph* graph) const override;
void Fuse(Graph* graph,
const std::string& matmul_type,
bool with_reshape_xshape,
bool with_transpose_xshape) const;
std::string op_name_;
};
} // namespace ir
} // namespace framework
......
......@@ -15,7 +15,6 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
......@@ -82,8 +81,8 @@ void TestMain(const std::string& op_name, bool with_xshapes) {
int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name +
"_mkldnn_fuse_pass");
auto pass =
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
......@@ -137,4 +136,3 @@ TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
} // namespace paddle
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
ReshapeTransposeMatmulV2MkldnnFusePass::
ReshapeTransposeMatmulV2MkldnnFusePass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass);
REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn.
*/
class ReshapeTransposeMatmulV2MkldnnFusePass
: public ReshapeTransposeMatmulMkldnnFusePass {
public:
ReshapeTransposeMatmulV2MkldnnFusePass();
virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {}
protected:
const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -302,9 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_activation_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
"matmul_transpose_reshape_mkldnn_fuse_pass", //
"matmul_elementwise_add_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up
......@@ -399,14 +397,12 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("fc_mkldnn_pass");
passes_.push_back("fc_act_mkldnn_fuse_pass");
passes_.push_back("matmul_transpose_reshape_fuse_pass");
passes_.push_back("matmul_v2_transpose_reshape_fuse_pass");
passes_.push_back("matmul_transpose_reshape_mkldnn_fuse_pass");
passes_.push_back("batch_norm_act_fuse_pass");
passes_.push_back("softplus_activation_mkldnn_fuse_pass");
passes_.push_back("compute_propagate_scales_mkldnn_pass");
passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("reshape_transpose_matmul_v2_mkldnn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
......
......@@ -448,8 +448,8 @@ class Quant2Int8MkldnnPass(object):
# Disabled due to topology-dependent speed-up
graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'fc_act_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
graph = self._apply_pass(graph, 'matmul_v2_transpose_reshape_fuse_pass')
graph = self._apply_pass(graph,
'matmul_transpose_reshape_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass')
# the following pass should be the last one since it will work on all fused ops.
......@@ -650,8 +650,6 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(
graph, 'reshape_transpose_matmul_v2_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'cpu_quantize_placement_pass',
['quantize_enabled_op_types'],
[self._ops_to_quantize])
......
......@@ -121,8 +121,8 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest):
yield config, ["matmul"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
passes=["matmul_transpose_reshape_fuse_pass"])
self.run_and_statis(
quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
......@@ -142,8 +142,8 @@ class TestMatmulv2TransposeReshapeMkldnnFusePass(PassAutoScanTest):
yield config, [fused_op], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False,
passes=["matmul_v2_transpose_reshape_fuse_pass"])
self.run_and_statis(
quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
......@@ -29,7 +29,7 @@ class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest):
def setUp(self):
self.set_params()
self.tranpose_perm = [0, 2, 1, 3]
self.pass_name = 'reshape_transpose_matmul_v2_mkldnn_fuse_pass'
self.pass_name = 'reshape_transpose_matmul_mkldnn_fuse_pass'
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(name="data",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册