未验证 提交 ea5f44b8 编写于 作者: H Hui Zhang 提交者: GitHub

[cherry-pick] Squeeze2 and transpose2 fuse using oneDNN(#47712)

* suqeeze2 + transpose2 fuse onednn cherrypick 2.4

* format

* fix merge
上级 34f67a88
......@@ -219,6 +219,7 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
......
......@@ -1042,6 +1042,25 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
return relu_out_var;
}
PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
->assert_is_op_input("squeeze2", "X");
auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr())
->assert_is_op("squeeze2")
->assert_has_n_outputs(2);
auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr())
->AsIntermediate()
->assert_is_op_output("squeeze2", "Out")
->assert_is_op_input("transpose2", "X");
auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out});
transpose2_op->LinksFrom({squeeze2_op_out});
return transpose2_op;
}
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias,
bool with_relu) {
......
......@@ -634,6 +634,20 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(output);
};
// Squeeze2 + Transpose2
// Forward pass
struct Squeeze2Transpose2 : public PatternBase {
Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "squeeze2_transpose2") {}
PDNode* operator()();
PATTERN_DECL_NODE(squeeze2_op_in);
PATTERN_DECL_NODE(squeeze2_op);
PATTERN_DECL_NODE(squeeze2_op_out);
PATTERN_DECL_NODE(transpose2_op);
};
// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
......@@ -2002,6 +2016,12 @@ struct AddSupportInt8 : public PatternBase {
out_var->inputs.clear(); \
out_var->inputs.push_back(op);
// Set the in_var as the input of the op
#define IR_VAR_OP_LINK(in_var, op) \
in_var->outputs.clear(); \
in_var->outputs.push_back(op); \
op->inputs.push_back(in_var);
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern(
gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass");
squeeze2_transpose2_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, squeeze2_transpose2_pattern);
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
(transpose2_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool,
transpose2_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of transpose2 can be fused after with "
"squeeze2.";
return;
}
std::vector<int> squeeze2_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes);
transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()});
IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op);
GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out});
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass,
paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass);
REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("squeeze2", 0)
.GE("transpose2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase {
public:
virtual ~FuseSqueeze2Transpose2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -307,6 +307,7 @@ void CpuPassStrategy::EnableMKLDNN() {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>({
"squeeze2_transpose2_onednn_fuse_pass",
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
......@@ -386,6 +387,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("constant_folding_pass");
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass");
......
......@@ -42,6 +42,9 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::SetInMemDescWithLogicalLayoutFusesSupport(
ctx, const_cast<phi::DenseTensor*>(x), x->mem_desc());
if (ndims == 1) {
framework::TensorCopy(*x, x->place(), out);
out->set_mem_desc(x->mem_desc());
......
......@@ -39,11 +39,13 @@ class TransposeOp : public framework::OperatorWithKernel {
size_t x_rank = x_dims.size();
size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(x_rank,
// Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
// axis_size
PADDLE_ENFORCE_GE(x_rank,
axis_size,
platform::errors::InvalidArgument(
"The input tensor's dimension "
"should be equal to the axis's size. "
"should be equal to or greater than the axis's size. "
"But received input tensor's dimension is %d, "
"axis's size is %d",
x_rank,
......
......@@ -120,12 +120,10 @@ static void SetOutMemDescWithUnsqueeze2FuseSupport(
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}
int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
......@@ -143,20 +141,17 @@ static void SetOutMemDescWithReshape2FuseSupport(
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
......@@ -169,11 +164,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport(
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_squeeze2_axes")) {
out->set_mem_desc(out_md);
out->Resize(phi::make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}
static void SetInMemDescWithSqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* in,
const dnnl::memory::desc& in_md) {
const std::vector<int> fused_squeeze2_axes =
ctx.Attr<std::vector<int>>("fused_squeeze2_axes");
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
fused_squeeze2_axes.end());
const std::vector<int64_t>& x_vec_dims = in_md.dims();
std::vector<int64_t> squeezed_op_tz(
x_vec_dims.size() - fused_squeeze2_axes.size(), 0);
int j = 0;
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
if (squeeze2_axes_set.count(i) ||
squeeze2_axes_set.count(i - x_vec_dims.size())) {
PADDLE_ENFORCE_EQ(
x_vec_dims[i],
1,
platform::errors::InvalidArgument(
"Squeeze2 input '%d' dim should be equal to one, but get '%d'.",
i,
x_vec_dims[i]));
continue;
}
squeezed_op_tz[j++] = x_vec_dims[i];
}
in->set_mem_desc(in_md.reshape(squeezed_op_tz));
in->Resize(phi::make_ddim(squeezed_op_tz));
}
static void SetInMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* in,
const dnnl::memory::desc& in_md) {
if (ctx.HasAttr("fused_squeeze2_axes")) {
SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md);
} else {
in->set_mem_desc(in_md);
in->Resize(phi::make_ddim(in_md.dims()));
}
}
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册