未验证 提交 d0e19af3 编写于 作者: J jakpiase 提交者: GitHub

[CHERRY-PICK] Added caching to oneDNN FC and op+unsqueeze2 and op+reshape2 fuse passes (#47690)

* fc cherrypick

* another files added

* added transpose cherrypick

* reverter somebodys fc changes

* minor fix

* minor fix

* cherry-pick of fc+act changes

* minor fix

* fix
上级 cf668ab3
......@@ -219,6 +219,8 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
......@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out;
}
PDNode *patterns::OperatorUnsqueeze2::operator()(
const std::string &operator_type, const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("unsqueeze2");
auto *unsqueeze2_op =
pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2");
auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->AsOutput()
->assert_is_op_output("unsqueeze2");
preceding_op->LinksTo({preceding_op_out});
unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out});
return unsqueeze2_out;
}
PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type,
const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("reshape2");
auto *reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto *reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsOutput()
->assert_is_op_output("reshape2");
preceding_op->LinksTo({preceding_op_out});
reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out});
return reshape2_out;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
......
......@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};
struct OperatorUnsqueeze2 : public PatternBase {
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(unsqueeze2_op);
PATTERN_DECL_NODE(unsqueeze2_out);
};
struct OperatorReshape2 : public PatternBase {
OperatorReshape2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_reshape2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,9 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -26,20 +25,20 @@ namespace ir {
using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
auto act_types = paddle::platform::GetSupportedActivations();
for (std::string act_type : act_types) FuseFCAct(graph, act_type);
for (auto act_type : act_types) FuseFCAct(graph, act_type);
}
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph);
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act");
patterns::OperatorActivation fc_act_pattern(
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
fc_act_pattern("fc", act_type);
int found_fc_act_count = 0;
......@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used."));
}
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
fc_op->SetAttr("activation_type", act_type + type);
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fc_op->SetAttr("activation_type", act_type);
fc_op->SetAttr("fuse_activation", act_type);
}
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out);
......@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
gpd(graph, handler);
AddStatis(found_fc_act_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs"))
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_fc_act_count > 0)
PrettyLogDetail(
"--- fused %d fc with %s activation", found_fc_act_count, act_type);
}
......@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0)
.LE("gelu", 0)
.LE("sigmoid", 0)
.LE("mish", 1)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("tanh", 0));
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
......@@ -23,21 +23,14 @@ namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase {
public:
virtual ~FuseFCActOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
void ApplyImpl(Graph *graph) const override;
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const;
void FuseFCAct(Graph *graph, const std::string &act_types) const;
};
} // namespace ir
......
......@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
}
}
......@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_erf"), 0);
}
}
......@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu"), 0);
}
}
......@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("tanh"), 0);
}
}
......@@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("sigmoid"), 0);
}
}
......@@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("mish"), 0);
}
}
......@@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("hard_swish"), 0);
}
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const {
// THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E.
// ABCD FOR 4D! BE AWARE OF THAT!
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"fc", 1}, {"transpose2", 2}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_reshape2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorReshape2 op_reshape2_pattern(
gpd.mutable_pattern(), op_type + "_reshape2_onednn_fuse_pass");
op_reshape2_pattern(op_type, num_of_outputs);
int found_operator_reshape2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_out, reshape2_out, op_reshape2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with reshape2.";
return;
}
if (operator_op->Op()->HasAttr("fused_unsqueeze2_axes")) {
VLOG(4) << "Cannot do " << op_type << " + reshape2 fuse, because "
<< op_type << " is already fused with unsqueeze2!";
return;
}
std::vector<int> reshape2_shape =
PADDLE_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
int num_of_minus_ones = 0;
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
if (reshape2_shape[i] == 0) {
VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, "
"skipping";
return;
} else if (reshape2_shape[i] == -1) {
++num_of_minus_ones;
}
}
if (num_of_minus_ones > 1) {
VLOG(4) << "Number of -1 values inside of reshape2 shouldn't be greater "
"than one in op+reshape2 oneDNN fuse pass, skipping";
return;
}
auto const &names = reshape2_op->Op()->InputNames();
bool has_shape_tensor =
std::find(names.begin(), names.end(), "ShapeTensor") != names.end();
bool has_shape_tensor_list =
std::find(names.begin(), names.end(), "ShapeTensorList") != names.end();
if (has_shape_tensor &&
reshape2_op->Op()->Input("ShapeTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensor!";
return;
}
if (has_shape_tensor_list &&
reshape2_op->Op()->Input("ShapeTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape);
operator_op->Op()->SetOutput("Out", {reshape2_out->Name()});
IR_OP_VAR_LINK(operator_op, reshape2_out);
GraphSafeRemoveNodes(g, {reshape2_op, operator_out});
found_operator_reshape2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_reshape2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_reshape2_count > 0)
PrettyLogDetail("--- fused %d %s with reshape2",
found_operator_reshape2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_reshape2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorReshape2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("reshape2", 0)
.GE("fc", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorReshape2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorReshape2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"transpose2", 2}, {"elementwise_mul", 1}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2(
Graph *graph, const std::string &op_type, int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_unsqueeze2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorUnsqueeze2 op_unsqueeze2_pattern(
gpd.mutable_pattern(), op_type + "_unsqueeze2_onednn_fuse_pass");
op_unsqueeze2_pattern(op_type, num_of_outputs);
int found_operator_unsqueeze2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_op, unsqueeze2_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with unsqueeze2.";
return;
}
std::vector<int> unsqueeze2_axes = PADDLE_GET_CONST(
std::vector<int>, unsqueeze2_op->Op()->GetAttr("axes"));
auto const &names = unsqueeze2_op->Op()->InputNames();
bool has_axes_tensor =
std::find(names.begin(), names.end(), "AxesTensor") != names.end();
bool has_axes_tensor_list =
std::find(names.begin(), names.end(), "AxesTensorList") != names.end();
if (has_axes_tensor &&
unsqueeze2_op->Op()->Input("AxesTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensor!";
return;
}
if (has_axes_tensor_list &&
unsqueeze2_op->Op()->Input("AxesTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes);
operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()});
IR_OP_VAR_LINK(operator_op, unsqueeze2_out);
GraphSafeRemoveNodes(g, {unsqueeze2_op, operator_out});
found_operator_unsqueeze2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_unsqueeze2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_unsqueeze2_count > 0)
PrettyLogDetail("--- fused %d %s with unsqueeze2",
found_operator_unsqueeze2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorUnsqueeze2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("unsqueeze2", 0)
.GE("transpose2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorUnsqueeze2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorUnsqueeze2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseUnsqueeze2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -327,6 +327,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", //
"operator_scale_onednn_fuse_pass", //
"operator_unsqueeze2_onednn_fuse_pass", //
"operator_reshape2_onednn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
......@@ -421,6 +423,8 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("operator_scale_onednn_fuse_pass");
passes_.push_back("operator_unsqueeze2_onednn_fuse_pass");
passes_.push_back("operator_reshape2_onednn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
......
......@@ -129,12 +129,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
astream.wait();
if (handler.use_broadcasting_hack == false) {
z->set_mem_desc(dst_memory->get_desc());
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc());
} else {
auto dims = dst_memory->get_desc().dims();
dims.insert(dims.begin(), x->dims()[0]);
dims[1] /= dims[0];
z->set_mem_desc(dst_memory->get_desc().reshape(dims));
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc().reshape(dims));
}
}
};
......
......@@ -16,10 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace phi {
class DenseTensor;
} // namespace phi
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
......@@ -29,393 +26,131 @@ using dnnl::memory;
using dnnl::primitive;
using dnnl::prop_kind;
using dnnl::stream;
using framework::DataLayout;
using framework::DDim;
using framework::ExecutionContext;
using framework::LoDTensor;
using framework::Tensor;
using LoDTensor = phi::DenseTensor;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
struct InnerProductCache {
dnnl::inner_product_forward inner_product_p;
dnnl::memory src_mem;
dnnl::memory weights_mem;
dnnl::memory bias_mem;
dnnl::memory dst_mem;
};
template <typename T_in, typename T_w, typename T_out>
class FCPrimitiveFactory {
class FCMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T_in,
dnnl::inner_product_forward> {
public:
explicit FCPrimitiveFactory(const dnnl::engine& engine) : engine_(engine) {}
void ExecuteFcPrimitive(const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
RecomputeOutputDims(ctx, input, weights, output);
// If primitive has already been created and cached, don't create new one,
// but update input and output data pointers and return it.
if (fc_) {
UpdateDataPointers(ctx, output, input);
this->Execute();
return;
} // Otherwise, create a new one.
auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
PADDLE_ENFORCE_LE(
in_col_dims,
2,
platform::errors::Unimplemented(
"DNNL FC doesn't support in_num_col_dims parameter to "
"be higher than "
"2."));
if (in_col_dims == 2) {
PADDLE_ENFORCE_EQ(
input->dims().size(),
3,
platform::errors::Unimplemented(
"DNNL FC only supports in_num_col_dims equal to 2 when "
"3 dim input is provided."));
PADDLE_ENFORCE_EQ(
input->format(),
MKLDNNMemoryFormat::ncw,
platform::errors::Unimplemented(
"DNNL FC only supports in_num_col_dims equal to 2 when "
"input format is equal to ncw."));
FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const phi::DenseTensor* x,
const phi::DenseTensor* weights,
const phi::DenseTensor* bias,
phi::DenseTensor* out,
const int in_num_col_dims,
dnnl::engine mkldnn_engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T_in, dnnl::inner_product_forward>(
mkldnn_engine, cpu_place),
dev_ctx_(dev_ctx) {
this->memory_key_ = ctx.InputName("W");
auto x_vec_dims = phi::vectorize(x->dims());
auto weights_vec_dims = phi::vectorize(weights->dims());
int MB = 1;
for (int i = 0; i < in_num_col_dims; ++i) {
MB *= x_vec_dims[i];
}
weights_ = CreateWeightsMemory(weights);
// Since MKL-DNN has a lot of limitations on what the input/weights/output
// dimensions should be, to simplify the code, the creation of primitive
// descriptor has been divided into separate cases, based on the number
// of input dimensions.
size_t input_dim_num = input->dims().size();
paddle::optional<dnnl::inner_product_forward::primitive_desc> fc_prim_desc;
memory::desc usr_weights_desc = {};
switch (input_dim_num) {
case 2:
fc_prim_desc =
Create2DFcPrimDescriptor(input, weights, bias, output, ctx);
usr_weights_desc = Create2DUserWeightsDesc();
break;
case 3:
fc_prim_desc =
Create3DFcPrimDescriptor(input, weights, bias, output, ctx);
usr_weights_desc = Create3DUserWeightsDesc(weights);
break;
case 4:
fc_prim_desc =
Create4DFcPrimDescriptor(input, weights, bias, output, ctx);
usr_weights_desc = Create4DUserWeightsDesc(input, weights);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"DNNL FC doesn't support input dims different than 2, 3, 4."));
break;
int IC = 1;
for (size_t i = in_num_col_dims; i < x_vec_dims.size(); ++i) {
IC *= x_vec_dims[i];
}
input_ = CreateMemory<T_in>(fc_prim_desc->src_desc(), input);
// Update weights format inside of its memory
weights_ = Reorder(
usr_weights_desc, usr_weights_desc, weights_->get_data_handle());
// Quantize weights and reorder to format chosen by FC primitive descriptor.
QuantizeWeights(ctx, fc_prim_desc->weights_desc());
bias_ = CreateMemoryToBeCached<float>(fc_prim_desc->bias_desc(), bias);
// If int8 is desired, quantize bias into 32-bit signed int
QuantizeBias(*fc_prim_desc, ctx);
// Store weights and bias in the mkldnn cache
CacheWeightsAndBias(dev_ctx, ctx);
// Based on format determined by inner_product, create output in desired
// memory format
output_ = CreateDstMemory(*fc_prim_desc, ctx, output);
// Return MKL-DNN primitive ready to be fed into pipeline and executed
fc_ = inner_product_forward(*fc_prim_desc);
this->Execute();
}
void Execute() {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (bias_) {
fc_->execute(astream,
{{DNNL_ARG_SRC, *input_},
{DNNL_ARG_WEIGHTS, *weights_},
{DNNL_ARG_BIAS, *bias_},
{DNNL_ARG_DST, *output_}});
} else {
fc_->execute(astream,
{{DNNL_ARG_SRC, *input_},
{DNNL_ARG_WEIGHTS, *weights_},
{DNNL_ARG_DST, *output_}});
}
astream.wait();
}
int OC = weights_vec_dims[1];
private:
// DNNL always returns 2-dimensional data block as a result of computing
// inner product. Hence the format 'nc' is always set for its output
// primitive. Therefore, function SetOutputFormat is needed to choose
// an appropriate format based on the number of input dimensions and
// format of an input tensor.
void SetOutputFormat(MKLDNNMemoryFormat in_format, Tensor* out) {
int dim_num = out->dims().size();
// In case of 2 dims, we set the only possible format, nc
if (dim_num == 2) {
out->set_format(MKLDNNMemoryFormat::nc);
out->set_mem_desc({phi::vectorize(out->dims()),
platform::MKLDNNGetDataType<T_out>(),
out->format()});
// In case of 3 dims, we generate a format that is based on number
// of output dims and the layout of input format (nchw or nhwc).
} else if (dim_num == 3) {
if (in_format == MKLDNNMemoryFormat::nwc ||
in_format == MKLDNNMemoryFormat::nhwc) {
out->set_format(
platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nhwc));
} else {
out->set_format(
platform::MKLDNNFormatForSize(dim_num, MKLDNNMemoryFormat::nchw));
}
// In any other case we overwrite the output format with the input one.
} else {
out->set_format(in_format);
}
}
dnnl::memory::desc bias_md;
void UpdateDataPointers(const ExecutionContext& ctx,
Tensor* out,
const Tensor* in) {
input_->set_data_handle(to_void_cast(in->data<T_in>()));
output_->set_data_handle(out->mutable_data<T_out>(ctx.GetPlace()));
// If the primitive exists, but the output tensor has changed its
// variable, update its format to what has been determined in first
// call to CreateFcPrimitive method.
if (out->format() == MKLDNNMemoryFormat::undef) {
SetOutputFormat(in->format(), out);
auto src_md = dnnl::memory::desc(
{MB, IC}, MKLDNNGetDataType<T_in>(), dnnl::memory::format_tag::any);
auto weights_md = dnnl::memory::desc(
{OC, IC}, MKLDNNGetDataType<T_w>(), dnnl::memory::format_tag::any);
auto dst_md = dnnl::memory::desc(
{MB, OC}, MKLDNNGetDataType<T_out>(), dnnl::memory::format_tag::any);
if (bias) {
bias_md = dnnl::memory::desc({bias->numel()},
MKLDNNGetDataType<float>(),
dnnl::memory::format_tag::a);
}
}
dnnl::inner_product_forward::primitive_desc Create2DFcPrimDescriptor(
const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto src_desc = CreateMemDescriptor<T_in>(input, MKLDNNMemoryFormat::any);
auto weight_dims = Get2DWeightDimsForDNNL(weights);
auto weights_desc =
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
std::vector<int64_t> Get2DWeightDimsForDNNL(const Tensor* weights) {
auto dims = phi::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // swap input dim with output dim
return dims;
}
memory::desc Create2DUserWeightsDesc() { return weights_->get_desc(); }
dnnl::inner_product_forward::primitive_desc Create3DFcPrimDescriptor(
const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto input_dims = phi::vectorize(input->dims());
std::vector<int64_t> new_input_dims = {
input_dims[0] * input_dims[1], input_dims[2], 1};
auto src_desc =
CreateMemDescriptor<T_in>(new_input_dims, MKLDNNMemoryFormat::any);
auto weight_dims = Get3DWeightDimsForDNNL(weights);
auto weights_desc =
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_dims = {input_dims[0] * input_dims[1], weight_dims[0]};
auto dst_desc =
CreateMemDescriptor<T_out>(dst_dims, MKLDNNMemoryFormat::any);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
std::vector<int64_t> Get3DWeightDimsForDNNL(const Tensor* weights) {
auto paddle_w_dims = phi::vectorize(weights->dims());
return {paddle_w_dims[1], paddle_w_dims[0], 1};
}
memory::desc Create3DUserWeightsDesc(const Tensor* weights) {
auto dims = Get3DWeightDimsForDNNL(weights);
return CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oiw);
}
dnnl::inner_product_forward::primitive_desc Create4DFcPrimDescriptor(
const LoDTensor* input,
const Tensor* weights,
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto src_desc = CreateMemDescriptor<T_in>(input, MKLDNNMemoryFormat::any);
// Since MKL-DNN doesn't support 4D column-major data formats in
// inner_product primitive, transpose the weights to be in
// row-major format
auto dims = Get4DWeightDimsForDNNL(input, weights);
auto weights_desc = CreateMemDescriptor<T_w>(dims, MKLDNNMemoryFormat::any);
auto bias_desc = CreateMemDescriptor<float>(bias, MKLDNNMemoryFormat::x);
auto dst_desc = CreateMemDescriptor<T_out>(output, MKLDNNMemoryFormat::any);
const auto attrs = CreateFCAttrs(ctx);
return CreateFcPrimDesc(src_desc, weights_desc, bias_desc, dst_desc, attrs);
}
std::vector<int64_t> Get4DWeightDimsForDNNL(const LoDTensor* input,
const Tensor* weights) {
auto old_w_dims = phi::vectorize(weights->dims());
auto old_in_dims = phi::vectorize(input->dims());
auto dims = {old_w_dims[1], old_in_dims[1], old_in_dims[2], old_in_dims[3]};
return dims;
}
memory::desc Create4DUserWeightsDesc(const LoDTensor* input,
const Tensor* weights) {
auto dims = Get4DWeightDimsForDNNL(input, weights);
return CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oihw);
}
// Convert data from one data format to another
std::shared_ptr<dnnl::memory> Reorder(const memory::desc& src_desc,
const memory::desc& dst_desc,
void* src_data) {
auto src_mem = memory(src_desc, engine_, src_data);
auto dst_mem = std::make_shared<memory>(dst_desc, engine_);
auto reorder = dnnl::reorder(src_mem, *dst_mem);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder.execute(astream, src_mem, *dst_mem);
astream.wait();
}
return dst_mem;
this->AcquireForwardPrimitiveDescriptor(attrs,
prop_kind::forward_inference,
src_md,
weights_md,
bias_md,
dst_md);
}
// Convert data from one data format to another and rescale it.
// If the desired data type is (un)signed int8, quantization occurs here.
std::shared_ptr<dnnl::memory> ReorderWithScale(
const std::shared_ptr<memory> src_mem,
const memory::desc& dst_md,
const std::vector<float>& scale_data) {
auto dst_mem = std::make_shared<dnnl::memory>(dst_md, engine_);
private:
dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr attributes;
// According to MKL-DNN's documentation mask determines along which
// dimensions should the scale be applied.
// 0 - Single scale applied to whole tensor
// 1 - Apply Scale along a slice of each dimension which index is 1.
// In case of weights quantization, that dimension is output,
// becuase we perform per-output-channel quantization
int mask = CreateMask(0, scale_data.size() > 1);
attributes.set_output_scales(mask, scale_data);
auto reorder = dnnl::reorder(*src_mem, *dst_mem, attributes);
dnnl::post_ops post_operations;
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{
platform::RecordEvent record_reorder(
"int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder.execute(astream,
{{DNNL_ARG_FROM, *src_mem}, {DNNL_ARG_TO, *dst_mem}});
astream.wait();
std::vector<float> output_shift_scale;
float scale = 1.0f;
if (IsInt8<T_w>()) {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
}
return dst_mem;
}
template <typename T>
static dnnl::memory::desc CreateMemDescriptor(
const std::vector<int64_t>& dims, MKLDNNMemoryFormat format) {
return platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), format);
}
template <typename T>
static dnnl::memory::desc CreateMemDescriptor(const Tensor* tensor,
MKLDNNMemoryFormat format) {
auto dims = phi::vectorize(tensor->dims());
return CreateMemDescriptor<T>(dims, format);
}
template <typename T>
dnnl::memory CreateMemory(const dnnl::memory::desc& desc,
const Tensor* tensor) {
return CreateMemory(desc, platform::to_void_cast<T>(tensor->data<T>()));
}
dnnl::memory CreateMemory(const dnnl::memory::desc& desc, void* data) {
return memory(desc, engine_, data);
}
template <typename T>
std::shared_ptr<dnnl::memory> CreateMemoryToBeCached(
const dnnl::memory::desc& desc, const Tensor* tensor) {
return CreateMemoryToBeCached(desc,
platform::to_void_cast<T>(tensor->data<T>()));
}
std::shared_ptr<dnnl::memory> CreateMemoryToBeCached(
const dnnl::memory::desc& desc, void* data) {
return std::make_shared<memory>(desc, engine_, data);
}
float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
post_operations.append_sum(sum_scale);
}
// Create weights memory and transform to default MKL-DNN format
std::shared_ptr<dnnl::memory> CreateWeightsMemory(const Tensor* weights) {
auto dims = phi::vectorize(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<float>(dims, MKLDNNMemoryFormat::oi);
// Transpose weights through MKL-DNN's reorder from io to oi format.
return Reorder(src_desc,
dst_desc,
platform::to_void_cast<float>(weights->data<float>()));
}
// ReLU from "fc_fuse_pass"
if (ctx.Attr<std::string>("activation_type") == "relu") {
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
}
platform::AppendActivation(ctx, post_operations, scale);
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
std::string key = platform::CreateKey(dev_ctx);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_);
dev_ctx.SetBlob(bias_key, bias_);
attributes.set_post_ops(post_operations);
return attributes;
}
// Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication
std::vector<float> ComputeBiasScales(const ExecutionContext& ctx) {
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> bias_scales(weight_scales_num);
std::vector<float> ComputeBiasScales(
const float scale_in, const std::vector<float>& scale_weights) {
std::vector<float> bias_scales(scale_weights.size());
#pragma omp parallel for
for (size_t i = 0; i < weight_scales_num; i++) {
if (scale_weights_data[i] == 0.0)
for (size_t i = 0; i < bias_scales.size(); ++i) {
if (scale_weights[i] == 0.0)
bias_scales[i] = 1.0f;
else
bias_scales[i] = scale_in_data * scale_weights_data[i];
bias_scales[i] = scale_in * scale_weights[i];
}
return bias_scales;
......@@ -442,18 +177,16 @@ class FCPrimitiveFactory {
? 1.0f
: ctx.Attr<float>("Scale_out");
const size_t weight_scales_num = scale_weights_data.size();
std::vector<float> output_shift_scale(weight_scales_num);
#pragma omp parallel for
for (size_t i = 0; i < weight_scales_num; i++) {
for (size_t i = 0; i < weight_scales_num; ++i) {
if (scale_weights_data[i] == 0.0)
output_shift_scale[i] = inner_scale;
scale_weights_data[i] = inner_scale;
else
output_shift_scale[i] =
scale_weights_data[i] =
inner_scale / (scale_in_data * scale_weights_data[i]);
}
return make_tuple(output_shift_scale, scale);
return make_tuple(scale_weights_data, scale);
}
// Computing MKL-DNN's scaling mask which determines along which dimension
......@@ -464,137 +197,300 @@ class FCPrimitiveFactory {
return is_multi_channel_quantizied ? 1 << slice_dimension : 0;
}
void QuantizeWeights(const ExecutionContext& ctx, memory::desc dst) {
weights_ = ReorderWithScale(
weights_, dst, ctx.Attr<std::vector<float>>("Scale_weights"));
}
std::shared_ptr<dnnl::memory> AcquireMemoryWithReorderAndAttrs(
const dnnl::memory::desc& user_md,
const dnnl::memory::desc& target_md,
void* ptr,
const dnnl::primitive_attr& attrs) {
std::shared_ptr<dnnl::memory> target_memory_p;
void QuantizeBias(const inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx) {
auto bias_scales = ComputeBiasScales(ctx);
bias_ = ReorderWithScale(bias_, fc_prim_desc.bias_desc(), bias_scales);
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
target_memory_p = std::make_shared<dnnl::memory>(target_md, this->engine_);
auto reorder_p = std::make_shared<dnnl::reorder>(
*user_memory_p, *target_memory_p, attrs);
dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr attributes;
dnnl::post_ops post_operations;
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(
astream,
{{DNNL_ARG_FROM, *user_memory_p}, {DNNL_ARG_TO, *target_memory_p}});
astream.wait();
std::vector<float> output_shift_scale;
float scale;
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
return target_memory_p;
}
float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
post_operations.append_sum(sum_scale);
}
std::string memory_key_;
const platform::MKLDNNDeviceContext& dev_ctx_;
if (ctx.Attr<std::string>("activation_type") == "relu") {
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 1.0f; // beta
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_relu, negative_slope, placeholder);
} else if (ctx.Attr<std::string>("activation_type") == "gelu") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_gelu, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_tanh") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_gelu_tanh, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "gelu_erf") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_gelu_erf, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "tanh") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_tanh, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "sigmoid") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_logistic, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "mish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_mish, alpha, beta);
} else if (ctx.Attr<std::string>("activation_type") == "hard_swish") {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_hardswish, alpha, beta);
public:
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
const phi::DenseTensor* x) {
const T_in* x_data = x->data<T_in>();
auto user_md = x->mem_desc();
if (x->dims().size() != 2) {
// reshape restrictions are always satisfied because in case of 3 or 4 dim
// input, plain layout is enforced
user_md = user_md.reshape(this->fwd_pd_->src_desc().dims());
}
if (ctx.HasAttr("fused_output_scale")) {
float scale_alpha = ctx.Attr<float>("fused_output_scale");
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
return this->AcquireMemoryWithReorder(
user_md, this->fwd_pd_->src_desc(), to_void_cast<T_in>(x_data));
}
attributes.set_post_ops(post_operations);
return attributes;
std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
const phi::DenseTensor* bias,
const float scale_in,
const std::vector<float>& scale_weights) {
const float* bias_data = bias->data<float>();
if (IsInt8<T_w>() == false) {
// for BF16/FP32 bias is 1D and has no scales, so reorder is not needed
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data));
} else {
const std::string bias_key = this->memory_key_ + "@bias";
auto memory_p = std::static_pointer_cast<dnnl::memory>(
this->dev_ctx_.GetBlob(bias_key));
if (!memory_p) {
const auto& scale_data = ComputeBiasScales(scale_in, scale_weights);
dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data);
auto user_md = dnnl::memory::desc({bias->dims()[0]},
MKLDNNGetDataType<float>(),
dnnl::memory::format_tag::a);
memory_p = this->AcquireMemoryWithReorderAndAttrs(
user_md,
this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data),
attrs);
this->dev_ctx_.SetBlob(bias_key, memory_p);
}
return memory_p;
}
}
dnnl::inner_product_forward::primitive_desc CreateFcPrimDesc(
const dnnl::memory::desc& input_desc,
const dnnl::memory::desc& weights_desc,
const dnnl::memory::desc& bias_desc,
const dnnl::memory::desc& dst_desc,
const dnnl::primitive_attr& attrs) {
auto fc_desc = inner_product_forward::desc(prop_kind::forward_scoring,
input_desc,
weights_desc,
bias_desc,
dst_desc);
std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
const phi::DenseTensor* weights, const std::vector<float>& scale_data) {
const std::string weights_key = this->memory_key_ + "@weights";
auto memory_p = std::static_pointer_cast<dnnl::memory>(
this->dev_ctx_.GetBlob(weights_key));
if (!memory_p) {
const float* weights_data = weights->data<float>();
auto weights_dims = this->fwd_pd_->weights_desc().dims();
return inner_product_forward::primitive_desc(fc_desc, attrs, engine_);
auto user_md = dnnl::memory::desc(weights_dims,
MKLDNNGetDataType<float>(),
dnnl::memory::format_tag::io);
if (IsInt8<T_w>()) {
dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data);
memory_p = this->AcquireMemoryWithReorderAndAttrs(
user_md,
this->fwd_pd_->weights_desc(),
to_void_cast<float>(weights_data),
attrs);
} else {
memory_p =
this->AcquireMemoryWithReorder(user_md,
this->fwd_pd_->weights_desc(),
to_void_cast<float>(weights_data));
}
this->dev_ctx_.SetBlob(weights_key, memory_p);
}
return memory_p;
}
// Create output memory based on output tensor and inner_product
// primitive descriptor format chosen for output
dnnl::memory CreateDstMemory(
const dnnl::inner_product_forward::primitive_desc& fc_prim_desc,
const ExecutionContext& ctx,
Tensor* output) {
std::shared_ptr<dnnl::memory> AcquireCustomDstMemory(
const ExecutionContext& ctx, phi::DenseTensor* out) {
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<Tensor>("ResidualData");
auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData");
PADDLE_ENFORCE_EQ(
output->dims(),
out->dims(),
residual_param->dims(),
platform::errors::InvalidArgument(
"Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d"
" and residual param's dimension =%d .",
output->dims().size(),
out->dims().size(),
residual_param->dims().size()));
output->ShareDataWith(*residual_param);
out->ShareDataWith(*residual_param);
}
return this->template AcquireDstMemory<T_out>(out);
} // namespace operators
}; // namespace paddle
auto dst_desc = fc_prim_desc.dst_desc();
auto buffer_size = dst_desc.get_size();
T_out* output_data =
output->mutable_data<T_out>(ctx.GetPlace(), buffer_size);
memory dst_mem(dst_desc, engine_, to_void_cast<T_out>(output_data));
SetOutputFormat(ctx.Input<LoDTensor>("Input")->format(), output);
template <typename T_in, typename T_w>
class FCMKLDNNKernel : public framework::OpKernel<T_in> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
return dst_mem;
if (force_fp32_output) {
this->RunKernel<float>(ctx);
} else if (IsInt8<T_in>()) {
if (fuse_relu) {
this->RunKernel<uint8_t>(ctx);
} else {
this->RunKernel<int8_t>(ctx);
}
} else {
this->RunKernel<T_in>(ctx);
}
}
void PrepareSrcMem(const std::shared_ptr<inner_product_forward>& fc_p,
const std::shared_ptr<dnnl::memory>& src_mem,
const LoDTensor* x,
const dnnl::engine& engine) const {
auto x_md = x->mem_desc().reshape(src_mem->get_desc().dims());
if (x_md != src_mem->get_desc()) {
dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>()));
auto reorder_p = dnnl::reorder(x_mem, *src_mem);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p.execute(astream, x_mem, *src_mem);
astream.wait();
} else {
src_mem->set_data_handle(to_void_cast<T_in>(x->data<T_in>()));
}
}
template <typename T_out = T_w>
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<LoDTensor>("Input");
const auto* weights = ctx.Input<phi::DenseTensor>("W");
const auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto out = ctx.Output<LoDTensor>("Out");
const float scale_in = ctx.Attr<float>("Scale_in");
const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
std::shared_ptr<dnnl::inner_product_forward> fc_p;
std::shared_ptr<dnnl::memory> src_memory_p;
std::shared_ptr<dnnl::memory> weights_memory_p;
std::shared_ptr<dnnl::memory> bias_memory_p;
std::shared_ptr<dnnl::memory> dst_memory_p;
std::string cache_key;
cache_key.reserve(64);
cache_key = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx,
platform::CreateKey(dev_ctx,
ctx.InputName("Input"),
ctx.InputName("W"),
phi::vectorize(x->dims())));
auto inner_product_cache =
std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));
RecomputeOutputDims(ctx, x, weights, out);
if (inner_product_cache) {
fc_p = std::make_shared<dnnl::inner_product_forward>(
inner_product_cache->inner_product_p);
src_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->src_mem);
PrepareSrcMem(fc_p, src_memory_p, x, mkldnn_engine);
weights_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->weights_mem);
dst_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->dst_mem);
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData");
out->ShareDataWith(*residual_param);
}
auto out_ptr = out->mutable_data<T_out>(
ctx.GetPlace(), dst_memory_p->get_desc().get_size());
dst_memory_p->set_data_handle(out_ptr);
if (bias) {
bias_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->bias_mem);
}
} else {
auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
FCMKLDNNHandler<T_in, T_w, T_out> handler(ctx,
dev_ctx,
x,
weights,
bias,
out,
in_col_dims,
mkldnn_engine,
ctx.GetPlace());
src_memory_p = handler.AcquireSrcMemoryWithReorder(x);
weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(weights, scale_weights);
dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
if (bias) {
bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights);
}
fc_p = handler.AcquireForwardPrimitive();
}
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> fc_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
fc_p->execute(astream, fc_args);
astream.wait();
if (!inner_product_cache) {
auto ip_cache = std::make_shared<InnerProductCache>();
ip_cache->inner_product_p = *fc_p;
ip_cache->src_mem = *src_memory_p;
ip_cache->weights_mem = *weights_memory_p;
ip_cache->dst_mem = *dst_memory_p;
if (bias) {
ip_cache->bias_mem = *bias_memory_p;
}
dev_ctx.SetBlob(cache_key, ip_cache);
}
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx,
out,
dst_memory_p->get_desc().reshape(phi::vectorize(out->dims())));
}
void RecomputeOutputDims(const ExecutionContext& ctx,
const LoDTensor* input,
const Tensor* w,
LoDTensor* output) {
const LoDTensor* x,
const phi::DenseTensor* weights,
LoDTensor* out) const {
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
bool padding_weights = ctx.Attr<bool>("padding_weights");
PADDLE_ENFORCE_EQ(padding_weights,
......@@ -602,102 +498,16 @@ class FCPrimitiveFactory {
platform::errors::PermissionDenied(
"Weight padding in fc can not be used in MKLDNN."));
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(),
w->dims(),
FCOutputSize(x->dims(),
weights->dims(),
output_dims,
in_num_col_dims,
padding_weights);
output->Resize(phi::make_ddim(output_dims));
output->set_lod(input->lod());
out->Resize(phi::make_ddim(output_dims));
out->set_lod(x->lod());
}
private:
const dnnl::engine& engine_;
paddle::optional<memory> input_;
paddle::optional<memory> output_;
std::shared_ptr<memory> bias_;
std::shared_ptr<memory> weights_;
paddle::optional<inner_product_forward> fc_;
};
// Attempt to fetch cached primitive factory based on provided parameters
// of input format, weight dimensions and output name.
// If not cached, create a new one.
template <typename T_in, typename T_w, typename T_out>
static std::shared_ptr<FCPrimitiveFactory<T_in, T_w, T_out>>
GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const std::string& key) {
auto prim_creator =
std::static_pointer_cast<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetBlob(key));
if (prim_creator == nullptr) {
prim_creator = std::make_shared<FCPrimitiveFactory<T_in, T_w, T_out>>(
dev_ctx.GetEngine());
dev_ctx.SetBlob(key, prim_creator);
}
return prim_creator;
}
// Choose appropriate primitive factory implementation based on inferred
// output type (uint8, int8 or float).
template <typename T_in, typename T_w>
static void ExecuteFc(const ExecutionContext& ctx,
const LoDTensor* input,
const Tensor* w,
const Tensor* bias,
LoDTensor* output,
bool fuse_relu,
bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
std::string prim_key = platform::CreateKey(dev_ctx,
input->format(),
input->dims()[0],
phi::vectorize<int>(w->dims()),
ctx.OutputName("Out"));
prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key);
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
if ((!is_int8 && !is_bfloat16) || force_fp32_output) {
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (is_bfloat16) {
GetPrimitiveFactory<T_in, T_w, platform::bfloat16>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (fuse_relu) {
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else {
GetPrimitiveFactory<T_in, T_w, int8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
}
}
template <typename T_in, typename T_w>
class FCMKLDNNOpKernel : public framework::OpKernel<T_in> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()),
true,
platform::errors::PreconditionNotMet("FC MKL-DNN must use CPUPlace."));
platform::MKLDNNDeviceContext::tls().log_lib_version();
auto input = ctx.Input<LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<LoDTensor>("Out");
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
ExecuteFc<T_in, T_w>(
ctx, input, w, bias, output, fuse_relu, force_fp32_output);
output->set_layout(DataLayout::kMKLDNN);
}
};
} // namespace operators
} // namespace paddle
......@@ -710,7 +520,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
::paddle::platform::CPUPlace,
FP32,
ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<float, float>);
ops::FCMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
fc,
......@@ -718,19 +528,19 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
::paddle::platform::CPUPlace,
BF16,
ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
ops::FCMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<uint8_t, int8_t>);
ops::FCMKLDNNKernel<uint8_t, int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<int8_t, int8_t>);
ops::FCMKLDNNKernel<int8_t, int8_t>);
......@@ -21,72 +21,8 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::DataLayout;
template <typename T>
class TransposeMKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
std::vector<int>& axis, // NOLINT
dnnl::engine engine)
: dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < this->logical_axis_.size(); ++i) {
this->logical_axis_[i] = i;
}
auto src_md = fmt != MKLDNNMemoryFormat::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<T>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
return std::make_shared<dnnl::memory>(src_md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto dst_md = Axis2MemoryDesc(dims_, axis_);
auto dst_data = output->mutable_data<T>(place, dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireTranspose(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
protected:
dnnl::memory::desc Axis2MemoryDesc(std::vector<int64_t>& nchw_tz, // NOLINT
std::vector<int>& axis // NOLINT
) {
size_t ndims = axis.size();
std::vector<int64_t> strides(ndims);
unsigned int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
strides[axis[i]] = total_stride;
total_stride *= nchw_tz[axis[i]];
}
dnnl::memory::desc mem_d(
nchw_tz, platform::MKLDNNGetDataType<T>(), strides);
return mem_d;
}
private:
std::vector<int64_t> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
dnnl::engine engine_;
};
using Tensor = phi::DenseTensor;
using phi::DataLayout;
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
......@@ -98,37 +34,87 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
const T* input_data = input->data<T>();
const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = transpose_axis.size();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (ndims == 1) {
framework::TensorCopy(*input, input->place(), output);
output->set_format(input->format());
framework::TensorCopy(*x, x->place(), out);
out->set_mem_desc(x->mem_desc());
return;
}
auto nchw_tz = phi::vectorize<int64_t>(input->dims());
auto x_vec_dims = phi::vectorize(x->dims());
TransposeMKLDNNHandler<T> handler(nchw_tz, axis, mkldnn_engine);
framework::proto::VarType::Type x_paddle_type =
framework::TransToProtoVarType(x->dtype());
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x_paddle_type, x_type, dnnl_engine);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
auto dst_md =
dnnl::memory::desc(x_vec_dims,
x->mem_desc().data_type(),
platform::GetPlainMKLDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
dst_md =
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides);
auto dst_data =
out->mutable_data(ctx.GetPlace(), x->type(), dst_md.get_size());
auto reorder_dst_memory_p =
std::make_shared<dnnl::memory>(dst_md, dnnl_engine, dst_data);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
output->set_layout(DataLayout::kNCHW);
output->set_format(MKLDNNMemoryFormat::undef);
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx,
out,
reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
}
private:
// it is needed because oneDNN's permute axis understand axes order in
// different way PaddlePaddle's transpose
std::vector<int> TransposeToPermuteAxis(
const std::vector<int>& transpose_axis) const {
std::vector<int> permute_axis(transpose_axis.size());
for (size_t i = 0; i < transpose_axis.size(); ++i) {
permute_axis[transpose_axis[i]] = i;
}
return permute_axis;
}
std::vector<int64_t> FakeTranposeStrides(
const dnnl::memory::desc& dst_md,
const std::vector<int>& transpose_axis) const {
std::vector<int64_t> fake_strides(transpose_axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
int ndims = static_cast<int>(dims.size());
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= dims[transpose_axis[i]];
}
return fake_strides;
}
};
......@@ -140,44 +126,47 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace"));
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
if (!dx) return;
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
int ndims = axis.size();
const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
int ndims = transpose_axis.size();
if (ndims == 1) {
framework::TensorCopy(*out_grad, out_grad->place(), x_grad);
x_grad->set_format(out_grad->format());
framework::TensorCopy(*dout, dout->place(), dx);
dx->set_mem_desc(dout->mem_desc());
return;
}
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
auto dout_vec_dims = phi::vectorize(dout->dims());
const T* out_grad_data = out_grad->data<T>();
x_grad->mutable_data<T>(ctx.GetPlace());
framework::proto::VarType::Type dout_paddle_type =
framework::TransToProtoVarType(dout->dtype());
dnnl::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout_paddle_type);
auto nchw_tz = phi::vectorize<int64_t>(out_grad->dims());
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine);
TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, mkldnn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto transpose_src_memory_p = handler.AcquireSrcMemory(
out_grad->format(), platform::to_void_cast<T>(out_grad_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(x_grad, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace());
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
transpose_p->execute(
astream, *transpose_src_memory_p, *transpose_dst_memory_p);
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_mem_desc(
reorder_dst_memory_p->get_desc().permute_axes(transpose_axis));
}
};
......
......@@ -111,6 +111,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx,
}
}
static void SetOutMemDescWithUnsqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const std::vector<int>& fused_unsqueeze2_axes =
ctx.Attr<std::vector<int>>("fused_unsqueeze2_axes");
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}
int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
unsqueezed_op_tz[i] = op_tz[j++];
}
}
out->set_mem_desc(out_md.reshape(unsqueezed_op_tz));
out->Resize(phi::make_ddim(unsqueezed_op_tz));
}
static void SetOutMemDescWithReshape2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
static void SetOutMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
if (ctx.HasAttr("fused_unsqueeze2_axes")) {
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else {
out->set_mem_desc(out_md);
}
}
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册