未验证 提交 d0e19af3 编写于 作者: J jakpiase 提交者: GitHub

[CHERRY-PICK] Added caching to oneDNN FC and op+unsqueeze2 and op+reshape2 fuse passes (#47690)

* fc cherrypick

* another files added

* added transpose cherrypick

* reverter somebodys fc changes

* minor fix

* minor fix

* cherry-pick of fc+act changes

* minor fix

* fix
上级 cf668ab3
...@@ -219,6 +219,8 @@ if(WITH_MKLDNN) ...@@ -219,6 +219,8 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn) pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
...@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()( ...@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out; return activation_out;
} }
PDNode *patterns::OperatorUnsqueeze2::operator()(
const std::string &operator_type, const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("unsqueeze2");
auto *unsqueeze2_op =
pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2");
auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->AsOutput()
->assert_is_op_output("unsqueeze2");
preceding_op->LinksTo({preceding_op_out});
unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out});
return unsqueeze2_out;
}
PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type,
const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("reshape2");
auto *reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto *reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsOutput()
->assert_is_op_output("reshape2");
preceding_op->LinksTo({preceding_op_out});
reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out});
return reshape2_out;
}
PDNode *patterns::SeqConvEltAddRelu::operator()( PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) { paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators // Create Operators
......
...@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase { ...@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out); PATTERN_DECL_NODE(activation_out);
}; };
struct OperatorUnsqueeze2 : public PatternBase {
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(unsqueeze2_op);
PATTERN_DECL_NODE(unsqueeze2_out);
};
struct OperatorReshape2 : public PatternBase {
OperatorReshape2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_reshape2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
// SEQCONV with Elementwise_Add ReLU // SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu // op: seqconv + elementwise_add + relu
// named nodes: // named nodes:
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -26,20 +25,20 @@ namespace ir { ...@@ -26,20 +25,20 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = { auto act_types = paddle::platform::GetSupportedActivations();
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
for (std::string act_type : act_types) FuseFCAct(graph, act_type); for (auto act_type : act_types) FuseFCAct(graph, act_type);
} }
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const { const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph); FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act"); patterns::OperatorActivation fc_act_pattern(
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
fc_act_pattern("fc", act_type); fc_act_pattern("fc", act_type);
int found_fc_act_count = 0; int found_fc_act_count = 0;
...@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used.")); "is used."));
} }
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) { if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")); std::string gelu_act_type =
std::string type = approximate ? "_tanh" : "_erf"; PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
fc_op->SetAttr("activation_type", act_type + type); : "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else { } else {
fc_op->SetAttr("activation_type", act_type); fc_op->SetAttr("fuse_activation", act_type);
} }
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()}); fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out); IR_OP_VAR_LINK(fc, act_out);
...@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_fc_act_count); AddStatis(found_fc_act_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_fc_act_count > 0)
PrettyLogDetail( PrettyLogDetail(
"--- fused %d fc with %s activation", found_fc_act_count, act_type); "--- fused %d fc with %s activation", found_fc_act_count, act_type);
} }
...@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) ...@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0) .LE("fc", 0)
.LE("gelu", 0) .EQ("abs", 0)
.LE("sigmoid", 0) .LE("clip", 1)
.LE("mish", 1) .EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0) .LE("hard_swish", 0)
.LE("tanh", 0)); .LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
...@@ -23,21 +23,14 @@ namespace paddle { ...@@ -23,21 +23,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase { class FuseFCActOneDNNPass : public FusePassBase {
public: public:
virtual ~FuseFCActOneDNNPass() {} virtual ~FuseFCActOneDNNPass() {}
protected: protected:
void ApplyImpl(ir::Graph *graph) const override; void ApplyImpl(Graph *graph) const override;
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const; void FuseFCAct(Graph *graph, const std::string &act_types) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { ...@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_tanh"), 0); EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
} }
} }
...@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { ...@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_erf"), 0); EXPECT_EQ(act_type.compare("gelu_erf"), 0);
} }
} }
...@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { ...@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu"), 0); EXPECT_EQ(act_type.compare("gelu"), 0);
} }
} }
...@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { ...@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("tanh"), 0); EXPECT_EQ(act_type.compare("tanh"), 0);
} }
} }
...@@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { ...@@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("sigmoid"), 0); EXPECT_EQ(act_type.compare("sigmoid"), 0);
} }
} }
...@@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { ...@@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("mish"), 0); EXPECT_EQ(act_type.compare("mish"), 0);
} }
} }
...@@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { ...@@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("hard_swish"), 0); EXPECT_EQ(act_type.compare("hard_swish"), 0);
} }
} }
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const {
// THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E.
// ABCD FOR 4D! BE AWARE OF THAT!
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"fc", 1}, {"transpose2", 2}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_reshape2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorReshape2 op_reshape2_pattern(
gpd.mutable_pattern(), op_type + "_reshape2_onednn_fuse_pass");
op_reshape2_pattern(op_type, num_of_outputs);
int found_operator_reshape2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_out, reshape2_out, op_reshape2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with reshape2.";
return;
}
if (operator_op->Op()->HasAttr("fused_unsqueeze2_axes")) {
VLOG(4) << "Cannot do " << op_type << " + reshape2 fuse, because "
<< op_type << " is already fused with unsqueeze2!";
return;
}
std::vector<int> reshape2_shape =
PADDLE_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
int num_of_minus_ones = 0;
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
if (reshape2_shape[i] == 0) {
VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, "
"skipping";
return;
} else if (reshape2_shape[i] == -1) {
++num_of_minus_ones;
}
}
if (num_of_minus_ones > 1) {
VLOG(4) << "Number of -1 values inside of reshape2 shouldn't be greater "
"than one in op+reshape2 oneDNN fuse pass, skipping";
return;
}
auto const &names = reshape2_op->Op()->InputNames();
bool has_shape_tensor =
std::find(names.begin(), names.end(), "ShapeTensor") != names.end();
bool has_shape_tensor_list =
std::find(names.begin(), names.end(), "ShapeTensorList") != names.end();
if (has_shape_tensor &&
reshape2_op->Op()->Input("ShapeTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensor!";
return;
}
if (has_shape_tensor_list &&
reshape2_op->Op()->Input("ShapeTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape);
operator_op->Op()->SetOutput("Out", {reshape2_out->Name()});
IR_OP_VAR_LINK(operator_op, reshape2_out);
GraphSafeRemoveNodes(g, {reshape2_op, operator_out});
found_operator_reshape2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_reshape2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_reshape2_count > 0)
PrettyLogDetail("--- fused %d %s with reshape2",
found_operator_reshape2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_reshape2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorReshape2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("reshape2", 0)
.GE("fc", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorReshape2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorReshape2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"transpose2", 2}, {"elementwise_mul", 1}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2(
Graph *graph, const std::string &op_type, int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_unsqueeze2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorUnsqueeze2 op_unsqueeze2_pattern(
gpd.mutable_pattern(), op_type + "_unsqueeze2_onednn_fuse_pass");
op_unsqueeze2_pattern(op_type, num_of_outputs);
int found_operator_unsqueeze2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_op, unsqueeze2_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with unsqueeze2.";
return;
}
std::vector<int> unsqueeze2_axes = PADDLE_GET_CONST(
std::vector<int>, unsqueeze2_op->Op()->GetAttr("axes"));
auto const &names = unsqueeze2_op->Op()->InputNames();
bool has_axes_tensor =
std::find(names.begin(), names.end(), "AxesTensor") != names.end();
bool has_axes_tensor_list =
std::find(names.begin(), names.end(), "AxesTensorList") != names.end();
if (has_axes_tensor &&
unsqueeze2_op->Op()->Input("AxesTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensor!";
return;
}
if (has_axes_tensor_list &&
unsqueeze2_op->Op()->Input("AxesTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes);
operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()});
IR_OP_VAR_LINK(operator_op, unsqueeze2_out);
GraphSafeRemoveNodes(g, {unsqueeze2_op, operator_out});
found_operator_unsqueeze2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_unsqueeze2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_unsqueeze2_count > 0)
PrettyLogDetail("--- fused %d %s with unsqueeze2",
found_operator_unsqueeze2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorUnsqueeze2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("unsqueeze2", 0)
.GE("transpose2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorUnsqueeze2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorUnsqueeze2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseUnsqueeze2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -327,6 +327,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -327,6 +327,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"shuffle_channel_mkldnn_detect_pass", // "shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", // "elt_act_mkldnn_fuse_pass", //
"operator_scale_onednn_fuse_pass", // "operator_scale_onednn_fuse_pass", //
"operator_unsqueeze2_onednn_fuse_pass", //
"operator_reshape2_onednn_fuse_pass", //
// TODO(intel): Please fix the bug on windows. // TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710 // https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after // "mkldnn_inplace_pass", // This pass should be activated after
...@@ -421,6 +423,8 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -421,6 +423,8 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("operator_scale_onednn_fuse_pass"); passes_.push_back("operator_scale_onednn_fuse_pass");
passes_.push_back("operator_unsqueeze2_onednn_fuse_pass");
passes_.push_back("operator_reshape2_onednn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass"); passes_.push_back("cpu_quantize_squash_pass");
......
...@@ -129,12 +129,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -129,12 +129,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
if (handler.use_broadcasting_hack == false) { if (handler.use_broadcasting_hack == false) {
z->set_mem_desc(dst_memory->get_desc()); platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc());
} else { } else {
auto dims = dst_memory->get_desc().dims(); auto dims = dst_memory->get_desc().dims();
dims.insert(dims.begin(), x->dims()[0]); dims.insert(dims.begin(), x->dims()[0]);
dims[1] /= dims[0]; dims[1] /= dims[0];
z->set_mem_desc(dst_memory->get_desc().reshape(dims)); platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc().reshape(dims));
} }
} }
}; };
......
...@@ -16,10 +16,7 @@ limitations under the License. */ ...@@ -16,10 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,393 +26,131 @@ using dnnl::memory; ...@@ -29,393 +26,131 @@ using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using dnnl::prop_kind; using dnnl::prop_kind;
using dnnl::stream; using dnnl::stream;
using framework::DataLayout;
using framework::DDim; using framework::DDim;
using framework::ExecutionContext; using framework::ExecutionContext;
using framework::LoDTensor; using LoDTensor = phi::DenseTensor;
using framework::Tensor;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext; using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
struct InnerProductCache {
dnnl::inner_product_forward inner_product_p;
dnnl::memory src_mem;
dnnl::memory weights_mem;
dnnl::memory bias_mem;
dnnl::memory dst_mem;
};
template <typename T_in, typename T_w, typename T_out> template <typename T_in, typename T_w, typename T_out>
class FCPrimitiveFactory { class FCMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T_in,
dnnl::inner_product_forward> {
public: public:
explicit FCPrimitiveFactory(const dnnl::engine& engine) : engine_(engine) {} FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
void ExecuteFcPrimitive(const LoDTensor* input, const phi::DenseTensor* x,
const Tensor* weights, const phi::DenseTensor* weights,
const Tensor* bias, const phi::DenseTensor* bias,
LoDTensor* output, phi::DenseTensor* out,
const MKLDNNDeviceContext& dev_ctx, const int in_num_col_dims,
const ExecutionContext& ctx) { dnnl::engine mkldnn_engine,
RecomputeOutputDims(ctx, input, weights, output); platform::Place cpu_place)
// If primitive has already been created and cached, don't create new one, : platform::MKLDNNHandlerNoCachingT<T_in, dnnl::inner_product_forward>(
// but update input and output data pointers and return it. mkldnn_engine, cpu_place),
if (fc_) { dev_ctx_(dev_ctx) {
UpdateDataPointers(ctx, output, input); this->memory_key_ = ctx.InputName("W");
this->Execute();
return; auto x_vec_dims = phi::vectorize(x->dims());
} // Otherwise, create a new one. auto weights_vec_dims = phi::vectorize(weights->dims());
auto in_col_dims = ctx.Attr<int>("in_num_col_dims"); int MB = 1;
PADDLE_ENFORCE_LE( for (int i = 0; i < in_num_col_dims; ++i) {
in_col_dims, MB *= x_vec_dims[i];
2,
platform::errors::Unimplemented(
"DNNL FC doesn't support in_num_col_dims parameter to "
"be higher than "
"2."));
if (in_col_dims == 2) {
PADDLE_ENFORCE_EQ(
input->dims().size(),
3,
platform::errors::Unimplemented(
"DNNL FC only supports in_num_col_dims equal to 2 when "
"3 dim input is provided."));
PADDLE_ENFORCE_EQ(
input->format(),
MKLDNNMemoryFormat::ncw,
platform::errors::Unimplemented(
"DNNL FC only supports in_num_col_dims equal to 2 when "
"input format is equal to ncw."));
} }
weights_ = CreateWeightsMemory(weights); int IC = 1;
for (size_t i = in_num_col_dims; i < x_vec_dims.size(); ++i) {
// Since MKL-DNN has a lot of limitations on what the input/weights/output IC *= x_vec_dims[i];
// dimensions should be, to simplify the code, the creation of primitive
// descriptor has been divided into separate cases, based on the number
// of input dimensions.
size_t input_dim_num = input->dims().size();
paddle::optional<dnnl::inner_product_forward::primitive_desc> fc_prim_desc;
memory::desc usr_weights_desc = {};
switch (input_dim_num) {
case 2:
fc_prim_desc =
Create2DFcPrimDescriptor(input, weights, bias, output, ctx);
usr_weights_desc = Create2DUserWeightsDesc();
break;
case 3:
fc_prim_desc =
Create3DFcPrimDescriptor(input, weights, bias, output, ctx);
usr_weights_desc = Create3DUserWeightsDesc(weights);
break;
case 4:
fc_prim_desc =
Create4DFcPrimDescriptor(input, weights, bias, output, ctx);
usr_weights_desc = Create4DUserWeightsDesc(input, weights);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"DNNL FC doesn't support input dims different than 2, 3, 4."));
break;
} }
input_ = CreateMemory<T_in>(fc_prim_desc->src_desc(), input);
// Update weights format inside of its memory
weights_ = Reorder(
usr_weights_desc, usr_weights_desc, weights_->get_data_handle());
// Quantize weights and reorder to format chosen by FC primitive descriptor.
QuantizeWeights(ctx, fc_prim_desc->weights_desc());
bias_ = CreateMemoryToBeCached<float>(fc_prim_desc->bias_desc(), bias);
// If int8 is desired, quantize bias into 32-bit signed int
QuantizeBias(*fc_prim_desc, ctx);
// Store weights and bias in the mkldnn cache
CacheWeightsAndBias(dev_ctx, ctx);
// Based on format determined by inner_product, create output in desired
// memory format
output_ = CreateDstMemory(*fc_prim_desc, ctx, output);
// Return MKL-DNN primitive ready to be fed into pipeline and executed
fc_ = inner_product_forward(*fc_prim_desc);
this->Execute();
}
void Execute() { int OC = weights_vec_dims[1];
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (bias_) {
fc_->execute(astream,
{{DNNL_ARG_SRC, *input_},
{DNNL_ARG_WEIGHTS, *weights_},
{DNNL_ARG_BIAS, *bias_},
{DNNL_ARG_DST, *output_}});
} else {
fc_->execute(astream,
{{DNNL_ARG_SRC, *input_},
{DNNL_ARG_WEIGHTS, *weights_},
{DNNL_ARG_DST, *output_}});
}
astream.wait();
}
private: dnnl::memory::desc bias_md;
// DNNL always returns 2-dimensional data block as a result of computing
// inner product. Hence the format 'nc' is always set for its output
// primitive. Therefore, function SetOutputFormat is needed to choose
// an appropriate format based on the number of input dimensions and
// format of an input tensor.
void SetOutputFormat(MKLDNNMemoryFormat in_format, Tensor* out) {
int dim_num = out->dims().size();
// In case of 2 dims, we set the only possible format, nc
if (dim_num == 2) {
out->set_format(MKLDNNMemoryFormat::nc);
out->set_mem_desc({phi::vectorize(out->dims()),
platform::MKLDNNGetDataType<T_out>(),
out->format()});
// In case of 3 dims, we generate a format that is based on number
// of output dims and the layout of input format (nchw or nhwc).
} else if (dim_num == 3) {
if (in_format == MKLDNNMemoryFormat::nwc ||
in_format == MKLDNNMemoryFormat::nhwc) {
out->set_format(
platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nhwc));
} else {
out->set_format(
platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nchw));
}
// In any other case we overwrite the output format with the input one.
} else {
out->set_format(in_format);
}
}
void UpdateDataPointers(const ExecutionContext& ctx, auto src_md = dnnl::memory::desc(
Tensor* out, {MB, IC}, MKLDNNGetDataType<T_in>(), dnnl::memory::format_tag::any);
const Tensor* in) { auto weights_md = dnnl::memory::desc(
input_->set_data_handle(to_void_cast(in->data<T_in>())); {OC, IC}, MKLDNNGetDataType<T_w>(), dnnl::memory::format_tag::any);
output_->set_data_handle(out->mutable_data<T_out>(ctx.GetPlace())); auto dst_md = dnnl::memory::desc(
// If the primitive exists, but the output tensor has changed its {MB, OC}, MKLDNNGetDataType<T_out>(), dnnl::memory::format_tag::any);
// variable, update its format to what has been determined in first if (bias) {
// call to CreateFcPrimitive method. bias_md = dnnl::memory::desc({bias->numel()},
if (out->format() == MKLDNNMemoryFormat::undef) { MKLDNNGetDataType<float>(),
SetOutputFormat(in->format(), out); dnnl::memory::format_tag::a);
} }
}
dnnl::inner_product_forward::primitive_desc Create2DFcPrimDescriptor(
const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto src_desc = CreateMemDescriptor<T_in>(input, MKLDNNMemoryFormat::any);
auto weight_dims = Get2DWeightDimsForDNNL(weights);
auto weights_desc =
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
std::vector<int64_t> Get2DWeightDimsForDNNL(const Tensor* weights) {
auto dims = phi::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // swap input dim with output dim
return dims;
}
memory::desc Create2DUserWeightsDesc() { return weights_->get_desc(); }
dnnl::inner_product_forward::primitive_desc Create3DFcPrimDescriptor(
const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto input_dims = phi::vectorize(input->dims());
std::vector<int64_t> new_input_dims = {
input_dims[0] * input_dims[1], input_dims[2], 1};
auto src_desc =
CreateMemDescriptor<T_in>(new_input_dims, MKLDNNMemoryFormat::any);
auto weight_dims = Get3DWeightDimsForDNNL(weights);
auto weights_desc =
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]};
auto dst_desc =
CreateMemDescriptor<T_out>(dst_dims, MKLDNNMemoryFormat::any);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
std::vector<int64_t> Get3DWeightDimsForDNNL(const Tensor* weights) {
auto paddle_w_dims = phi::vectorize(weights->dims());
return {paddle_w_dims[1], paddle_w_dims[0], 1};
}
memory::desc Create3DUserWeightsDesc(const Tensor* weights) {
auto dims = Get3DWeightDimsForDNNL(weights);
return CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oiw);
}
dnnl::inner_product_forward::primitive_desc Create4DFcPrimDescriptor(
const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto src_desc = CreateMemDescriptor<T_in>(input, MKLDNNMemoryFormat::any);
// Since MKL-DNN doesn't support 4D column-major data formats in
// inner_product primitive, transpose the weights to be in
// row-major format
auto dims = Get4DWeightDimsForDNNL(input, weights);
auto weights_desc = CreateMemDescriptor<T_w>(dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreateFCAttrs(ctx); const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
std::vector<int64_t> Get4DWeightDimsForDNNL(const LoDTensor* input,
const Tensor* weights) {
auto old_w_dims = phi::vectorize(weights->dims());
auto old_in_dims = phi::vectorize(input->dims());
auto dims = {old_w_dims[1], old_in_dims[1], old_in_dims[2], old_in_dims[3]};
return dims;
}
memory::desc Create4DUserWeightsDesc(const LoDTensor* input,
const Tensor* weights) {
auto dims = Get4DWeightDimsForDNNL(input, weights);
return CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oihw);
}
// Convert data from one data format to another this->AcquireForwardPrimitiveDescriptor(attrs,
std::shared_ptr<dnnl::memory> Reorder(const memory::desc& src_desc, prop_kind::forward_inference,
const memory::desc& dst_desc, src_md,
void* src_data) { weights_md,
auto src_mem = memory(src_desc, engine_, src_data); bias_md,
auto dst_mem = std::make_shared<memory>(dst_desc, engine_); dst_md);
auto reorder = dnnl::reorder(src_mem, *dst_mem);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder.execute(astream, src_mem, *dst_mem);
astream.wait();
}
return dst_mem;
} }
// Convert data from one data format to another and rescale it. private:
// If the desired data type is (un)signed int8, quantization occurs here. dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
std::shared_ptr<dnnl::memory> ReorderWithScale(
const std::shared_ptr<memory> src_mem,
const memory::desc& dst_md,
const std::vector<float>& scale_data) {
auto dst_mem = std::make_shared<dnnl::memory>(dst_md, engine_);
dnnl::primitive_attr attributes; dnnl::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which dnnl::post_ops post_operations;
// dimensions should the scale be applied.
// 0 - Single scale applied to whole tensor
// 1 - Apply Scale along a slice of each dimension which index is 1.
// In case of weights quantization, that dimension is output,
// becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data);
auto reorder = dnnl::reorder(*src_mem, *dst_mem, attributes);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); std::vector<float> output_shift_scale;
{ float scale = 1.0f;
platform::RecordEvent record_reorder( if (IsInt8<T_w>()) {
"int_reorder", std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
platform::TracerEventType::UserDefined, int mask = CreateMask(1, output_shift_scale.size() > 1);
2, attributes.set_output_scales(mask, output_shift_scale);
platform::EventRole::kUniqueOp);
reorder.execute(astream,
{{DNNL_ARG_FROM, *src_mem}, {DNNL_ARG_TO, *dst_mem}});
astream.wait();
} }
return dst_mem; float sum_scale = 1.0f;
} if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
template <typename T> post_operations.append_sum(sum_scale);
static dnnl::memory::desc CreateMemDescriptor( }
const std::vector<int64_t>& dims, MKLDNNMemoryFormat format) {
return platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), format);
}
template <typename T>
static dnnl::memory::desc CreateMemDescriptor(const Tensor* tensor,
MKLDNNMemoryFormat format) {
auto dims = phi::vectorize(tensor->dims());
return CreateMemDescriptor<T>(dims, format);
}
template <typename T>
dnnl::memory CreateMemory(const dnnl::memory::desc& desc,
const Tensor* tensor) {
return CreateMemory(desc, platform::to_void_cast<T>(tensor->data<T>()));
}
dnnl::memory CreateMemory(const dnnl::memory::desc& desc, void* data) {
return memory(desc, engine_, data);
}
template <typename T>
std::shared_ptr<dnnl::memory> CreateMemoryToBeCached(
const dnnl::memory::desc& desc, const Tensor* tensor) {
return CreateMemoryToBeCached(desc,
platform::to_void_cast<T>(tensor->data<T>()));
}
std::shared_ptr<dnnl::memory> CreateMemoryToBeCached(
const dnnl::memory::desc& desc, void* data) {
return std::make_shared<memory>(desc, engine_, data);
}
// Create weights memory and transform to default MKL-DNN format // ReLU from "fc_fuse_pass"
std::shared_ptr<dnnl::memory> CreateWeightsMemory(const Tensor* weights) { if (ctx.Attr<std::string>("activation_type") == "relu") {
auto dims = phi::vectorize(weights->dims()); post_operations.append_eltwise(
std::swap(dims[0], dims[1]); // Correct output dimensions scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io); }
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi); platform::AppendActivation(ctx, post_operations, scale);
// Transpose weights through MKL-DNN's reorder from io to oi format.
return Reorder(src_desc,
dst_desc,
platform::to_void_cast<float>(weights->data<float>()));
}
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, if (ctx.HasAttr("fused_output_scale")) {
const ExecutionContext& ctx) { float scale_alpha = ctx.Attr<float>("fused_output_scale");
std::string key = platform::CreateKey(dev_ctx); post_operations.append_eltwise(
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); 1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
const std::string weights_key = key + ctx.InputName("W"); attributes.set_post_ops(post_operations);
const std::string bias_key = key + ctx.InputName("Bias"); return attributes;
dev_ctx.SetBlob(weights_key, weights_);
dev_ctx.SetBlob(bias_key, bias_);
} }
// Compute the bias scales so that its values correspond to the // Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication // scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) { std::vector<float> ComputeBiasScales(
auto scale_in_data = ctx.Attr<float>("Scale_in"); const float scale_in, const std::vector<float>& scale_weights) {
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights"); std::vector<float> bias_scales(scale_weights.size());
const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> bias_scales(weight_scales_num);
#pragma omp parallel for for (size_t i = 0; i < bias_scales.size(); ++i) {
for (size_t i = 0; i < weight_scales_num; i++) { if (scale_weights[i] == 0.0)
if (scale_weights_data[i] == 0.0)
bias_scales[i] = 1.0f; bias_scales[i] = 1.0f;
else else
bias_scales[i] = scale_in_data * scale_weights_data[i]; bias_scales[i] = scale_in * scale_weights[i];
} }
return bias_scales; return bias_scales;
...@@ -442,18 +177,16 @@ class FCPrimitiveFactory { ...@@ -442,18 +177,16 @@ class FCPrimitiveFactory {
? 1.0f ? 1.0f
: ctx.Attr<float>("Scale_out"); : ctx.Attr<float>("Scale_out");
const size_t weight_scales_num = scale_weights_data.size(); const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> output_shift_scale(weight_scales_num);
#pragma omp parallel for for (size_t i = 0; i < weight_scales_num; ++i) {
for (size_t i = 0; i < weight_scales_num; i++) {
if (scale_weights_data[i] == 0.0) if (scale_weights_data[i] == 0.0)
output_shift_scale[i] = inner_scale; scale_weights_data[i] = inner_scale;
else else
output_shift_scale[i] = scale_weights_data[i] =
inner_scale / (scale_in_data * scale_weights_data[i]); inner_scale / (scale_in_data * scale_weights_data[i]);
} }
return make_tuple(output_shift_scale, scale); return make_tuple(scale_weights_data, scale);
} }
// Computing MKL-DNN's scaling mask which determines along which dimension // Computing MKL-DNN's scaling mask which determines along which dimension
...@@ -464,137 +197,300 @@ class FCPrimitiveFactory { ...@@ -464,137 +197,300 @@ class FCPrimitiveFactory {
return is_multi_channel_quantizied ? 1 << slice_dimension : 0; return is_multi_channel_quantizied ? 1 << slice_dimension : 0;
} }
void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) { std::shared_ptr<dnnl::memory> AcquireMemoryWithReorderAndAttrs(
weights_ = ReorderWithScale( const dnnl::memory::desc& user_md,
weights_, dst, ctx.Attr<std::vector<float>>("Scale_weights")); const dnnl::memory::desc& target_md,
} void* ptr,
const dnnl::primitive_attr& attrs) {
std::shared_ptr<dnnl::memory> target_memory_p;
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc, auto user_memory_p =
const ExecutionContext& ctx) { std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
auto bias_scales = ComputeBiasScales(ctx); target_memory_p = std::make_shared<dnnl::memory>(target_md, this->engine_);
bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales); auto reorder_p = std::make_shared<dnnl::reorder>(
} *user_memory_p, *target_memory_p, attrs);
dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) { auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::primitive_attr attributes; reorder_p->execute(
dnnl::post_ops post_operations; astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
std::vector<float> output_shift_scale; return target_memory_p;
float scale; }
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
float sum_scale = 1.0f; std::string memory_key_;
if (ctx.HasAttr("fuse_residual_connection") && const platform::MKLDNNDeviceContext& dev_ctx_;
ctx.Attr<bool>("fuse_residual_connection")) {
post_operations.append_sum(sum_scale);
}
if (ctx.Attr<std::string>("activation_type") == "relu") { public:
constexpr float negative_slope = 0.0f; std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
constexpr float placeholder = 1.0f; // beta const phi::DenseTensor* x) {
post_operations.append_eltwise( const T_in* x_data = x->data<T_in>();
scale, dnnl::algorithm::eltwise_relu, negative_slope, placeholder);
} else if (ctx.Attr<std::string>("activation_type") == "gelu") { auto user_md = x->mem_desc();
constexpr float alpha = 0.0f; if (x->dims().size() != 2) {
constexpr float beta = 0.0f; // reshape restrictions are always satisfied because in case of 3 or 4 dim
post_operations.append_eltwise( // input, plain layout is enforced
scale, dnnl::algorithm::eltwise_gelu, alpha, beta); user_md = user_md.reshape(this->fwd_pd_->src_desc().dims());
} else if (ctx.Attr<std::string>("activation_type") == "gelu_tanh") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_gelu_tanh, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_erf") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_gelu_erf, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "tanh") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_tanh, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "sigmoid") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_logistic, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "mish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_mish, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "hard_swish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_hardswish, alpha, beta);
} }
if (ctx.HasAttr("fused_output_scale")) { return this->AcquireMemoryWithReorder(
float scale_alpha = ctx.Attr<float>("fused_output_scale"); user_md, this->fwd_pd_->src_desc(), to_void_cast<T_in>(x_data));
post_operations.append_eltwise( }
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
attributes.set_post_ops(post_operations); std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
return attributes; const phi::DenseTensor* bias,
const float scale_in,
const std::vector<float>& scale_weights) {
const float* bias_data = bias->data<float>();
if (IsInt8<T_w>() == false) {
// for BF16/FP32 bias is 1D and has no scales, so reorder is not needed
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data));
} else {
const std::string bias_key = this->memory_key_ + "@bias";
auto memory_p = std::static_pointer_cast<dnnl::memory>(
this->dev_ctx_.GetBlob(bias_key));
if (!memory_p) {
const auto& scale_data = ComputeBiasScales(scale_in, scale_weights);
dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data);
auto user_md = dnnl::memory::desc({bias->dims()[0]},
MKLDNNGetDataType<float>(),
dnnl::memory::format_tag::a);
memory_p = this->AcquireMemoryWithReorderAndAttrs(
user_md,
this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data),
attrs);
this->dev_ctx_.SetBlob(bias_key, memory_p);
}
return memory_p;
}
} }
dnnl::inner_product_forward::primitive_desc CreateFcPrimDesc( std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
const dnnl::memory::desc& input_desc, const phi::DenseTensor* weights, const std::vector<float>& scale_data) {
const dnnl::memory::desc& weights_desc, const std::string weights_key = this->memory_key_ + "@weights";
const dnnl::memory::desc& bias_desc, auto memory_p = std::static_pointer_cast<dnnl::memory>(
const dnnl::memory::desc& dst_desc, this->dev_ctx_.GetBlob(weights_key));
const dnnl::primitive_attr& attrs) {
auto fc_desc = inner_product_forward::desc(prop_kind::forward_scoring, if (!memory_p) {
input_desc, const float* weights_data = weights->data<float>();
weights_desc, auto weights_dims = this->fwd_pd_->weights_desc().dims();
bias_desc,
dst_desc);
return inner_product_forward::primitive_desc(fc_desc, attrs, engine_); auto user_md = dnnl::memory::desc(weights_dims,
MKLDNNGetDataType<float>(),
dnnl::memory::format_tag::io);
if (IsInt8<T_w>()) {
dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data);
memory_p = this->AcquireMemoryWithReorderAndAttrs(
user_md,
this->fwd_pd_->weights_desc(),
to_void_cast<float>(weights_data),
attrs);
} else {
memory_p =
this->AcquireMemoryWithReorder(user_md,
this->fwd_pd_->weights_desc(),
to_void_cast<float>(weights_data));
}
this->dev_ctx_.SetBlob(weights_key, memory_p);
}
return memory_p;
} }
// Create output memory based on output tensor and inner_product std::shared_ptr<dnnl::memory> AcquireCustomDstMemory(
// primitive descriptor format chosen for output const ExecutionContext& ctx, phi::DenseTensor* out) {
dnnl::memory CreateDstMemory(
const dnnl::inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx,
Tensor* output) {
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) { ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<Tensor>("ResidualData"); auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output->dims(), out->dims(),
residual_param->dims(), residual_param->dims(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d" "same dimension sizes, but got output's dimension = %d"
" and residual param's dimension =%d .", " and residual param's dimension =%d .",
output->dims().size(), out->dims().size(),
residual_param->dims().size())); residual_param->dims().size()));
output->ShareDataWith(*residual_param); out->ShareDataWith(*residual_param);
} }
return this->template AcquireDstMemory<T_out>(out);
} // namespace operators
}; // namespace paddle
auto dst_desc = fc_prim_desc.dst_desc(); template <typename T_in, typename T_w>
auto buffer_size = dst_desc.get_size(); class FCMKLDNNKernel : public framework::OpKernel<T_in> {
T_out* output_data = public:
output->mutable_data<T_out>(ctx.GetPlace(), buffer_size); void Compute(const framework::ExecutionContext& ctx) const override {
memory dst_mem(dst_desc, engine_, to_void_cast<T_out>(output_data)); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
SetOutputFormat(ctx.Input<LoDTensor>("Input")->format(), output); bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
return dst_mem; if (force_fp32_output) {
this->RunKernel<float>(ctx);
} else if (IsInt8<T_in>()) {
if (fuse_relu) {
this->RunKernel<uint8_t>(ctx);
} else {
this->RunKernel<int8_t>(ctx);
}
} else {
this->RunKernel<T_in>(ctx);
}
}
void PrepareSrcMem(const std::shared_ptr<inner_product_forward>& fc_p,
const std::shared_ptr<dnnl::memory>& src_mem,
const LoDTensor* x,
const dnnl::engine& engine) const {
auto x_md = x->mem_desc().reshape(src_mem->get_desc().dims());
if (x_md != src_mem->get_desc()) {
dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>()));
auto reorder_p = dnnl::reorder(x_mem, *src_mem);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p.execute(astream, x_mem, *src_mem);
astream.wait();
} else {
src_mem->set_data_handle(to_void_cast<T_in>(x->data<T_in>()));
}
}
template <typename T_out = T_w>
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<LoDTensor>("Input");
const auto* weights = ctx.Input<phi::DenseTensor>("W");
const auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto out = ctx.Output<LoDTensor>("Out");
const float scale_in = ctx.Attr<float>("Scale_in");
const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
std::shared_ptr<dnnl::inner_product_forward> fc_p;
std::shared_ptr<dnnl::memory> src_memory_p;
std::shared_ptr<dnnl::memory> weights_memory_p;
std::shared_ptr<dnnl::memory> bias_memory_p;
std::shared_ptr<dnnl::memory> dst_memory_p;
std::string cache_key;
cache_key.reserve(64);
cache_key = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx,
platform::CreateKey(dev_ctx,
ctx.InputName("Input"),
ctx.InputName("W"),
phi::vectorize(x->dims())));
auto inner_product_cache =
std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));
RecomputeOutputDims(ctx, x, weights, out);
if (inner_product_cache) {
fc_p = std::make_shared<dnnl::inner_product_forward>(
inner_product_cache->inner_product_p);
src_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->src_mem);
PrepareSrcMem(fc_p, src_memory_p, x, mkldnn_engine);
weights_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->weights_mem);
dst_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->dst_mem);
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData");
out->ShareDataWith(*residual_param);
}
auto out_ptr = out->mutable_data<T_out>(
ctx.GetPlace(), dst_memory_p->get_desc().get_size());
dst_memory_p->set_data_handle(out_ptr);
if (bias) {
bias_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->bias_mem);
}
} else {
auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
FCMKLDNNHandler<T_in, T_w, T_out> handler(ctx,
dev_ctx,
x,
weights,
bias,
out,
in_col_dims,
mkldnn_engine,
ctx.GetPlace());
src_memory_p = handler.AcquireSrcMemoryWithReorder(x);
weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(weights, scale_weights);
dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
if (bias) {
bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights);
}
fc_p = handler.AcquireForwardPrimitive();
}
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> fc_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
fc_p->execute(astream, fc_args);
astream.wait();
if (!inner_product_cache) {
auto ip_cache = std::make_shared<InnerProductCache>();
ip_cache->inner_product_p = *fc_p;
ip_cache->src_mem = *src_memory_p;
ip_cache->weights_mem = *weights_memory_p;
ip_cache->dst_mem = *dst_memory_p;
if (bias) {
ip_cache->bias_mem = *bias_memory_p;
}
dev_ctx.SetBlob(cache_key, ip_cache);
}
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx,
out,
dst_memory_p->get_desc().reshape(phi::vectorize(out->dims())));
} }
void RecomputeOutputDims(const ExecutionContext& ctx, void RecomputeOutputDims(const ExecutionContext& ctx,
const LoDTensor* input, const LoDTensor* x,
const Tensor* w, const phi::DenseTensor* weights,
LoDTensor* output) { LoDTensor* out) const {
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims"); int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
bool padding_weights = ctx.Attr<bool>("padding_weights"); bool padding_weights = ctx.Attr<bool>("padding_weights");
PADDLE_ENFORCE_EQ(padding_weights, PADDLE_ENFORCE_EQ(padding_weights,
...@@ -602,102 +498,16 @@ class FCPrimitiveFactory { ...@@ -602,102 +498,16 @@ class FCPrimitiveFactory {
platform::errors::PermissionDenied( platform::errors::PermissionDenied(
"Weight padding in fc can not be used in MKLDNN.")); "Weight padding in fc can not be used in MKLDNN."));
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), FCOutputSize(x->dims(),
w->dims(), weights->dims(),
output_dims, output_dims,
in_num_col_dims, in_num_col_dims,
padding_weights); padding_weights);
output->Resize(phi::make_ddim(output_dims)); out->Resize(phi::make_ddim(output_dims));
output->set_lod(input->lod()); out->set_lod(x->lod());
} }
private:
const dnnl::engine& engine_;
paddle::optional<memory> input_;
paddle::optional<memory> output_;
std::shared_ptr<memory> bias_;
std::shared_ptr<memory> weights_;
paddle::optional<inner_product_forward> fc_;
}; };
// Attempt to fetch cached primitive factory based on provided parameters
// of input format, weight dimensions and output name.
// If not cached, create a new one.
template <typename T_in, typename T_w, typename T_out>
static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>>
GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const std::string& key) {
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator = std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetEngine());
dev_ctx.SetBlob(key, prim_creator);
}
return prim_creator;
}
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template <typename T_in, typename T_w>
static void ExecuteFc(const ExecutionContext& ctx,
const LoDTensor* input,
const Tensor* w,
const Tensor* bias,
LoDTensor* output,
bool fuse_relu,
bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
std::string prim_key = platform::CreateKey(dev_ctx,
input->format(),
input->dims()[0],
phi::vectorize<int>(w->dims()),
ctx.OutputName("Out"));
prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key);
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
if ((!is_int8 && !is_bfloat16) || force_fp32_output) {
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (is_bfloat16) {
GetPrimitiveFactory<T_in, T_w, platform::bfloat16>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (fuse_relu) {
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else {
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
}
}
template <typename T_in, typename T_w>
class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()),
true,
platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace."));
platform::MKLDNNDeviceContext::tls().log_lib_version();
auto input = ctx.Input<LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<LoDTensor>("Out");
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
ExecuteFc<T_in, T_w>(
ctx, input, w, bias, output, fuse_relu, force_fp32_output);
output->set_layout(DataLayout::kMKLDNN);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -710,7 +520,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, ...@@ -710,7 +520,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
FP32, FP32,
ops::kFCMKLDNNFP32, ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<float, float>); ops::FCMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
fc, fc,
...@@ -718,19 +528,19 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( ...@@ -718,19 +528,19 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
BF16, BF16,
ops::kFCMKLDNNFP32, ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<paddle::platform::bfloat16, ops::FCMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>); paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
U8, U8,
ops::kFCMKLDNNINT8, ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<uint8_t, int8_t>); ops::FCMKLDNNKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
S8, S8,
ops::kFCMKLDNNINT8, ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<int8_t, int8_t>); ops::FCMKLDNNKernel<int8_t, int8_t>);
...@@ -21,72 +21,8 @@ ...@@ -21,72 +21,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = phi::DenseTensor;
using framework::DataLayout; using phi::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(
nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
template <typename T> template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...@@ -98,37 +34,87 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -98,37 +34,87 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Transpose must use CPUPlace")); "Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size(); int ndims = transpose_axis.size();
auto* input = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
const T* input_data = input->data<T>();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*input, input->place(), output); framework::TensorCopy(*x, x->place(), out);
output->set_format(input->format()); out->set_mem_desc(x->mem_desc());
return; return;
} }
auto nchw_tz = phi::vectorize<int64_t>(input->dims()); auto x_vec_dims = phi::vectorize(x->dims());
TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine); framework::proto::VarType::Type x_paddle_type =
framework::TransToProtoVarType(x->dtype());
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x_paddle_type, x_type, dnnl_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data)); x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto dst_md =
transpose_p->execute( dnnl::memory::desc(x_vec_dims,
astream, *transpose_src_memory_p, *transpose_dst_memory_p); x->mem_desc().data_type(),
platform::GetPlainMKLDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
dst_md =
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides);
auto dst_data =
out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size());
auto reorder_dst_memory_p =
std::make_shared<dnnl::memory>(dst_md, dnnl_engine, dst_data);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kNCHW); platform::SetOutMemDescWithLogicalLayoutFusesSupport(
output->set_format(MKLDNNMemoryFormat::undef); ctx,
out,
reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
}
private:
// it is needed because oneDNN's permute axis understand axes order in
// different way PaddlePaddle's transpose
std::vector<int> TransposeToPermuteAxis(
const std::vector<int>& transpose_axis) const {
std::vector<int> permute_axis(transpose_axis.size());
for (size_t i = 0; i < transpose_axis.size(); ++i) {
permute_axis[transpose_axis[i]] = i;
}
return permute_axis;
}
std::vector<int64_t> FakeTranposeStrides(
const dnnl::memory::desc& dst_md,
const std::vector<int>& transpose_axis) const {
std::vector<int64_t> fake_strides(transpose_axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
int ndims = static_cast<int>(dims.size());
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= dims[transpose_axis[i]];
}
return fake_strides;
} }
}; };
...@@ -140,44 +126,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,44 +126,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
true, true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace")); "Operator DNNL TransposeGrad must use CPUPlace"));
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
if (!x_grad) return; if (!dx) return;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
int ndims = axis.size(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
int ndims = transpose_axis.size();
if (ndims == 1) { if (ndims == 1) {
framework::TensorCopy(*out_grad, out_grad->place(), x_grad); framework::TensorCopy(*dout, dout->place(), dx);
x_grad->set_format(out_grad->format()); dx->set_mem_desc(dout->mem_desc());
return; return;
} }
for (size_t i = 0; i < axis.size(); i++) { auto dout_vec_dims = phi::vectorize(dout->dims());
reversed_axis[axis[i]] = i;
}
const T* out_grad_data = out_grad->data<T>(); framework::proto::VarType::Type dout_paddle_type =
x_grad->mutable_data<T>(ctx.GetPlace()); framework::TransToProtoVarType(dout->dtype());
dnnl::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout_paddle_type);
auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims()); platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine);
TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine); auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto transpose_src_memory_p = handler.AcquireSrcMemory( auto reorder_dst_memory_p =
out_grad->format(), platform::to_void_cast<T>(out_grad_data)); reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace());
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
transpose_p->execute( reorder_src_memory_p);
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
dx->set_mem_desc(
reorder_dst_memory_p->get_desc().permute_axes(transpose_axis));
} }
}; };
......
...@@ -111,6 +111,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx, ...@@ -111,6 +111,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx,
} }
} }
static void SetOutMemDescWithUnsqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const std::vector<int>& fused_unsqueeze2_axes =
ctx.Attr<std::vector<int>>("fused_unsqueeze2_axes");
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}
int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
unsqueezed_op_tz[i] = op_tz[j++];
}
}
out->set_mem_desc(out_md.reshape(unsqueezed_op_tz));
out->Resize(phi::make_ddim(unsqueezed_op_tz));
}
static void SetOutMemDescWithReshape2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
static void SetOutMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
if (ctx.HasAttr("fused_unsqueeze2_axes")) {
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else {
out->set_mem_desc(out_md);
}
}
template <typename T> template <typename T>
constexpr bool IsInt8() { constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册