未验证 提交 a922168a 编写于 作者: S Sylwester Fraczek 提交者: GitHub

add reshape+transpose+matmul_v2 only (#37847)

* reshape+transpose+matmul_v2

* in_name->input_name

* fix pr-ci-static-check
上级 6a852536
......@@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
......@@ -190,7 +191,7 @@ endif()
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass reshape_transpose_matmul_v2_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
......
......@@ -2711,12 +2711,13 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
}
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
bool with_reshape_xshape, bool with_transpose_xshape) {
const std::string &op_name, bool with_reshape_xshape,
bool with_transpose_xshape) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);
auto reshape_in = pattern->NewNode(reshape_in_repr())
->AsInput()
......@@ -2737,7 +2738,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input(op_name)
->assert_is_op_output("transpose2", "Out");
if (!with_transpose_xshape)
transpose_out->assert_is_only_output_of_op("transpose2");
......@@ -2751,7 +2752,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
->assert_is_op_output(op_name, "Out");
reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape});
......
......@@ -1570,7 +1570,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}
PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape);
PDNode* operator()(const std::string& op_name, bool with_reshape_xshape,
bool with_transpose_xshape);
PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op);
......
......@@ -24,6 +24,8 @@ namespace framework {
namespace ir {
ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
op_name_ = "matmul";
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
......@@ -55,7 +57,7 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("matmul"))
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
......@@ -82,17 +84,17 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
name_scope_);
rtm_pattern(with_reshape_xshape, with_transpose_xshape);
rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape);
int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in "
"reshape_transpose_matmul_mkldnn_fuse_pass failed.";
LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_
<< "_mkldnn_fuse_pass failed.";
return;
}
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse";
VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
......@@ -131,8 +133,8 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
UpdateMatmul("Y");
} else {
throw platform::errors::InvalidArgument(
"Unexpected input to MatMul encountered.");
throw platform::errors::InvalidArgument("Unexpected input to " +
op_name_ + " encountered.");
}
std::unordered_set<const ir::Node *> nodes_to_remove{
......@@ -151,7 +153,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmulMkldnn patterns";
<< " ReshapeTransposeMatmul patterns for " << op_name_ << " Op";
if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str());
......
......@@ -28,6 +28,7 @@ namespace ir {
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public:
ReshapeTransposeMatmulMkldnnFusePass();
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
......@@ -35,6 +36,7 @@ class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
void Fuse(Graph* graph, bool with_reshape_xshape,
bool with_transpose_xshape) const;
std::string op_name_;
};
} // namespace ir
} // namespace framework
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
......@@ -37,7 +38,7 @@ Scope* CreateParamScope() {
return param_scope;
}
void TestMain(bool with_xshapes) {
void TestMain(const std::string& op_name, bool with_xshapes) {
// inputs operator output
// -----------------------------------------------
// a1,w1,bias1 fc -> b1
......@@ -46,7 +47,7 @@ void TestMain(bool with_xshapes) {
// a2,w2,bias2 fc -> b2
// b2 reshape -> c2
// c2 transpose -> d2
// (d1, d2) matmul -> (...)
// (d1, d2) matmul(_v2) -> (...)
Layers layers;
auto* a1 = layers.data("a1", {-1, 128, 768});
auto* w1 = layers.data("w1", {768, 768}, true);
......@@ -66,7 +67,11 @@ void TestMain(bool with_xshapes) {
c2->SetShape({-1, 128, 12, 64});
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
d2->SetShape({-1, 12, 128, 64});
if (op_name == "matmul_v2") {
layers.matmul_v2(d1, d2);
} else {
layers.matmul(d1, d2);
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
......@@ -76,8 +81,8 @@ void TestMain(bool with_xshapes) {
int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
auto pass =
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name +
"_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
......@@ -92,7 +97,7 @@ void TestMain(bool with_xshapes) {
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op();
auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op();
auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
......@@ -108,12 +113,22 @@ void TestMain(bool with_xshapes) {
TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose) {
TestMain(false);
TestMain("matmul", false);
}
TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose_one_with_xshapes) {
TestMain(true);
TestMain("matmul", true);
}
TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose) {
TestMain("matmul_v2", false);
}
TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose_one_with_xshapes) {
TestMain("matmul_v2", true);
}
} // namespace ir
......@@ -121,3 +136,4 @@ TEST(ReshapeTransposeMatmulMkldnnFusePass,
} // namespace paddle
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
ReshapeTransposeMatmulV2MkldnnFusePass::
ReshapeTransposeMatmulV2MkldnnFusePass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass);
REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn.
*/
class ReshapeTransposeMatmulV2MkldnnFusePass
: public ReshapeTransposeMatmulMkldnnFusePass {
public:
ReshapeTransposeMatmulV2MkldnnFusePass();
virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {}
protected:
const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -307,6 +307,19 @@ struct Layers {
return out;
}
VarDesc* matmul_v2(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool trans_x = false, bool trans_y = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul_v2");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("trans_x", trans_x);
op->SetAttr("trans_y", trans_y);
return out;
}
VarDesc* transpose2(VarDesc* x, std::vector<int> axis,
bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name());
......
......@@ -252,6 +252,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_hard_sigmoid_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
......
......@@ -39,6 +39,22 @@ extra {
name: "op_device"
type: STRING
}
attrs {
name: "fused_reshape_X"
type: INTS
}
attrs {
name: "fused_reshape_Y"
type: INTS
}
attrs {
name: "fused_transpose_X"
type: INTS
}
attrs {
name: "fused_transpose_Y"
type: INTS
}
attrs {
name: "fused_reshape_Out"
type: INTS
......
......@@ -19,6 +19,81 @@
namespace paddle {
namespace operators {
static framework::DDim GetDimForInput(const framework::InferShapeContext& ctx,
const std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);
PADDLE_ENFORCE_GT(dim.size(), 0,
platform::errors::InvalidArgument(
"The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].",
dim));
// if mkldnn reshape+transpose+matmul fuse activated
if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_LE(
shape.size(), 4,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_EQ(
shape.size(), axis.size(),
platform::errors::InvalidArgument(
"Ranks of shape_%s and axis_%s attributes of MatMulOp "
"must be equal.",
input_name, input_name));
int num_negative = std::count(shape.begin(), shape.end(), -1);
PADDLE_ENFORCE_LE(num_negative, 1,
platform::errors::InvalidArgument(
"The max number of -1 in fused_reshape_%s is 1 "
"but received %d.",
input_name, num_negative));
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i, dim.size(),
platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name, i, dim.size()));
shape[i] = dim.at(i);
}
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < dim.size(); i++) {
dim_product *= dim.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
class MatMulV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -30,9 +105,9 @@ class MatMulV2Op : public framework::OperatorWithKernel {
bool trans_y = ctx->Attrs().Get<bool>("trans_y");
std::vector<int64_t> dims_x =
paddle::framework::vectorize(ctx->GetInputDim("X"));
framework::vectorize(GetDimForInput(*ctx, "X"));
std::vector<int64_t> dims_y =
paddle::framework::vectorize(ctx->GetInputDim("Y"));
framework::vectorize(GetDimForInput(*ctx, "Y"));
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x, 0,
......@@ -215,6 +290,22 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault("float32")
.InEnum({"float32", "bfloat16"})
.AsExtra();
AddAttr<std::vector<int>>("fused_reshape_X",
R"DOC(Shape of fused reshape of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_reshape_Y",
R"DOC(Shape of fused reshape of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_X",
R"DOC(Axis of fused transpose of `X` input.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>("fused_transpose_Y",
R"DOC(Axis of fused transpose of `Y` input.)DOC")
.SetDefault({})
.AsExtra();
AddComment(
R"DOC(Matrix multiplication Out = X * Y. A has shape (d0, d1 ... M, K),
B has shape (d0, d1 ... K, N), Out has shape ((d0, d1 ... M, N)).
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/blas.h"
......
......@@ -25,10 +25,88 @@ using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using Tensor = paddle::framework::Tensor;
using paddle::framework::DDim;
using paddle::framework::GradVarName;
using paddle::framework::make_ddim;
using paddle::framework::vectorize;
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static DDim RowMatrixDimsFromVector(const DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
}
static std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx,
const std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
paddle::operators::math::MatDescriptor mat_dim =
paddle::operators::math::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0,
ctx.Attr<bool>(std::string("trans_") +
static_cast<char>(std::tolower(input_name[0]))));
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = Transpose(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
......@@ -37,7 +115,9 @@ class MatMulV2MKLDNNHandler
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused)
bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
......@@ -64,17 +144,25 @@ class MatMulV2MKLDNNHandler
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
}
if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
......@@ -82,8 +170,12 @@ class MatMulV2MKLDNNHandler
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
......@@ -146,9 +238,11 @@ void ExecuteMatMulV2(const ExecutionContext& ctx,
const Tensor* y, std::vector<int64_t>& y_dims,
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y,
IsOutputFused(ctx));
trans_x, y_dims, trans_y, IsOutputFused(ctx),
x_strides_override, y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......@@ -171,6 +265,17 @@ void ExecuteMatMulV2(const ExecutionContext& ctx,
out->set_format(format);
}
DDim GetDimForInput(const paddle::framework::ExecutionContext& ctx,
const std::string& input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.Input<paddle::framework::Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -230,11 +335,11 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
auto x_dims = vectorize(GetDimForInput(ctx, "X"));
auto y_dims = vectorize(GetDimForInput(ctx, "Y"));
auto out_dims = vectorize(out->dims());
int ndims = std::max(x->dims().size(), y->dims().size());
int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);
std::vector<int64_t> x_bd_dims(ndims, 1);
......@@ -398,8 +503,6 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
};
} // anonymous namespace
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul_v2, MKLDNN, ::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
......
......@@ -638,6 +638,8 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_v2_mkldnn_fuse_pass')
graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout(graph)])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import PassVersionChecker
class TestReshapeTransposeMatmulV2OneDNNFusePass(InferencePassTest):
def setUp(self):
self.set_params()
self.tranpose_perm = [0, 2, 1, 3]
self.pass_name = 'reshape_transpose_matmul_v2_mkldnn_fuse_pass'
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=self.data_shape, dtype="float32")
weight = fluid.layers.create_parameter(
shape=self.weight_shape, dtype="float32")
reshape = fluid.layers.reshape(data, shape=self.reshape_shape)
transpose = fluid.layers.transpose(reshape, self.tranpose_perm)
matmul = paddle.matmul(
transpose,
weight,
transpose_x=self.transpose_x,
transpose_y=self.transpose_y)
self.fetch_list = [matmul]
self.enable_mkldnn = True
def set_params(self):
self.data_shape = [-1, 128, 768]
self.weight_shape = [1, 12, 64, 128]
self.feeds = {"data": np.random.random((1, 128, 768)).astype("float32")}
self.transpose_x = False
self.transpose_y = False
self.reshape_shape = [0, 0, 12, 64]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
def test_pass_compatible(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class TestReshapeTransposeMatmulV2OneDNNFusePassBroadcast(
TestReshapeTransposeMatmulV2OneDNNFusePass):
def set_params(self):
self.data_shape = [2, 64, 16]
self.weight_shape = [1, 2, 8, 64]
self.feeds = {"data": np.random.random((2, 64, 16)).astype("float32")}
self.transpose_x = True
self.transpose_y = True
self.reshape_shape = [0, 0, 2, 8]
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -252,7 +252,7 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
@skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.")
class TestMatMulOpReshapeTranspose(OpTest):
class TestReshapeTransposeMatMulOp(OpTest):
def init_data_type(self):
self.data_type_ = 'float32'
......@@ -267,10 +267,12 @@ class TestMatMulOpReshapeTranspose(OpTest):
self.fused_reshape_Y = []
self.fused_transpose_Y = []
def setUp(self):
# Set max isa, otherwise fails on SKX and earlier
os.environ["DNNL_MAX_CPU_ISA"] = "AVX"
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul"
self.transpose_y_name = "transpose_Y"
def setUp(self):
self.set_op_type_and_transpose_y_name()
self._cpu_only = True
self.use_mkldnn = True
self.transpose_y = True
......@@ -280,7 +282,7 @@ class TestMatMulOpReshapeTranspose(OpTest):
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {
'use_mkldnn': self.use_mkldnn,
'transpose_Y': self.transpose_y
self.transpose_y_name: self.transpose_y
}
if len(self.fused_transpose_X) > 0:
self.attrs['fused_transpose_X'] = self.fused_transpose_X
......@@ -297,7 +299,7 @@ class TestMatMulOpReshapeTranspose(OpTest):
self.check_output()
class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp4DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32").reshape(
......@@ -311,12 +313,12 @@ class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose):
self.y.transpose([0, 1, 3, 2]))
class TestMatMulOpReshapeTranspose4DXInt8(TestMatMulOpReshapeTranspose4DXFloat):
class TestReshapeTransposeMatMulOp4DXInt8(TestReshapeTransposeMatMulOp4DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp4DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32").reshape(
[2, 128, 12, 64]).transpose([0, 2, 1, 3])
......@@ -329,12 +331,12 @@ class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose):
self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]))
class TestMatMulOpReshapeTranspose4DYInt8(TestMatMulOpReshapeTranspose4DYFloat):
class TestReshapeTransposeMatMulOp4DYInt8(TestReshapeTransposeMatMulOp4DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp4DXYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32")
......@@ -347,13 +349,13 @@ class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose):
self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]))
class TestMatMulOpReshapeTranspose4DXYInt8(
TestMatMulOpReshapeTranspose4DXYFloat):
class TestReshapeTransposeMatMulOp4DXYInt8(
TestReshapeTransposeMatMulOp4DXYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp2DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32")
self.y = np.random.random([2, 5, 10]).astype("float32").reshape(
......@@ -367,12 +369,12 @@ class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose):
self.y.transpose([1, 0]))
class TestMatMulOpReshapeTranspose2DXInt8(TestMatMulOpReshapeTranspose2DXFloat):
class TestReshapeTransposeMatMulOp2DXInt8(TestReshapeTransposeMatMulOp2DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp2DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32").reshape(
[10, 10]).transpose([1, 0])
......@@ -384,12 +386,12 @@ class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose):
self.out = np.matmul(self.x, self.y.reshape([10, 10]))
class TestMatMulOpReshapeTranspose2DYInt8(TestMatMulOpReshapeTranspose2DYFloat):
class TestReshapeTransposeMatMulOp2DYInt8(TestReshapeTransposeMatMulOp2DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp3DXFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype("float32")
self.y = np.random.random([2, 2, 5, 5]).astype("float32").reshape(
......@@ -403,12 +405,12 @@ class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose):
self.y.transpose(0, 2, 1))
class TestMatMulOpReshapeTranspose3DXInt8(TestMatMulOpReshapeTranspose3DXFloat):
class TestReshapeTransposeMatMulOp3DXInt8(TestReshapeTransposeMatMulOp3DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose):
class TestReshapeTransposeMatMulOp3DYFloat(TestReshapeTransposeMatMulOp):
def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype(self.data_type_).reshape(
[2, 10, 5]).transpose([0, 2, 1])
......@@ -420,7 +422,7 @@ class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose):
self.out = np.matmul(self.x, self.y.reshape([2, 10, 5]))
class TestMatMulOpReshapeTranspose3DYInt8(TestMatMulOpReshapeTranspose3DYFloat):
class TestReshapeTransposeMatMulOp3DYInt8(TestReshapeTransposeMatMulOp3DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
......
......@@ -29,7 +29,11 @@ from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import (
TestMatMulOpTransposeReshapeOtherDimFloat,
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException,
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException,
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException)
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException,
TestReshapeTransposeMatMulOp, TestReshapeTransposeMatMulOp4DXFloat,
TestReshapeTransposeMatMulOp4DYFloat, TestReshapeTransposeMatMulOp4DXYFloat,
TestReshapeTransposeMatMulOp2DXFloat, TestReshapeTransposeMatMulOp2DYFloat,
TestReshapeTransposeMatMulOp3DXFloat, TestReshapeTransposeMatMulOp3DYFloat)
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
......@@ -434,6 +438,61 @@ class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException(
self.op_type = "matmul_v2"
class TestMatMulV2OpReshapeTranspose(TestReshapeTransposeMatMulOp):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DXFloat(
TestReshapeTransposeMatMulOp4DXFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DYFloat(
TestReshapeTransposeMatMulOp4DYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose4DXYFloat(
TestReshapeTransposeMatMulOp4DXYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose2DXFloat(
TestReshapeTransposeMatMulOp2DXFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose2DYFloat(
TestReshapeTransposeMatMulOp2DYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose3DXFloat(
TestReshapeTransposeMatMulOp3DXFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
class TestMatMulV2OpReshapeTranspose3DYFloat(
TestReshapeTransposeMatMulOp3DYFloat):
def set_op_type_and_transpose_y_name(self):
self.op_type = "matmul_v2"
self.transpose_y_name = "trans_y"
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册