未验证 提交 a922168a 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add reshape+transpose+matmul_v2 only (#37847)

* reshape+transpose+matmul_v2

* in_name->input_name

* fix pr-ci-static-check
上级 6a852536
...@@ -123,6 +123,7 @@ if(WITH_MKLDNN) ...@@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
...@@ -190,7 +191,7 @@ endif() ...@@ -190,7 +191,7 @@ endif()
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass reshape_transpose_matmul_v2_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
......
...@@ -2711,12 +2711,13 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() { ...@@ -2711,12 +2711,13 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
} }
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
bool with_reshape_xshape, bool with_transpose_xshape) { const std::string &op_name, bool with_reshape_xshape,
bool with_transpose_xshape) {
auto reshape_op = auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op = auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);
auto reshape_in = pattern->NewNode(reshape_in_repr()) auto reshape_in = pattern->NewNode(reshape_in_repr())
->AsInput() ->AsInput()
...@@ -2737,7 +2738,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( ...@@ -2737,7 +2738,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
auto transpose_out = pattern->NewNode(transpose_out_repr()) auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("matmul") ->assert_is_op_input(op_name)
->assert_is_op_output("transpose2", "Out"); ->assert_is_op_output("transpose2", "Out");
if (!with_transpose_xshape) if (!with_transpose_xshape)
transpose_out->assert_is_only_output_of_op("transpose2"); transpose_out->assert_is_only_output_of_op("transpose2");
...@@ -2751,7 +2752,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( ...@@ -2751,7 +2752,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
auto matmul_out = pattern->NewNode(matmul_out_repr()) auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("matmul", "Out"); ->assert_is_op_output(op_name, "Out");
reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out}); reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape}); if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape});
......
...@@ -1570,7 +1570,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase { ...@@ -1570,7 +1570,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const std::string& name_scope) const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {} : PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}
PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape); PDNode* operator()(const std::string& op_name, bool with_reshape_xshape,
bool with_transpose_xshape);
PATTERN_DECL_NODE(reshape_in); PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op); PATTERN_DECL_NODE(reshape_op);
......
...@@ -24,6 +24,8 @@ namespace framework { ...@@ -24,6 +24,8 @@ namespace framework {
namespace ir { namespace ir {
ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
op_name_ = "matmul";
AddOpCompat(OpCompat("reshape2")) AddOpCompat(OpCompat("reshape2"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -55,7 +57,7 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() { ...@@ -55,7 +57,7 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End(); .End();
AddOpCompat(OpCompat("matmul")) AddOpCompat(OpCompat(op_name_))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
.End() .End()
...@@ -82,17 +84,17 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -82,17 +84,17 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(), patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
name_scope_); name_scope_);
rtm_pattern(with_reshape_xshape, with_transpose_xshape); rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape);
int found_reshape_transpose_matmul_count = 0; int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in " LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_
"reshape_transpose_matmul_mkldnn_fuse_pass failed."; << "_mkldnn_fuse_pass failed.";
return; return;
} }
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse"; VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
...@@ -131,8 +133,8 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -131,8 +133,8 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) { } else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
UpdateMatmul("Y"); UpdateMatmul("Y");
} else { } else {
throw platform::errors::InvalidArgument( throw platform::errors::InvalidArgument("Unexpected input to " +
"Unexpected input to MatMul encountered."); op_name_ + " encountered.");
} }
std::unordered_set<const ir::Node *> nodes_to_remove{ std::unordered_set<const ir::Node *> nodes_to_remove{
...@@ -151,7 +153,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse( ...@@ -151,7 +153,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
if (!Has("disable_logs") || !Get<bool>("disable_logs")) { if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmulMkldnn patterns"; << " ReshapeTransposeMatmul patterns for " << op_name_ << " Op";
if (with_reshape_xshape) msg_ss << " with reshape's xshape"; if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape"; if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str()); string::PrettyLogDetail(msg_ss.str().c_str());
......
...@@ -28,6 +28,7 @@ namespace ir { ...@@ -28,6 +28,7 @@ namespace ir {
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public: public:
ReshapeTransposeMatmulMkldnnFusePass(); ReshapeTransposeMatmulMkldnnFusePass();
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
...@@ -35,6 +36,7 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase { ...@@ -35,6 +36,7 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
void Fuse(Graph* graph, bool with_reshape_xshape, void Fuse(Graph* graph, bool with_reshape_xshape,
bool with_transpose_xshape) const; bool with_transpose_xshape) const;
std::string op_name_;
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
...@@ -37,7 +38,7 @@ Scope* CreateParamScope() { ...@@ -37,7 +38,7 @@ Scope* CreateParamScope() {
return param_scope; return param_scope;
} }
void TestMain(bool with_xshapes) { void TestMain(const std::string& op_name, bool with_xshapes) {
// inputs operator output // inputs operator output
// ----------------------------------------------- // -----------------------------------------------
// a1,w1,bias1 fc -> b1 // a1,w1,bias1 fc -> b1
...@@ -46,7 +47,7 @@ void TestMain(bool with_xshapes) { ...@@ -46,7 +47,7 @@ void TestMain(bool with_xshapes) {
// a2,w2,bias2 fc -> b2 // a2,w2,bias2 fc -> b2
// b2 reshape -> c2 // b2 reshape -> c2
// c2 transpose -> d2 // c2 transpose -> d2
// (d1, d2) matmul -> (...) // (d1, d2) matmul(_v2) -> (...)
Layers layers; Layers layers;
auto* a1 = layers.data("a1", {-1, 128, 768}); auto* a1 = layers.data("a1", {-1, 128, 768});
auto* w1 = layers.data("w1", {768, 768}, true); auto* w1 = layers.data("w1", {768, 768}, true);
...@@ -66,7 +67,11 @@ void TestMain(bool with_xshapes) { ...@@ -66,7 +67,11 @@ void TestMain(bool with_xshapes) {
c2->SetShape({-1, 128, 12, 64}); c2->SetShape({-1, 128, 12, 64});
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3}); auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
d2->SetShape({-1, 12, 128, 64}); d2->SetShape({-1, 12, 128, 64});
layers.matmul(d1, d2); if (op_name == "matmul_v2") {
layers.matmul_v2(d1, d2);
} else {
layers.matmul(d1, d2);
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -76,8 +81,8 @@ void TestMain(bool with_xshapes) { ...@@ -76,8 +81,8 @@ void TestMain(bool with_xshapes) {
int total_nodes_before = graph->Nodes().size(); int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
auto pass = auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name +
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass"); "_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2"); int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
...@@ -92,7 +97,7 @@ void TestMain(bool with_xshapes) { ...@@ -92,7 +97,7 @@ void TestMain(bool with_xshapes) {
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after); EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op(); auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op();
auto check = [&matmul_op_desc](std::string a) { auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a; std::string shape_str = "fused_reshape_" + a;
...@@ -108,12 +113,22 @@ void TestMain(bool with_xshapes) { ...@@ -108,12 +113,22 @@ void TestMain(bool with_xshapes) {
TEST(ReshapeTransposeMatmulMkldnnFusePass, TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose) { both_matmul_inputs_reshape_transpose) {
TestMain(false); TestMain("matmul", false);
} }
TEST(ReshapeTransposeMatmulMkldnnFusePass, TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose_one_with_xshapes) { both_matmul_inputs_reshape_transpose_one_with_xshapes) {
TestMain(true); TestMain("matmul", true);
}
TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose) {
TestMain("matmul_v2", false);
}
TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose_one_with_xshapes) {
TestMain("matmul_v2", true);
} }
} // namespace ir } // namespace ir
...@@ -121,3 +136,4 @@ TEST(ReshapeTransposeMatmulMkldnnFusePass, ...@@ -121,3 +136,4 @@ TEST(ReshapeTransposeMatmulMkldnnFusePass,
} // namespace paddle } // namespace paddle
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass); USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
ReshapeTransposeMatmulV2MkldnnFusePass::
ReshapeTransposeMatmulV2MkldnnFusePass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass);
REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn.
*/
class ReshapeTransposeMatmulV2MkldnnFusePass
: public ReshapeTransposeMatmulMkldnnFusePass {
public:
ReshapeTransposeMatmulV2MkldnnFusePass();
virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {}
protected:
const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -307,6 +307,19 @@ struct Layers { ...@@ -307,6 +307,19 @@ struct Layers {
return out; return out;
} }
VarDesc* matmul_v2(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool trans_x = false, bool trans_y = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul_v2");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("trans_x", trans_x);
op->SetAttr("trans_y", trans_y);
return out;
}
VarDesc* transpose2(VarDesc* x, std::vector<int> axis, VarDesc* transpose2(VarDesc* x, std::vector<int> axis,
bool with_xshape = false) { bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
......
...@@ -244,16 +244,17 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -244,16 +244,17 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_leaky_relu_mkldnn_fuse_pass", // "conv_leaky_relu_mkldnn_fuse_pass", //
"conv_relu6_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", // "conv_hard_swish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", // "conv_hard_sigmoid_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", // "reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", // "matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass", // "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass", // "fc_act_mkldnn_fuse_pass",
......
...@@ -39,6 +39,22 @@ extra { ...@@ -39,6 +39,22 @@ extra {
name: "op_device" name: "op_device"
type: STRING type: STRING
} }
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
attrs { attrs {
name: "fused_reshape_Out" name: "fused_reshape_Out"
type: INTS type: INTS
......
...@@ -19,6 +19,81 @@ ...@@ -19,6 +19,81 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx,
const std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);
PADDLE_ENFORCE_GT(dim.size(), 0,
platform::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
// if mkldnn reshape+transpose+matmul fuse activated
if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_LE(
shape.size(), 4,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_EQ(
shape.size(), axis.size(),
platform::errors::InvalidArgument(
"Ranks of shape_%s and axis_%s attributes of MatMulOp "
"must be equal.",
input_name, input_name));
int num_negative = std::count(shape.begin(), shape.end(), -1);
PADDLE_ENFORCE_LE(num_negative, 1,
platform::errors::InvalidArgument(
"The max number of -1 in fused_reshape_%s is 1 "
"but received %d.",
input_name, num_negative));
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i, dim.size(),
platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name, i, dim.size()));
shape[i] = dim.at(i);
}
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < dim.size(); i++) {
dim_product *= dim.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
class MatMulV2Op : public framework::OperatorWithKernel { class MatMulV2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -30,9 +105,9 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -30,9 +105,9 @@ class MatMulV2Op : public framework::OperatorWithKernel {
bool trans_y = ctx->Attrs().Get<bool>("trans_y"); bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x = std::vector<int64_t> dims_x =
paddle::framework::vectorize(ctx->GetInputDim("X")); framework::vectorize(GetDimForInput(*ctx, "X"));
std::vector<int64_t> dims_y = std::vector<int64_t> dims_y =
paddle::framework::vectorize(ctx->GetInputDim("Y")); framework::vectorize(GetDimForInput(*ctx, "Y"));
auto ndims_x = dims_x.size(); auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size(); auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x, 0, PADDLE_ENFORCE_GT(ndims_x, 0,
...@@ -215,6 +290,22 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -215,6 +290,22 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("float32") .SetDefault("float32")
.InEnum({"float32", "bfloat16"}) .InEnum({"float32", "bfloat16"})
.AsExtra(); .AsExtra();
AddAttr<std::vector<int>>("fused_reshape_X",
R"DOC(Shape of fused reshape of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_reshape_Y",
R"DOC(Shape of fused reshape of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_X",
R"DOC(Axis of fused transpose of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_Y",
R"DOC(Axis of fused transpose of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddComment( AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K), R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)). B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
......
...@@ -25,10 +25,88 @@ using paddle::platform::MKLDNNDeviceContext; ...@@ -25,10 +25,88 @@ using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast; using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
using paddle::framework::DDim;
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
using paddle::framework::make_ddim; using paddle::framework::make_ddim;
using paddle::framework::vectorize; using paddle::framework::vectorize;
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static DDim RowMatrixDimsFromVector(const DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
}
static std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx,
const std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
paddle::operators::math::MatDescriptor mat_dim =
paddle::operators::math::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0,
ctx.Attr<bool>(std::string("trans_") +
static_cast<char>(std::tolower(input_name[0]))));
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = Transpose(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
template <typename T> template <typename T>
class MatMulV2MKLDNNHandler class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> { : public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
...@@ -37,7 +115,9 @@ class MatMulV2MKLDNNHandler ...@@ -37,7 +115,9 @@ class MatMulV2MKLDNNHandler
paddle::platform::Place cpu_place, paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x, const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y, const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused) bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine, : paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
// M X K * K X N // M X K * K X N
...@@ -64,16 +144,24 @@ class MatMulV2MKLDNNHandler ...@@ -64,16 +144,24 @@ class MatMulV2MKLDNNHandler
y_strides.reserve(x_dims.size()); y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size()); out_strides.reserve(x_dims.size());
if (!trans_x) { if (!x_strides_override.empty()) {
x_strides.insert(x_strides.end(), {M * K, K, 1}); x_strides = x_strides_override;
} else { } else {
x_strides.insert(x_strides.end(), {M * K, 1, M}); if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
} }
if (!trans_y) { if (!y_strides_override.empty()) {
y_strides.insert(y_strides.end(), {N * K, N, 1}); y_strides = y_strides_override;
} else { } else {
y_strides.insert(y_strides.end(), {N * K, 1, K}); if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
} }
out_strides.insert(out_strides.end(), {M * N, N, 1}); out_strides.insert(out_strides.end(), {M * N, N, 1});
...@@ -82,8 +170,12 @@ class MatMulV2MKLDNNHandler ...@@ -82,8 +170,12 @@ class MatMulV2MKLDNNHandler
for (int i = x_dims.size() - 4; i >= 0; --i) { for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]); out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1]; if (x_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1]; x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
...@@ -146,9 +238,11 @@ void ExecuteMatMulV2(const ExecutionContext& ctx, ...@@ -146,9 +238,11 @@ void ExecuteMatMulV2(const ExecutionContext& ctx,
const Tensor* y, std::vector<int64_t>& y_dims, const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims, bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) { int execution_number = 0) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims, MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y, trans_x, y_dims, trans_y, IsOutputFused(ctx),
IsOutputFused(ctx)); x_strides_override, y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
...@@ -171,6 +265,17 @@ void ExecuteMatMulV2(const ExecutionContext& ctx, ...@@ -171,6 +265,17 @@ void ExecuteMatMulV2(const ExecutionContext& ctx,
out->set_format(format); out->set_format(format);
} }
DDim GetDimForInput(const paddle::framework::ExecutionContext& ctx,
const std::string& input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.Input<paddle::framework::Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
template <typename T> template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -230,11 +335,11 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -230,11 +335,11 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x"); bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
auto x_dims = vectorize(x->dims()); auto x_dims = vectorize(GetDimForInput(ctx, "X"));
auto y_dims = vectorize(y->dims()); auto y_dims = vectorize(GetDimForInput(ctx, "Y"));
auto out_dims = vectorize(out->dims()); auto out_dims = vectorize(out->dims());
int ndims = std::max(x->dims().size(), y->dims().size()); int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3); ndims = std::max(ndims, 3);
std::vector<int64_t> x_bd_dims(ndims, 1); std::vector<int64_t> x_bd_dims(ndims, 1);
...@@ -398,8 +503,6 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -398,8 +503,6 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
}; };
} // anonymous namespace } // anonymous namespace
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>, MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>); MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
......
...@@ -638,6 +638,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -638,6 +638,8 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph, graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass') 'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_v2_mkldnn_fuse_pass')
graph = self._apply_pass( graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout(graph)]) [self._var_quant_scales, self._get_data_layout(graph)])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import PassVersionChecker
class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest):
def setUp(self):
self.set_params()
self.tranpose_perm = [0, 2, 1, 3]
self.pass_name = 'reshape_transpose_matmul_v2_mkldnn_fuse_pass'
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=self.data_shape, dtype="float32")
weight = fluid.layers.create_parameter(
shape=self.weight_shape, dtype="float32")
reshape = fluid.layers.reshape(data, shape=self.reshape_shape)
transpose = fluid.layers.transpose(reshape, self.tranpose_perm)
matmul = paddle.matmul(
transpose,
weight,
transpose_x=self.transpose_x,
transpose_y=self.transpose_y)
self.fetch_list = [matmul]
self.enable_mkldnn = True
def set_params(self):
self.data_shape = [-1, 128, 768]
self.weight_shape = [1, 12, 64, 128]
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.transpose_x = False
self.transpose_y = False
self.reshape_shape = [0, 0, 12, 64]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
def test_pass_compatible(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class TestReshapeTransposeMatmulV2OneDNNFusePassBroadcast(
TestReshapeTransposeMatmulV2OneDNNFusePass):
def set_params(self):
self.data_shape = [2, 64, 16]
self.weight_shape = [1, 2, 8, 64]
self.feeds = {"data": np.random.random((2, 64, 16)).astype("float32")}
self.transpose_x = True
self.transpose_y = True
self.reshape_shape = [0, 0, 2, 8]
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
...@@ -252,7 +252,7 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): ...@@ -252,7 +252,7 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
@skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.") @skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.")
class TestMatMulOpReshapeTranspose(OpTest): class TestReshapeTransposeMatMulOp(OpTest):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'float32' self.data_type_ = 'float32'
...@@ -267,10 +267,12 @@ class TestMatMulOpReshapeTranspose(OpTest): ...@@ -267,10 +267,12 @@ class TestMatMulOpReshapeTranspose(OpTest):
self.fused_reshape_Y = [] self.fused_reshape_Y = []
self.fused_transpose_Y = [] self.fused_transpose_Y = []
def setUp(self): def set_op_type_and_transpose_y_name(self):
# Set max isa, otherwise fails on SKX and earlier
os.environ["DNNL_MAX_CPU_ISA"] = "AVX"
self.op_type = "matmul" self.op_type = "matmul"
self.transpose_y_name = "transpose_Y"
def setUp(self):
self.set_op_type_and_transpose_y_name()
self._cpu_only = True self._cpu_only = True
self.use_mkldnn = True self.use_mkldnn = True
self.transpose_y = True self.transpose_y = True
...@@ -280,7 +282,7 @@ class TestMatMulOpReshapeTranspose(OpTest): ...@@ -280,7 +282,7 @@ class TestMatMulOpReshapeTranspose(OpTest):
self.inputs = {'X': self.x, 'Y': self.y} self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = { self.attrs = {
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
'transpose_Y': self.transpose_y self.transpose_y_name: self.transpose_y
} }
if len(self.fused_transpose_X) > 0: if len(self.fused_transpose_X) > 0:
self.attrs['fused_transpose_X'] = self.fused_transpose_X self.attrs['fused_transpose_X'] = self.fused_transpose_X
...@@ -297,7 +299,7 @@ class TestMatMulOpReshapeTranspose(OpTest): ...@@ -297,7 +299,7 @@ class TestMatMulOpReshapeTranspose(OpTest):
self.check_output() self.check_output()
class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp4DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32") self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32").reshape( self.y = np.random.random([2, 128, 768]).astype("float32").reshape(
...@@ -311,12 +313,12 @@ class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose): ...@@ -311,12 +313,12 @@ class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose):
self.y.transpose([0, 1, 3, 2])) self.y.transpose([0, 1, 3, 2]))
class TestMatMulOpReshapeTranspose4DXInt8(TestMatMulOpReshapeTranspose4DXFloat): class TestReshapeTransposeMatMulOp4DXInt8(TestReshapeTransposeMatMulOp4DXFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp4DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32").reshape( self.x = np.random.random([2, 128, 768]).astype("float32").reshape(
[2, 128, 12, 64]).transpose([0, 2, 1, 3]) [2, 128, 12, 64]).transpose([0, 2, 1, 3])
...@@ -329,12 +331,12 @@ class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose): ...@@ -329,12 +331,12 @@ class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose):
self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])) self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]))
class TestMatMulOpReshapeTranspose4DYInt8(TestMatMulOpReshapeTranspose4DYFloat): class TestReshapeTransposeMatMulOp4DYInt8(TestReshapeTransposeMatMulOp4DYFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp4DXYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32") self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32") self.y = np.random.random([2, 128, 768]).astype("float32")
...@@ -347,13 +349,13 @@ class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose): ...@@ -347,13 +349,13 @@ class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose):
self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1])) self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]))
class TestMatMulOpReshapeTranspose4DXYInt8( class TestReshapeTransposeMatMulOp4DXYInt8(
TestMatMulOpReshapeTranspose4DXYFloat): TestReshapeTransposeMatMulOp4DXYFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp2DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32") self.x = np.random.random([2, 5, 10]).astype("float32")
self.y = np.random.random([2, 5, 10]).astype("float32").reshape( self.y = np.random.random([2, 5, 10]).astype("float32").reshape(
...@@ -367,12 +369,12 @@ class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose): ...@@ -367,12 +369,12 @@ class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose):
self.y.transpose([1, 0])) self.y.transpose([1, 0]))
class TestMatMulOpReshapeTranspose2DXInt8(TestMatMulOpReshapeTranspose2DXFloat): class TestReshapeTransposeMatMulOp2DXInt8(TestReshapeTransposeMatMulOp2DXFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp2DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32").reshape( self.x = np.random.random([2, 5, 10]).astype("float32").reshape(
[10, 10]).transpose([1, 0]) [10, 10]).transpose([1, 0])
...@@ -384,12 +386,12 @@ class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose): ...@@ -384,12 +386,12 @@ class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose):
self.out = np.matmul(self.x, self.y.reshape([10, 10])) self.out = np.matmul(self.x, self.y.reshape([10, 10]))
class TestMatMulOpReshapeTranspose2DYInt8(TestMatMulOpReshapeTranspose2DYFloat): class TestReshapeTransposeMatMulOp2DYInt8(TestReshapeTransposeMatMulOp2DYFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp3DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype("float32") self.x = np.random.random([2, 2, 5, 5]).astype("float32")
self.y = np.random.random([2, 2, 5, 5]).astype("float32").reshape( self.y = np.random.random([2, 2, 5, 5]).astype("float32").reshape(
...@@ -403,12 +405,12 @@ class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose): ...@@ -403,12 +405,12 @@ class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose):
self.y.transpose(0, 2, 1)) self.y.transpose(0, 2, 1))
class TestMatMulOpReshapeTranspose3DXInt8(TestMatMulOpReshapeTranspose3DXFloat): class TestReshapeTransposeMatMulOp3DXInt8(TestReshapeTransposeMatMulOp3DXFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose): class TestReshapeTransposeMatMulOp3DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self): def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype(self.data_type_).reshape( self.x = np.random.random([2, 2, 5, 5]).astype(self.data_type_).reshape(
[2, 10, 5]).transpose([0, 2, 1]) [2, 10, 5]).transpose([0, 2, 1])
...@@ -420,7 +422,7 @@ class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose): ...@@ -420,7 +422,7 @@ class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose):
self.out = np.matmul(self.x, self.y.reshape([2, 10, 5])) self.out = np.matmul(self.x, self.y.reshape([2, 10, 5]))
class TestMatMulOpReshapeTranspose3DYInt8(TestMatMulOpReshapeTranspose3DYFloat): class TestReshapeTransposeMatMulOp3DYInt8(TestReshapeTransposeMatMulOp3DYFloat):
def init_data_type(self): def init_data_type(self):
self.data_type_ = 'int8' self.data_type_ = 'int8'
......
...@@ -29,7 +29,11 @@ from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import ( ...@@ -29,7 +29,11 @@ from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import (
TestMatMulOpTransposeReshapeOtherDimFloat, TestMatMulOpTransposeReshapeOtherDimFloat,
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException, TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException,
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException, TestMatMulOpTransposeReshapeTransposeRankNotSupportedException,
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException) TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException,
TestReshapeTransposeMatMulOp, TestReshapeTransposeMatMulOp4DXFloat,
TestReshapeTransposeMatMulOp4DYFloat, TestReshapeTransposeMatMulOp4DXYFloat,
TestReshapeTransposeMatMulOp2DXFloat, TestReshapeTransposeMatMulOp2DYFloat,
TestReshapeTransposeMatMulOp3DXFloat, TestReshapeTransposeMatMulOp3DYFloat)
def reference_matmul(X, Y, transpose_x=False, transpose_y=False): def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
...@@ -434,6 +438,61 @@ class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException( ...@@ -434,6 +438,61 @@ class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException(
self.op_type = "matmul_v2" self.op_type = "matmul_v2"
class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DXFloat(
TestReshapeTransposeMatMulOp4DXFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DYFloat(
TestReshapeTransposeMatMulOp4DYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DXYFloat(
TestReshapeTransposeMatMulOp4DXYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose2DXFloat(
TestReshapeTransposeMatMulOp2DXFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose2DYFloat(
TestReshapeTransposeMatMulOp2DYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose3DXFloat(
TestReshapeTransposeMatMulOp3DXFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose3DYFloat(
TestReshapeTransposeMatMulOp3DYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册