未验证 提交 d0e19af3 编写于 作者: J jakpiase 提交者: GitHub

[CHERRY-PICK] Added caching to oneDNN FC and op+unsqueeze2 and op+reshape2 fuse passes (#47690)

* fc cherrypick

* another files added

* added transpose cherrypick

* reverter somebodys fc changes

* minor fix

* minor fix

* cherry-pick of fc+act changes

* minor fix

* fix
上级 cf668ab3
...@@ -219,6 +219,8 @@ if(WITH_MKLDNN) ...@@ -219,6 +219,8 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
...@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()( ...@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out; return activation_out;
} }
PDNode *patterns::OperatorUnsqueeze2::operator()(
const std::string &operator_type, const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("unsqueeze2");
auto *unsqueeze2_op =
pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2");
auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->AsOutput()
->assert_is_op_output("unsqueeze2");
preceding_op->LinksTo({preceding_op_out});
unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out});
return unsqueeze2_out;
}
PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type,
const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("reshape2");
auto *reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto *reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsOutput()
->assert_is_op_output("reshape2");
preceding_op->LinksTo({preceding_op_out});
reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out});
return reshape2_out;
}
PDNode *patterns::SeqConvEltAddRelu::operator()( PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) { paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators // Create Operators
......
...@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase { ...@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out); PATTERN_DECL_NODE(activation_out);
}; };
struct OperatorUnsqueeze2 : public PatternBase {
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(unsqueeze2_op);
PATTERN_DECL_NODE(unsqueeze2_out);
};
struct OperatorReshape2 : public PatternBase {
OperatorReshape2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_reshape2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
// SEQCONV with Elementwise_Add ReLU // SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu // op: seqconv + elementwise_add + relu
// named nodes: // named nodes:
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -26,20 +25,20 @@ namespace ir { ...@@ -26,20 +25,20 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = { auto act_types = paddle::platform::GetSupportedActivations();
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
for (std::string act_type : act_types) FuseFCAct(graph, act_type); for (auto act_type : act_types) FuseFCAct(graph, act_type);
} }
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const { const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph); FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act"); patterns::OperatorActivation fc_act_pattern(
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
fc_act_pattern("fc", act_type); fc_act_pattern("fc", act_type);
int found_fc_act_count = 0; int found_fc_act_count = 0;
...@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used.")); "is used."));
} }
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) { if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")); std::string gelu_act_type =
std::string type = approximate ? "_tanh" : "_erf"; PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
fc_op->SetAttr("activation_type", act_type + type); : "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else { } else {
fc_op->SetAttr("activation_type", act_type); fc_op->SetAttr("fuse_activation", act_type);
} }
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()}); fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out); IR_OP_VAR_LINK(fc, act_out);
...@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_fc_act_count); AddStatis(found_fc_act_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_fc_act_count > 0)
PrettyLogDetail( PrettyLogDetail(
"--- fused %d fc with %s activation", found_fc_act_count, act_type); "--- fused %d fc with %s activation", found_fc_act_count, act_type);
} }
...@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) ...@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0) .LE("fc", 0)
.LE("gelu", 0) .EQ("abs", 0)
.LE("sigmoid", 0) .LE("clip", 1)
.LE("mish", 1) .EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0) .LE("hard_swish", 0)
.LE("tanh", 0)); .LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
...@@ -23,21 +23,14 @@ namespace paddle { ...@@ -23,21 +23,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase { class FuseFCActOneDNNPass : public FusePassBase {
public: public:
virtual ~FuseFCActOneDNNPass() {} virtual ~FuseFCActOneDNNPass() {}
protected: protected:
void ApplyImpl(ir::Graph *graph) const override; void ApplyImpl(Graph *graph) const override;
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const; void FuseFCAct(Graph *graph, const std::string &act_types) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { ...@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_tanh"), 0); EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
} }
} }
...@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { ...@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_erf"), 0); EXPECT_EQ(act_type.compare("gelu_erf"), 0);
} }
} }
...@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { ...@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu"), 0); EXPECT_EQ(act_type.compare("gelu"), 0);
} }
} }
...@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { ...@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("tanh"), 0); EXPECT_EQ(act_type.compare("tanh"), 0);
} }
} }
...@@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { ...@@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("sigmoid"), 0); EXPECT_EQ(act_type.compare("sigmoid"), 0);
} }
} }
...@@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { ...@@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("mish"), 0); EXPECT_EQ(act_type.compare("mish"), 0);
} }
} }
...@@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { ...@@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("hard_swish"), 0); EXPECT_EQ(act_type.compare("hard_swish"), 0);
} }
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const {
// THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E.
// ABCD FOR 4D! BE AWARE OF THAT!
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"fc", 1}, {"transpose2", 2}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_reshape2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorReshape2 op_reshape2_pattern(
gpd.mutable_pattern(), op_type + "_reshape2_onednn_fuse_pass");
op_reshape2_pattern(op_type, num_of_outputs);
int found_operator_reshape2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_out, reshape2_out, op_reshape2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with reshape2.";
return;
}
if (operator_op->Op()->HasAttr("fused_unsqueeze2_axes")) {
VLOG(4) << "Cannot do " << op_type << " + reshape2 fuse, because "
<< op_type << " is already fused with unsqueeze2!";
return;
}
std::vector<int> reshape2_shape =
PADDLE_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
int num_of_minus_ones = 0;
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
if (reshape2_shape[i] == 0) {
VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, "
"skipping";
return;
} else if (reshape2_shape[i] == -1) {
++num_of_minus_ones;
}
}
if (num_of_minus_ones > 1) {
VLOG(4) << "Number of -1 values inside of reshape2 shouldn't be greater "
"than one in op+reshape2 oneDNN fuse pass, skipping";
return;
}
auto const &names = reshape2_op->Op()->InputNames();
bool has_shape_tensor =
std::find(names.begin(), names.end(), "ShapeTensor") != names.end();
bool has_shape_tensor_list =
std::find(names.begin(), names.end(), "ShapeTensorList") != names.end();
if (has_shape_tensor &&
reshape2_op->Op()->Input("ShapeTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensor!";
return;
}
if (has_shape_tensor_list &&
reshape2_op->Op()->Input("ShapeTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape);
operator_op->Op()->SetOutput("Out", {reshape2_out->Name()});
IR_OP_VAR_LINK(operator_op, reshape2_out);
GraphSafeRemoveNodes(g, {reshape2_op, operator_out});
found_operator_reshape2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_reshape2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_reshape2_count > 0)
PrettyLogDetail("--- fused %d %s with reshape2",
found_operator_reshape2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_reshape2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorReshape2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("reshape2", 0)
.GE("fc", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorReshape2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorReshape2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"transpose2", 2}, {"elementwise_mul", 1}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2(
Graph *graph, const std::string &op_type, int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_unsqueeze2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorUnsqueeze2 op_unsqueeze2_pattern(
gpd.mutable_pattern(), op_type + "_unsqueeze2_onednn_fuse_pass");
op_unsqueeze2_pattern(op_type, num_of_outputs);
int found_operator_unsqueeze2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_op, unsqueeze2_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with unsqueeze2.";
return;
}
std::vector<int> unsqueeze2_axes = PADDLE_GET_CONST(
std::vector<int>, unsqueeze2_op->Op()->GetAttr("axes"));
auto const &names = unsqueeze2_op->Op()->InputNames();
bool has_axes_tensor =
std::find(names.begin(), names.end(), "AxesTensor") != names.end();
bool has_axes_tensor_list =
std::find(names.begin(), names.end(), "AxesTensorList") != names.end();
if (has_axes_tensor &&
unsqueeze2_op->Op()->Input("AxesTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensor!";
return;
}
if (has_axes_tensor_list &&
unsqueeze2_op->Op()->Input("AxesTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes);
operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()});
IR_OP_VAR_LINK(operator_op, unsqueeze2_out);
GraphSafeRemoveNodes(g, {unsqueeze2_op, operator_out});
found_operator_unsqueeze2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_unsqueeze2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_unsqueeze2_count > 0)
PrettyLogDetail("--- fused %d %s with unsqueeze2",
found_operator_unsqueeze2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorUnsqueeze2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("unsqueeze2", 0)
.GE("transpose2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorUnsqueeze2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorUnsqueeze2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseUnsqueeze2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -327,6 +327,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -327,6 +327,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"shuffle_channel_mkldnn_detect_pass", // "shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", // "elt_act_mkldnn_fuse_pass", //
"operator_scale_onednn_fuse_pass", // "operator_scale_onednn_fuse_pass", //
"operator_unsqueeze2_onednn_fuse_pass", //
"operator_reshape2_onednn_fuse_pass", //
// TODO(intel): Please fix the bug on windows. // TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710 // https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after // "mkldnn_inplace_pass", // This pass should be activated after
...@@ -421,6 +423,8 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -421,6 +423,8 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("operator_scale_onednn_fuse_pass"); passes_.push_back("operator_scale_onednn_fuse_pass");
passes_.push_back("operator_unsqueeze2_onednn_fuse_pass");
passes_.push_back("operator_reshape2_onednn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass"); passes_.push_back("cpu_quantize_squash_pass");
......
...@@ -129,12 +129,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -129,12 +129,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
if (handler.use_broadcasting_hack == false) { if (handler.use_broadcasting_hack == false) {
z->set_mem_desc(dst_memory->get_desc()); platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc());
} else { } else {
auto dims = dst_memory->get_desc().dims(); auto dims = dst_memory->get_desc().dims();
dims.insert(dims.begin(), x->dims()[0]); dims.insert(dims.begin(), x->dims()[0]);
dims[1] /= dims[0]; dims[1] /= dims[0];
z->set_mem_desc(dst_memory->get_desc().reshape(dims)); platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc().reshape(dims));
} }
} }
}; };
......
...@@ -21,72 +21,8 @@ ...@@ -21,72 +21,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = phi::DenseTensor;
using framework::DataLayout; using phi::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(
nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
template <typename T> template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...@@ -98,37 +34,87 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -98,37 +34,87 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Transpose must use CPUPlace")); "Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size(); int ndims = transpose_axis.size();
auto* input = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
const T* input_data = input->data<T>();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*input, input->place(), output); framework::TensorCopy(*x, x->place(), out);
output->set_format(input->format()); out->set_mem_desc(x->mem_desc());
return; return;
} }
auto nchw_tz = phi::vectorize<int64_t>(input->dims()); auto x_vec_dims = phi::vectorize(x->dims());
TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine); framework::proto::VarType::Type x_paddle_type =
framework::TransToProtoVarType(x->dtype());
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x_paddle_type, x_type, dnnl_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data)); x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto dst_md =
transpose_p->execute( dnnl::memory::desc(x_vec_dims,
astream, *transpose_src_memory_p, *transpose_dst_memory_p); x->mem_desc().data_type(),
platform::GetPlainMKLDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
dst_md =
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides);
auto dst_data =
out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size());
auto reorder_dst_memory_p =
std::make_shared<dnnl::memory>(dst_md, dnnl_engine, dst_data);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kNCHW); platform::SetOutMemDescWithLogicalLayoutFusesSupport(
output->set_format(MKLDNNMemoryFormat::undef); ctx,
out,
reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
}
private:
// it is needed because oneDNN's permute axis understand axes order in
// different way PaddlePaddle's transpose
std::vector<int> TransposeToPermuteAxis(
const std::vector<int>& transpose_axis) const {
std::vector<int> permute_axis(transpose_axis.size());
for (size_t i = 0; i < transpose_axis.size(); ++i) {
permute_axis[transpose_axis[i]] = i;
}
return permute_axis;
}
std::vector<int64_t> FakeTranposeStrides(
const dnnl::memory::desc& dst_md,
const std::vector<int>& transpose_axis) const {
std::vector<int64_t> fake_strides(transpose_axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
int ndims = static_cast<int>(dims.size());
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= dims[transpose_axis[i]];
}
return fake_strides;
} }
}; };
...@@ -140,44 +126,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,44 +126,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
true, true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace")); "Operator DNNL TransposeGrad must use CPUPlace"));
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
if (!x_grad) return; if (!dx) return;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
int ndims = axis.size(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
int ndims = transpose_axis.size();
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*out_grad, out_grad->place(), x_grad); framework::TensorCopy(*dout, dout->place(), dx);
x_grad->set_format(out_grad->format()); dx->set_mem_desc(dout->mem_desc());
return; return;
} }
for (size_t i = 0; i < axis.size(); i++) { auto dout_vec_dims = phi::vectorize(dout->dims());
reversed_axis[axis[i]] = i;
}
const T* out_grad_data = out_grad->data<T>(); framework::proto::VarType::Type dout_paddle_type =
x_grad->mutable_data<T>(ctx.GetPlace()); framework::TransToProtoVarType(dout->dtype());
dnnl::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout_paddle_type);
auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims()); platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine);
TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_dst_memory_p =
out_grad->format(), platform::to_void_cast<T>(out_grad_data)); reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace());
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
transpose_p->execute( reorder_src_memory_p);
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
dx->set_mem_desc(
reorder_dst_memory_p->get_desc().permute_axes(transpose_axis));
} }
}; };
......
...@@ -111,6 +111,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx, ...@@ -111,6 +111,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx,
} }
} }
static void SetOutMemDescWithUnsqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const std::vector<int>& fused_unsqueeze2_axes =
ctx.Attr<std::vector<int>>("fused_unsqueeze2_axes");
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}
int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
unsqueezed_op_tz[i] = op_tz[j++];
}
}
out->set_mem_desc(out_md.reshape(unsqueezed_op_tz));
out->Resize(phi::make_ddim(unsqueezed_op_tz));
}
static void SetOutMemDescWithReshape2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
static void SetOutMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
if (ctx.HasAttr("fused_unsqueeze2_axes")) {
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else {
out->set_mem_desc(out_md);
}
}
template <typename T> template <typename T>
constexpr bool IsInt8() { constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册