未验证 提交 fa874a46 编写于 作者: H Hui Zhang 提交者: GitHub

suqeeze2 + transpose2 fuse onednn (#47592)

* suqeeze2 transpose2 fuse onednn

* format

* fix output shape

* fix conflict

* format

* format

* remove useless

* remove log

* simply pass

* fix comment

* fix

* fix msg

* fix error msg

* format
上级 45bc4542
...@@ -177,6 +177,7 @@ if(WITH_MKLDNN) ...@@ -177,6 +177,7 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
......
...@@ -958,6 +958,25 @@ PDNode *patterns::OperatorActivation::operator()( ...@@ -958,6 +958,25 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out; return activation_out;
} }
PDNode *patterns::Squeeze2Transpose2::operator()() {
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
->AsInput()
->assert_is_op_input("squeeze2", "X");
auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr())
->assert_is_op("squeeze2")
->assert_has_n_outputs(2);
auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr())
->AsIntermediate()
->assert_is_op_output("squeeze2", "Out")
->assert_is_op_input("transpose2", "X");
auto *transpose2_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out});
transpose2_op->LinksFrom({squeeze2_op_out});
return transpose2_op;
}
PDNode *patterns::OperatorUnsqueeze2::operator()( PDNode *patterns::OperatorUnsqueeze2::operator()(
const std::string &operator_type, const int num_of_operator_outs) { const std::string &operator_type, const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr()) auto *preceding_op = pattern->NewNode(preceding_op_repr())
......
...@@ -539,6 +539,18 @@ struct OperatorActivation : public PatternBase { ...@@ -539,6 +539,18 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out); PATTERN_DECL_NODE(activation_out);
}; };
struct Squeeze2Transpose2 : public PatternBase {
Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "squeeze2_transpose2") {}
PDNode* operator()();
PATTERN_DECL_NODE(squeeze2_op_in);
PATTERN_DECL_NODE(squeeze2_op);
PATTERN_DECL_NODE(squeeze2_op_out);
PATTERN_DECL_NODE(transpose2_op);
};
struct OperatorUnsqueeze2 : public PatternBase { struct OperatorUnsqueeze2 : public PatternBase {
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope) OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {} : PatternBase(pattern, name_scope, "operator_unsqueeze2") {}
...@@ -2030,6 +2042,12 @@ struct AddSupportInt8 : public PatternBase { ...@@ -2030,6 +2042,12 @@ struct AddSupportInt8 : public PatternBase {
out_var->inputs.clear(); \ out_var->inputs.clear(); \
out_var->inputs.push_back(op); out_var->inputs.push_back(op);
// Set the in_var as the input of the op
#define IR_VAR_OP_LINK(in_var, op) \
in_var->outputs.clear(); \
in_var->outputs.push_back(op); \
op->inputs.push_back(in_var);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"Input graph pointer argument should not be nullptr."));
FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern(
gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass");
squeeze2_transpose2_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, squeeze2_transpose2_pattern);
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
(transpose2_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool,
transpose2_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of transpose2 can be fused after with "
"squeeze2.";
return;
}
std::vector<int> squeeze2_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes);
transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()});
IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op);
GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out});
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass,
paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass);
REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("squeeze2", 0)
.GE("transpose2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase {
public:
virtual ~FuseSqueeze2Transpose2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -308,6 +308,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -308,6 +308,7 @@ void CpuPassStrategy::EnableMKLDNN() {
passes_.insert(passes_.begin(), "mkldnn_placement_pass"); passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>({ for (auto &pass : std::vector<std::string>({
"squeeze2_transpose2_onednn_fuse_pass",
"depthwise_conv_mkldnn_pass", // "depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to "conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
...@@ -387,6 +388,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -387,6 +388,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("mkldnn_placement_pass"); passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("constant_folding_pass"); passes_.push_back("constant_folding_pass");
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass");
......
...@@ -37,11 +37,14 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -37,11 +37,14 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const auto& dnnl_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = transpose_axis.size(); int ndims = transpose_axis.size();
auto* x = ctx.Input<Tensor>("X"); const phi::DenseTensor* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::SetInMemDescWithLogicalLayoutFusesSupport(
ctx, const_cast<phi::DenseTensor*>(x), x->mem_desc());
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*x, x->place(), out); framework::TensorCopy(*x, x->place(), out);
out->set_mem_desc(x->mem_desc()); out->set_mem_desc(x->mem_desc());
......
...@@ -34,14 +34,17 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -34,14 +34,17 @@ class TransposeOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis"); std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
size_t x_rank = x_dims.size(); size_t x_rank = x_dims.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(x_rank, // Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
// axis_size
PADDLE_ENFORCE_GE(x_rank,
axis_size, axis_size,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor's dimension " "The input tensor's dimension "
"should be equal to the axis's size. " "should be equal to or greater than the axis's size. "
"But received input tensor's dimension is %d, " "But received input tensor's dimension is %d, "
"axis's size is %d", "axis's size is %d",
x_rank, x_rank,
......
...@@ -167,11 +167,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport( ...@@ -167,11 +167,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport(
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md); SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) { } else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md); SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_squeeze2_axes")) {
out->set_mem_desc(out_md);
out->Resize(phi::make_ddim(out_md.dims()));
} else { } else {
out->set_mem_desc(out_md); out->set_mem_desc(out_md);
} }
} }
static void SetInMemDescWithSqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* in,
const dnnl::memory::desc& in_md) {
const std::vector<int> fused_squeeze2_axes =
ctx.Attr<std::vector<int>>("fused_squeeze2_axes");
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
fused_squeeze2_axes.end());
const std::vector<int64_t>& x_vec_dims = in_md.dims();
std::vector<int64_t> squeezed_op_tz(
x_vec_dims.size() - fused_squeeze2_axes.size(), 0);
int j = 0;
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
if (squeeze2_axes_set.count(i) ||
squeeze2_axes_set.count(i - x_vec_dims.size())) {
PADDLE_ENFORCE_EQ(
x_vec_dims[i],
1,
platform::errors::InvalidArgument(
"Squeeze2 input dim %d should be equal to one, but get %d.",
i,
x_vec_dims[i]));
continue;
}
squeezed_op_tz[j++] = x_vec_dims[i];
}
in->set_mem_desc(in_md.reshape(squeezed_op_tz));
in->Resize(phi::make_ddim(squeezed_op_tz));
}
static void SetInMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* in,
const dnnl::memory::desc& in_md) {
if (ctx.HasAttr("fused_squeeze2_axes")) {
SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md);
} else {
in->set_mem_desc(in_md);
in->Resize(phi::make_ddim(in_md.dims()));
}
}
template <typename T> template <typename T>
constexpr bool IsInt8() { constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册