未验证 提交 856cb9c5 编写于 作者: J jakpiase 提交者: GitHub

Added matmul_v2+transpose+reshape fuse pass (#36481)

* added base changes for matmul_v2+trans+resh fuse pass

* added full matmul_v2+transpose+reshape pass

* removed a file added by mistake

* added reviewers suggestions

* Changed ops type in checking capatibility version

* Deteled one statement
上级 0ca2807c
...@@ -123,6 +123,7 @@ if(WITH_MKLDNN) ...@@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_fuse_pass inference DIR mkldnn) pass_library(multi_gru_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
...@@ -192,7 +193,7 @@ endif() ...@@ -192,7 +193,7 @@ endif()
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass) cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass)
......
...@@ -2697,16 +2697,18 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( ...@@ -2697,16 +2697,18 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
return matmul_out; return matmul_out;
} }
PDNode *patterns::MatmulTransposeReshapePattern::operator()() { // shared function for matmul and matmul_v2
PDNode *patterns::MatmulTransposeReshapePattern::operator()(
const std::string &op_name) {
auto reshape_op = auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op = auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);
auto matmul_out = pattern->NewNode(matmul_out_repr()) auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsInput() ->AsInput()
->assert_is_op_output("matmul", "Out") ->assert_is_op_output(op_name, "Out")
->assert_is_op_input("transpose2", "X"); ->assert_is_op_input("transpose2", "X");
auto transpose_out = pattern->NewNode(transpose_out_repr()) auto transpose_out = pattern->NewNode(transpose_out_repr())
......
...@@ -1546,7 +1546,7 @@ struct MatmulTransposeReshapePattern : public PatternBase { ...@@ -1546,7 +1546,7 @@ struct MatmulTransposeReshapePattern : public PatternBase {
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_transpose_reshape") {} : PatternBase(pattern, name_scope, "matmul_transpose_reshape") {}
PDNode* operator()(); PDNode* operator()(const std::string& op_name);
PATTERN_DECL_NODE(matmul_op); PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out); PATTERN_DECL_NODE(matmul_out);
......
...@@ -23,7 +23,9 @@ namespace framework { ...@@ -23,7 +23,9 @@ namespace framework {
namespace ir { namespace ir {
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul")) op_name_ = "matmul";
AddOpCompat(OpCompat(op_name_))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
.End() .End()
...@@ -89,7 +91,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -89,7 +91,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(), patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
name_scope_); name_scope_);
mtrp(); mtrp(op_name_);
int found_matmul_transpose_reshape_count = 0; int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
...@@ -98,7 +100,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -98,7 +100,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "Pass in op compat failed."; LOG(WARNING) << "Pass in op compat failed.";
return; return;
} }
VLOG(4) << "handle matmul_transpose_reshape fuse"; VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp); GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
...@@ -118,17 +120,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -118,17 +120,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
const bool supported_transpose_axis = std::equal( const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin()); transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) { if (transpose_out_size != 4) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: " VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size; << "supported rank is 4, received " << transpose_out_size;
return; return;
} }
if (!supported_transpose_axis) { if (!supported_transpose_axis) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: " VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}"; << "supported transpose axis for the fuse are {0, 2, 1, 3}";
return; return;
} }
if (reshape_out_size != 3) { if (reshape_out_size != 3) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: " VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received " << "reshape_out supported rank is 3, received "
<< reshape_out_size; << reshape_out_size;
return; return;
...@@ -152,7 +154,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { ...@@ -152,7 +154,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
if (!Has("disable_logs") || !Get<bool>("disable_logs")) { if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
<< " MatmulTransposeReshape patterns"; << " MatmulTransposeReshape patterns for " + op_name_ + " Op";
paddle::string::PrettyLogDetail(msg_ss.str().c_str()); paddle::string::PrettyLogDetail(msg_ss.str().c_str());
} }
} }
......
...@@ -31,6 +31,7 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase { ...@@ -31,6 +31,7 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"}; const std::string name_scope_{"matmul_transpose_reshape_fuse"};
std::string op_name_;
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -42,9 +42,15 @@ void SetOp(ProgramDesc *prog, const std::string &type, ...@@ -42,9 +42,15 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetAttr("transpose_X", true); op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true); op->SetAttr("transpose_Y", true);
} }
if (type == "matmul_v2") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("trans_x", true);
op->SetAttr("trans_y", true);
}
} }
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc(const std::string &op_name) {
ProgramDesc prog; ProgramDesc prog;
for (auto &v : std::initializer_list<std::string>( for (auto &v : std::initializer_list<std::string>(
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) { {"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
...@@ -52,7 +58,7 @@ ProgramDesc BuildProgramDesc() { ...@@ -52,7 +58,7 @@ ProgramDesc BuildProgramDesc() {
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
} }
SetOp(&prog, "matmul", {"a1", "a2"}, {"b"}); SetOp(&prog, op_name, {"a1", "a2"}, {"b"});
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"}); SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"}); SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
SetOp(&prog, "fc", {"d"}, {"e"}); SetOp(&prog, "fc", {"d"}, {"e"});
...@@ -60,13 +66,13 @@ ProgramDesc BuildProgramDesc() { ...@@ -60,13 +66,13 @@ ProgramDesc BuildProgramDesc() {
return prog; return prog;
} }
void MainTest(const ProgramDesc &prog) { void MainTest(const ProgramDesc &prog, const std::string &op_name) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
auto pass = auto pass =
PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass"); PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass");
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size(); int current_nodes_num = graph->Nodes().size();
...@@ -75,7 +81,7 @@ void MainTest(const ProgramDesc &prog) { ...@@ -75,7 +81,7 @@ void MainTest(const ProgramDesc &prog) {
for (auto *node : graph->Nodes()) { for (auto *node : graph->Nodes()) {
if (node->IsOp()) { if (node->IsOp()) {
auto *op = node->Op(); auto *op = node->Op();
if (op->Type() == "matmul") { if (op->Type() == op_name) {
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"), EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
std::vector<int>({4, 5, 6})); std::vector<int>({4, 5, 6}));
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"), EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
...@@ -85,12 +91,18 @@ void MainTest(const ProgramDesc &prog) { ...@@ -85,12 +91,18 @@ void MainTest(const ProgramDesc &prog) {
} }
} }
TEST(MatmulTransposeReshapeFusePass, matmul_inputs) { TEST(MatmulTransposeReshapeFusePass, matmul_fuse_pass) {
auto prog = BuildProgramDesc(); auto prog = BuildProgramDesc("matmul");
MainTest(prog); MainTest(prog, "matmul");
}
TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
auto prog = BuildProgramDesc("matmul_v2");
MainTest(prog, "matmul_v2");
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(matmul_transpose_reshape_fuse_pass); USE_PASS(matmul_transpose_reshape_fuse_pass);
USE_PASS(matmul_v2_transpose_reshape_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass);
REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulV2TransposeReshapeMKLDNNPass
: public MatmulTransposeReshapeMKLDNNPass {
public:
MatmulV2TransposeReshapeMKLDNNPass();
virtual ~MatmulV2TransposeReshapeMKLDNNPass() {}
protected:
const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", // "matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass", // "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass", // "fc_act_mkldnn_fuse_pass",
......
...@@ -39,4 +39,12 @@ extra { ...@@ -39,4 +39,12 @@ extra {
name: "op_device" name: "op_device"
type: STRING type: STRING
} }
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
} }
...@@ -90,8 +90,62 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -90,8 +90,62 @@ class MatMulV2Op : public framework::OperatorWithKernel {
new_dims.push_back(1); new_dims.push_back(1);
} }
auto out_dims = framework::make_ddim(new_dims); auto ddim_out = framework::make_ddim(new_dims);
ctx->SetOutputDim("Out", out_dims);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul_v2+transpose+reshape fuse activated
auto reshape_out = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto transpose_out =
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
// if "-1" is present then one of reshape dims must be infered
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);
auto ddim_out_vec = framework::vectorize(ddim_out);
int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());
reshape_out[index] = ddim_out_product / reshape_out_product;
}
framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
ctx->SetOutputDim("Out", shape_out);
} else {
ctx->SetOutputDim("Out", ddim_out);
}
#else
ctx->SetOutputDim("Out", ddim_out);
#endif
ctx->ShareLoD("X", /* --> */ "Out"); ctx->ShareLoD("X", /* --> */ "Out");
} }
...@@ -139,6 +193,18 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -139,6 +193,18 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
"Set true to transpose the last two dimensions of Y before " "Set true to transpose the last two dimensions of Y before "
"doing multiplication") "doing multiplication")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>(
"fused_reshape_Out",
R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, "
"it's a shape atribute of fused reshape for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>(
"fused_transpose_Out",
R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, "
"it's a axis atribute of fused transpose for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false) .SetDefault(false)
......
...@@ -36,7 +36,8 @@ class MatMulV2MKLDNNHandler ...@@ -36,7 +36,8 @@ class MatMulV2MKLDNNHandler
MatMulV2MKLDNNHandler(const mkldnn::engine engine, MatMulV2MKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place, paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x, const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y) const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine, : paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
// M X K * K X N // M X K * K X N
...@@ -86,6 +87,10 @@ class MatMulV2MKLDNNHandler ...@@ -86,6 +87,10 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
if (is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides); auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides); auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides); auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
...@@ -93,6 +98,24 @@ class MatMulV2MKLDNNHandler ...@@ -93,6 +98,24 @@ class MatMulV2MKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
} }
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
...@@ -116,7 +139,8 @@ class MatMulV2MKLDNNKernel ...@@ -116,7 +139,8 @@ class MatMulV2MKLDNNKernel
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims, bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const { int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims, MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y); trans_x, y_dims, trans_y,
IsOutputFused(ctx));
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
...@@ -133,9 +157,10 @@ class MatMulV2MKLDNNKernel ...@@ -133,9 +157,10 @@ class MatMulV2MKLDNNKernel
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
auto format = paddle::platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN); out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format( out->set_format(format);
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
} }
private: private:
...@@ -166,7 +191,8 @@ class MatMulV2MKLDNNKernel ...@@ -166,7 +191,8 @@ class MatMulV2MKLDNNKernel
} }
} }
if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) { if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2 &&
!IsOutputFused(ctx)) {
for (size_t i = 0; i < x_dims.size() - 2; ++i) { for (size_t i = 0; i < x_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
...@@ -181,6 +207,13 @@ class MatMulV2MKLDNNKernel ...@@ -181,6 +207,13 @@ class MatMulV2MKLDNNKernel
} }
} }
bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
void RunKernel(const ExecutionContext& ctx) const { void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import PassVersionChecker
class TestMatmulV2OneDNNTransposeReshapeFusePass(InferencePassTest):
def setUp(self):
self.set_params()
self.tranpose_perm = [0, 2, 1, 3]
self.pass_name = 'matmul_v2_transpose_reshape_fuse_pass'
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=self.data_shape, dtype="float32")
weight = fluid.layers.create_parameter(
shape=self.weight_shape, dtype="float32")
matmul = paddle.matmul(
data,
weight,
transpose_x=self.transpose_x,
transpose_y=self.transpose_y)
transpose = fluid.layers.transpose(matmul, self.tranpose_perm)
reshape = fluid.layers.reshape(transpose, shape=self.reshape_shape)
self.fetch_list = [reshape]
self.enable_mkldnn = True
def set_params(self):
self.data_shape = [-1, 3, 100, 110]
self.weight_shape = [1, 3, 110, 100]
self.feeds = {
"data": np.random.random((1, 3, 100, 110)).astype("float32")
}
self.transpose_x = False
self.transpose_y = False
self.reshape_shape = [3, 100, 100]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
def test_pass_compatible(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class TestMatmulV2OneDNNTransposeReshapeFusePassDifferentDims(
TestMatmulV2OneDNNTransposeReshapeFusePass):
def set_params(self):
self.data_shape = [-1, 4, 100, 80]
self.weight_shape = [1, 4, 80, 100]
self.feeds = {
"data": np.random.random((1, 4, 100, 80)).astype("float32")
}
self.transpose_x = True
self.transpose_y = True
self.reshape_shape = [8, 40, 80]
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
...@@ -440,9 +440,11 @@ class TestMatMulOpTransposeReshapeEmptyFloat(OpTest): ...@@ -440,9 +440,11 @@ class TestMatMulOpTransposeReshapeEmptyFloat(OpTest):
self.reshape_out = [] self.reshape_out = []
self.out = np.matmul(self.x, self.y) self.out = np.matmul(self.x, self.y)
def setUp(self): def set_op_type(self):
os.environ["DNNL_MAX_CPU_ISA"] = "AVX"
self.op_type = "matmul" self.op_type = "matmul"
def setUp(self):
self.set_op_type()
self._cpu_only = True self._cpu_only = True
self.use_mkldnn = True self.use_mkldnn = True
self.init_data_type() self.init_data_type()
......
...@@ -23,6 +23,13 @@ import paddle.fluid.core as core ...@@ -23,6 +23,13 @@ import paddle.fluid.core as core
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import (
TestMatMulOpTransposeReshapeEmptyFloat,
TestMatMulOpTransposeReshapeBasicFloat,
TestMatMulOpTransposeReshapeOtherDimFloat,
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException,
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException,
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException)
def reference_matmul(X, Y, transpose_x=False, transpose_y=False): def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
...@@ -390,6 +397,43 @@ create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp) ...@@ -390,6 +397,43 @@ create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp)
class TestMatMulV2OpTransposeReshapeEmptyFloat(
TestMatMulOpTransposeReshapeEmptyFloat):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeBasicFloat):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeOtherDimFloat(
TestMatMulOpTransposeReshapeOtherDimFloat):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeTransposeAxisNotSupportedException(
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册