未验证 提交 d31a174f 编写于 作者: A arlesniak 提交者: GitHub

added fusing matmul-transpose-reshape pass (#23866)

上级 46f3139c
...@@ -131,5 +131,67 @@ DDim stride_numel(const DDim& ddim) { ...@@ -131,5 +131,67 @@ DDim stride_numel(const DDim& ddim) {
return strides; return strides;
} }
DDim DDim::reshape(const std::vector<int>& shape) const {
const int64_t copy_dim_val = 0;
const DDim& in_dims = *this;
DDim out_dims;
out_dims.rank_ = shape.size();
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT(static_cast<int>(i), in_dims.size(),
platform::errors::InvalidArgument(
"Index %d of shape under which the value of 0 "
"is stored, must be lower than the number of "
"old dimensions. But received shape[%d] = 0, "
"dimensions = %d, shape = [%s].",
i, in_dims.size(), in_dims));
out_dims[i] = in_dims[i];
} else {
out_dims[i] = shape[i];
}
}
return out_dims;
}
DDim DDim::transpose(const std::vector<int>& axis) const {
const DDim& in_dims = *this;
size_t in_rank = in_dims.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(
in_rank, axis_size,
platform::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_ENFORCE_LT(axis[i], static_cast<int>(axis_size),
platform::errors::InvalidArgument(
"ValueError: Each element of axis must appear "
"exactly once in the range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"but received axis[%d] is %d, axis_size is %d",
i, axis[i], axis_size));
PADDLE_ENFORCE_EQ(
++count[axis[i]], 1,
platform::errors::InvalidArgument(
"ValueError: Each element of axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received count[axis[%d]] is %d",
i, count[axis[i]]));
}
DDim out_dims(in_dims);
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = in_dims[axis[i]];
}
return out_dims;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -126,6 +126,10 @@ class DDim { ...@@ -126,6 +126,10 @@ class DDim {
std::string to_str() const; std::string to_str() const;
DDim reshape(const std::vector<int>& shape) const;
DDim transpose(const std::vector<int>& axis) const;
private: private:
template <int D> template <int D>
inline Dim<D>& UnsafeCast() { inline Dim<D>& UnsafeCast() {
......
...@@ -97,6 +97,7 @@ if(WITH_MKLDNN) ...@@ -97,6 +97,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
endif() endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
...@@ -144,4 +145,5 @@ if (WITH_MKLDNN) ...@@ -144,4 +145,5 @@ if (WITH_MKLDNN)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
endif () endif ()
...@@ -2147,6 +2147,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() { ...@@ -2147,6 +2147,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2->LinksFrom({quant_dequant_out}); any_op2->LinksFrom({quant_dequant_out});
} }
PDNode *patterns::MatmulTransposeReshapePattern::operator()() {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsInput()
->assert_is_op_output("matmul", "Out")
->assert_is_op_input("transpose2", "X");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate()
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2", "X");
auto transpose_out_xshape = pattern->NewNode(transpose_out_xshape_repr())
->AsIntermediate()
->assert_is_op_output("transpose2", "XShape");
auto reshape_out = pattern->NewNode(reshape_out_repr())
->AsOutput()
->assert_is_op_output("reshape2");
auto reshape_out_xshape = pattern->NewNode(reshape_out_xshape_repr())
->AsIntermediate()
->assert_is_op_output("reshape2", "XShape");
matmul_op->LinksTo({matmul_out});
transpose_op->LinksTo({transpose_out_xshape});
reshape_op->LinksTo({reshape_out_xshape});
transpose_op->LinksFrom({matmul_out}).LinksTo({transpose_out});
reshape_op->LinksFrom({transpose_out}).LinksTo({reshape_out});
return reshape_out;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1210,6 +1210,24 @@ struct DeleteQuantDequantOpPattern : public PatternBase { ...@@ -1210,6 +1210,24 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2); PATTERN_DECL_NODE(any_op2);
}; };
// Matmul + Transpose + Reshape
struct MatmulTransposeReshapePattern : public PatternBase {
MatmulTransposeReshapePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_transpose_reshape") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(transpose_out_xshape);
PATTERN_DECL_NODE(reshape_op);
PATTERN_DECL_NODE(reshape_out);
PATTERN_DECL_NODE(reshape_out_xshape);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <paddle/fluid/string/pretty_log.h>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
name_scope_);
mtrp();
int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle matmul_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out_xshape, transpose_out_xshape, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out_xshape, reshape_out_xshape, mtrp);
auto reshape_shape =
boost::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis =
boost::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
auto reshape_out_size = reshape_shape.size();
auto transpose_out_size = transpose_axis.size();
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size;
return;
}
if (!supported_transpose_axis) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
return;
}
if (reshape_out_size != 3) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received "
<< reshape_out_size;
return;
}
OpDesc *matmul_desc = matmul_op->Op();
matmul_desc->SetOutput("Out", {reshape_out->Name()});
matmul_desc->SetAttr("fused_reshape_Out", reshape_shape);
matmul_desc->SetAttr("fused_transpose_Out", transpose_axis);
GraphSafeRemoveNodes(graph,
{matmul_out, transpose_op, transpose_out, reshape_op,
transpose_out_xshape, reshape_out_xshape});
IR_OP_VAR_LINK(matmul_op, reshape_out);
found_matmul_transpose_reshape_count++;
};
gpd(graph, handler);
AddStatis(found_matmul_transpose_reshape_count);
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
<< " MatmulTransposeReshape patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulTransposeReshapeMKLDNNPass);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
public:
virtual ~MatmulTransposeReshapeMKLDNNPass() {}
protected:
void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"};
};
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc *prog, const std::string &type,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs) {
auto *op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
if (type == "transpose2") {
op->SetAttr("axis", std::vector<int>({0, 2, 1, 3}));
op->SetOutput("XShape", {outputs[1]});
}
if (type == "reshape2") {
op->SetAttr("shape", std::vector<int>({4, 5, 6}));
op->SetOutput("XShape", {outputs[1]});
}
if (type == "matmul") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
}
}
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto &v : std::initializer_list<std::string>(
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
auto *var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
}
SetOp(&prog, "matmul", {"a1", "a2"}, {"b"});
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
SetOp(&prog, "fc", {"d"}, {"e"});
return prog;
}
void MainTest(const ProgramDesc &prog) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num = graph->Nodes().size();
auto pass =
PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - 6, current_nodes_num);
for (auto *node : graph->Nodes()) {
if (node->IsOp()) {
auto *op = node->Op();
if (op->Type() == "matmul") {
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
std::vector<int>({4, 5, 6}));
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
std::vector<int>({0, 2, 1, 3}));
}
}
}
}
TEST(MatmulTransposeReshapeFusePass, matmul_inputs) {
auto prog = BuildProgramDesc();
MainTest(prog);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(matmul_transpose_reshape_fuse_pass);
...@@ -196,6 +196,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -196,6 +196,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu6_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass", // "fc_mkldnn_pass",
"mkldnn_inplace_pass", // This pass should be activated after "mkldnn_inplace_pass", // This pass should be activated after
......
...@@ -407,7 +407,45 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -407,7 +407,45 @@ class MatMulOp : public framework::OperatorWithKernel {
if (dim_out.empty()) { if (dim_out.empty()) {
dim_out = {1}; dim_out = {1};
} }
context->SetOutputDim("Out", framework::make_ddim(dim_out));
framework::DDim ddim_out = framework::make_ddim(dim_out);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul+transpose+reshape fuse activated
auto reshape_out =
context->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto transpose_out =
context->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));
framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
context->SetOutputDim("Out", shape_out);
} else {
context->SetOutputDim("Out", ddim_out);
}
#else
context->SetOutputDim("Out", ddim_out);
#endif
context->ShareLoD("X", /*->*/ "Out"); context->ShareLoD("X", /*->*/ "Out");
} }
...@@ -446,6 +484,16 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -446,6 +484,16 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"use_mkldnn", "use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used") "(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>(
"fused_reshape_Out",
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a shape atribute of fused reshape for `Out` output.)DOC")
.SetDefault({});
AddAttr<std::vector<int>>(
"fused_transpose_Out",
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
"it's a axis atribute of fused transpose for `Out` output.)DOC")
.SetDefault({});
/* int8 parameters */ /* int8 parameters */
AddAttr<bool>("use_quantizer", AddAttr<bool>("use_quantizer",
"(bool, default false) " "(bool, default false) "
...@@ -466,6 +514,7 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -466,6 +514,7 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Force INT8 kernel output FP32, only " "(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8") "used in MKL-DNN INT8")
.SetDefault(false); .SetDefault(false);
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
AddAttr<int>("head_number", "The number of heads of the matrix") AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1); .SetDefault(1);
......
...@@ -31,6 +31,11 @@ using platform::MKLDNNDeviceContext; ...@@ -31,6 +31,11 @@ using platform::MKLDNNDeviceContext;
using framework::ExecutionContext; using framework::ExecutionContext;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the // Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned. // original x_dim is returned.
static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) { static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) {
...@@ -64,7 +69,8 @@ class MatMulFactory { ...@@ -64,7 +69,8 @@ class MatMulFactory {
private: private:
struct MatMulDims { struct MatMulDims {
const memory::dim BS, M, N, K; const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides,
out_strides;
}; };
void SetDNNLEngine(const ExecutionContext& ctx) { void SetDNNLEngine(const ExecutionContext& ctx) {
...@@ -80,6 +86,19 @@ class MatMulFactory { ...@@ -80,6 +86,19 @@ class MatMulFactory {
return dnnl::memory(md, engine_, to_void_cast(data)); return dnnl::memory(md, engine_, to_void_cast(data));
} }
bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
const memory::dim N, memory::dim b,
memory::dims* out_strides) const {
if (!IsInt8<OT>() && IsOutputFused(ctx)) *out_strides = {N, b * N, 1};
}
MatMulDims GetMatmulDims(const ExecutionContext& ctx) { MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
auto mat_dim_x = math::CreateMatrixDescriptor( auto mat_dim_x = math::CreateMatrixDescriptor(
RowMatrixDimsFromVector(ctx.Input<Tensor>("X")->dims()), 0, RowMatrixDimsFromVector(ctx.Input<Tensor>("X")->dims()), 0,
...@@ -100,34 +119,45 @@ class MatMulFactory { ...@@ -100,34 +119,45 @@ class MatMulFactory {
const memory::dim M = mat_dim_x.height_; const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_; const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_; const memory::dim K = mat_dim_x.width_;
return {BS, M, N, K};
batch_size_ = 1;
auto b = BS;
if (BS > 1 && IsOutputFused(ctx)) {
batch_size_ = ctx.Input<Tensor>("X")->dims()[0];
b = BS / batch_size_;
}
memory::dims x_dims = {b, M, K};
memory::dims y_dims = {b, K, N};
memory::dims out_dims = {b, M, N};
size_t x_size = b * M * K * sizeof(XT);
size_t y_size = b * K * N * sizeof(YT);
size_t out_size = b * M * N * sizeof(OT);
offsets_ = {x_size, y_size, out_size};
// Translate transA and transB
memory::dims strides_x = !ctx.Attr<bool>("transpose_X")
? memory::dims{M * K, K, 1}
: memory::dims{M * K, 1, M};
memory::dims strides_y = !ctx.Attr<bool>("transpose_Y")
? memory::dims{N * K, N, 1}
: memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides);
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
} }
void CreateMemories(const ExecutionContext& ctx) { void CreateMemories(const ExecutionContext& ctx) {
auto matmul_dims = GetMatmulDims(ctx); auto matmul_dims = GetMatmulDims(ctx);
auto BS = matmul_dims.BS;
auto M = matmul_dims.M;
auto N = matmul_dims.N;
auto K = matmul_dims.K;
bool x_trans = ctx.Attr<bool>("transpose_X");
bool y_trans = ctx.Attr<bool>("transpose_Y");
typedef memory::dims dims;
dims x_dims = {BS, M, K};
dims y_dims = {BS, K, N};
dims out_dims = {BS, M, N};
// Translate transA and transB x_mem_ = CreateMemory<XT>(matmul_dims.x_dims, matmul_dims.x_strides,
dims x_strides = !x_trans ? dims{M * K, K, 1} : dims{M * K, 1, M}; ctx.Input<Tensor>("X")->data<XT>());
dims y_strides = !y_trans ? dims{N * K, N, 1} : dims{N * K, 1, K}; y_mem_ = CreateMemory<YT>(matmul_dims.y_dims, matmul_dims.y_strides,
dims out_strides = {M * N, N, 1}; ctx.Input<Tensor>("Y")->data<YT>());
x_mem_ =
CreateMemory<XT>(x_dims, x_strides, ctx.Input<Tensor>("X")->data<XT>());
y_mem_ =
CreateMemory<YT>(y_dims, y_strides, ctx.Input<Tensor>("Y")->data<YT>());
out_mem_ = CreateMemory<OT>( out_mem_ = CreateMemory<OT>(
out_dims, out_strides, matmul_dims.out_dims, matmul_dims.out_strides,
ctx.Output<Tensor>("Out")->mutable_data<OT>(ctx.GetPlace())); ctx.Output<Tensor>("Out")->mutable_data<OT>(ctx.GetPlace()));
} }
...@@ -156,11 +186,25 @@ class MatMulFactory { ...@@ -156,11 +186,25 @@ class MatMulFactory {
void Execute() { void Execute() {
dnnl::stream stream(engine_); dnnl::stream stream(engine_);
auto offsets = offsets_;
unsigned bs = batch_size_;
void* x_ptr = x_mem_.get_data_handle();
void* y_ptr = y_mem_.get_data_handle();
void* out_ptr = out_mem_.get_data_handle();
for (unsigned i = 0; i < bs; i++) {
x_mem_.set_data_handle(x_ptr);
y_mem_.set_data_handle(y_ptr);
out_mem_.set_data_handle(out_ptr);
matmul_prim_.execute(stream, { matmul_prim_.execute(stream, {
{MKLDNN_ARG_SRC, x_mem_}, {MKLDNN_ARG_SRC, x_mem_},
{MKLDNN_ARG_WEIGHTS, y_mem_}, {MKLDNN_ARG_WEIGHTS, y_mem_},
{MKLDNN_ARG_DST, out_mem_}, {MKLDNN_ARG_DST, out_mem_},
}); });
x_ptr = static_cast<char*>(x_ptr) + offsets.x_offset;
y_ptr = static_cast<char*>(y_ptr) + offsets.y_offset;
out_ptr = static_cast<char*>(out_ptr) + offsets.out_offset;
}
stream.wait(); stream.wait();
} }
...@@ -188,11 +232,19 @@ class MatMulFactory { ...@@ -188,11 +232,19 @@ class MatMulFactory {
void SetInitialized() { initialized_ = true; } void SetInitialized() { initialized_ = true; }
private: private:
struct memory_offsets {
size_t x_offset;
size_t y_offset;
size_t out_offset;
};
dnnl::engine engine_; dnnl::engine engine_;
dnnl::memory x_mem_; dnnl::memory x_mem_;
dnnl::memory y_mem_; dnnl::memory y_mem_;
dnnl::memory out_mem_; dnnl::memory out_mem_;
dnnl::matmul matmul_prim_; dnnl::matmul matmul_prim_;
memory_offsets offsets_;
unsigned batch_size_;
bool initialized_ = false; bool initialized_ = false;
}; };
...@@ -217,10 +269,6 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -217,10 +269,6 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
return factory; return factory;
} }
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
// Choose appropriate primitive factory implementation based on inferred // Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float). // output type (uint8, int8 or float).
template <typename XT, typename YT> template <typename XT, typename YT>
......
...@@ -371,6 +371,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -371,6 +371,7 @@ class Qat2Int8MkldnnPass(object):
['use_gpu', 'use_fc_padding'], ['use_gpu', 'use_fc_padding'],
[False, False]) [False, False])
graph = self._apply_pass(graph, 'fc_mkldnn_pass') graph = self._apply_pass(graph, 'fc_mkldnn_pass')
graph = self._apply_pass(graph, 'matmul_transpose_reshape_fuse_pass')
return graph return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from inference_pass_test import InferencePassTest
class TestMKLDNNMatmulFuseOp(InferencePassTest):
def init_data(self):
self.bs = 8
self.d_type = np.float32
self.shape_x = [12, 128, 128]
self.shape_y = [12, 128, 64]
self.enable_mkldnn = True
def make_network(self):
with fluid.program_guard(self.main_program, self.startup_program):
x = fluid.data(
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
y = fluid.data(
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
out = fluid.layers.matmul(x, y)
out = fluid.layers.transpose(out, perm=[0, 2, 1, 3])
out = fluid.layers.reshape(
out, [0, 0, self.shape_y[0] * self.shape_y[2]])
out = fluid.layers.fc(out, size=1)
return out
def setUp(self):
self.init_data()
out = self.make_network()
self.set_feeds(out)
def set_feeds(self, out):
self.feeds = {
"x": np.random.random([self.bs] + self.shape_x).astype(self.d_type),
"y": np.random.random([self.bs] + self.shape_y).astype(self.d_type)
}
self.fetch_list = [out]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
class TestMKLDNNMatmulOtherDimsFuseOp(TestMKLDNNMatmulFuseOp):
def init_data(self):
self.bs = 8
self.d_type = np.float32
self.shape_x = [12, 1, 1]
self.shape_y = [12, 1, 64]
self.enable_mkldnn = True
class TestMKLDNNMatmulOpNotFusedWrongTransposeAxis(TestMKLDNNMatmulFuseOp):
def make_network(self):
with fluid.program_guard(self.main_program, self.startup_program):
x = fluid.data(
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
y = fluid.data(
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
out = fluid.layers.matmul(x, y)
out = fluid.layers.transpose(out, perm=[0, 1, 2, 3])
out = fluid.layers.reshape(out, [0, 0, 0, 0])
out = fluid.layers.fc(out, size=1)
return out
class TestMKLDNNMatmulOpNotFusedBreakPattern(TestMKLDNNMatmulFuseOp):
def init_data(self):
self.bs = 7
self.d_type = np.float32
self.shape_x = [12, 128, 128]
self.shape_y = [12, 128, 64]
self.enable_mkldnn = True
def make_network(self):
with fluid.program_guard(self.main_program, self.startup_program):
x = fluid.data(
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
y = fluid.data(
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
out = fluid.layers.matmul(x, y)
out = fluid.layers.transpose(out, perm=[0, 2, 1, 3])
out = fluid.layers.transpose(
out, perm=[0, 1, 2, 3]) # breaks pattern
out = fluid.layers.reshape(
out, [0, 0, self.shape_y[0] * self.shape_y[2]])
out = fluid.layers.fc(out, size=1)
return out
if __name__ == '__main__':
unittest.main()
...@@ -161,5 +161,134 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): ...@@ -161,5 +161,134 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
self.attrs = {'force_fp32_output': True} self.attrs = {'force_fp32_output': True}
@skip_check_grad_ci(reason="Tests inference only optimization.")
class TestMatMulOpTransposeReshapeEmptyFloat(OpTest):
def init_data_type(self):
self.data_type_ = np.float32
def generate_data(self):
self.bs = 1
self.x = np.random.random([self.bs, 128, 128]).astype(self.data_type_)
self.y = np.random.random([self.bs, 128, 64]).astype(self.data_type_)
def init_params_and_out(self):
self.transpose_out = []
self.reshape_out = []
self.out = np.matmul(self.x, self.y)
def setUp(self):
os.environ["DNNL_MAX_CPU_ISA"] = "AVX"
self.op_type = "matmul"
self._cpu_only = True
self.use_mkldnn = True
self.init_data_type()
self.generate_data()
self.init_params_and_out()
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {'use_mkldnn': self.use_mkldnn}
if len(self.reshape_out) > 0:
self.attrs['fused_reshape_Out'] = self.reshape_out
if len(self.transpose_out) > 0:
self.attrs['fused_transpose_Out'] = self.transpose_out
self.inputs = {'X': self.x, 'Y': self.y}
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
def check_raise_error(self, msg):
try:
self.check_output()
except Exception as e:
if msg in str(e):
raise AttributeError
else:
print(e)
class TestMatMulOpTransposeReshapeIntEmptyInt(
TestMatMulOpTransposeReshapeEmptyFloat):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeEmptyFloat):
def generate_data(self):
self.bs = 8
self.x = np.random.random(
[self.bs, 12, 128, 128]).astype(self.data_type_)
self.y = np.random.random(
[self.bs, 12, 128, 64]).astype(self.data_type_)
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y).transpose([0, 2, 1, 3]).reshape(
[self.bs, -1, self.x.shape[1] * self.y.shape[-1]])
class TestMatMulOpTransposeReshapeBasicInt(
TestMatMulOpTransposeReshapeBasicFloat):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeOtherDimFloat(
TestMatMulOpTransposeReshapeBasicFloat):
def generate_data(self):
self.bs = 11
self.x = np.random.random([self.bs, 12, 14, 18]).astype(self.data_type_)
self.y = np.random.random([self.bs, 12, 18, 13]).astype(self.data_type_)
class TestMatMulOpTransposeReshapeOtherDimInt(
TestMatMulOpTransposeReshapeOtherDimFloat):
def init_data_type(self):
self.data_type_ = np.int8
class TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 1, 2, 3]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'InvalidArgumentError: supported transpose axis '
'for the fuse are {0, 2, 1, 3}')
class TestMatMulOpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 2, 1]
self.reshape_out = [0, 0, self.x.shape[1] * self.y.shape[-1]]
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.assertRaises(
AttributeError, self.check_raise_error,
'InvalidArgumentError: transpose_out supported rank is 4')
class TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeBasicFloat):
def init_params_and_out(self):
self.transpose_out = [0, 2, 1, 3]
self.reshape_out = [0, 0]
self.out = np.matmul(self.x, self.y)
def test_check_output(self):
self.assertRaises(
AttributeError, self.check_raise_error,
'InvalidArgumentError: reshape_out supported rank is 3')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册