未验证 提交 b7a23adb 编写于 作者: S Sławomir Siwek 提交者: GitHub

FC + activation fuse passes (#45183)

* git

* style

* leave default relu in kernel

* style

* cleanup FCMKLDNN pattern

* merge conflicts

* update develop

* update develop

* add const

* rename to oneDNN and adjust attributes

* whitespace
上级 da051350
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,9 +14,8 @@ ...@@ -14,9 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -26,20 +25,20 @@ namespace ir { ...@@ -26,20 +25,20 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = { auto act_types = paddle::platform::GetSupportedActivations();
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
for (std::string act_type : act_types) FuseFCAct(graph, act_type); for (auto act_type : act_types) FuseFCAct(graph, act_type);
} }
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const { const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph); FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act"); patterns::OperatorActivation fc_act_pattern(
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
fc_act_pattern("fc", act_type); fc_act_pattern("fc", act_type);
int found_fc_act_count = 0; int found_fc_act_count = 0;
...@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used.")); "is used."));
} }
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) { if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")); std::string gelu_act_type =
std::string type = approximate ? "_tanh" : "_erf"; PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
fc_op->SetAttr("activation_type", act_type + type); : "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else { } else {
fc_op->SetAttr("activation_type", act_type); fc_op->SetAttr("fuse_activation", act_type);
} }
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()}); fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out); IR_OP_VAR_LINK(fc, act_out);
...@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_fc_act_count); AddStatis(found_fc_act_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_fc_act_count > 0)
PrettyLogDetail( PrettyLogDetail(
"--- fused %d fc with %s activation", found_fc_act_count, act_type); "--- fused %d fc with %s activation", found_fc_act_count, act_type);
} }
...@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) ...@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0) .LE("fc", 0)
.LE("gelu", 0) .EQ("abs", 0)
.LE("sigmoid", 0) .LE("clip", 1)
.LE("mish", 1) .EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0) .LE("hard_swish", 0)
.LE("tanh", 0)); .LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
...@@ -23,21 +23,14 @@ namespace paddle { ...@@ -23,21 +23,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase { class FuseFCActOneDNNPass : public FusePassBase {
public: public:
virtual ~FuseFCActOneDNNPass() {} virtual ~FuseFCActOneDNNPass() {}
protected: protected:
void ApplyImpl(ir::Graph *graph) const override; void ApplyImpl(Graph *graph) const override;
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const; void FuseFCAct(Graph *graph, const std::string &act_types) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -34,12 +34,12 @@ TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) { ...@@ -34,12 +34,12 @@ TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}, {{"Out", "fc_y"}},
false); false);
test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog); Graph graph(prog);
// No fusion in this attribute configuration // No fusion in this attribute configuration
...@@ -58,12 +58,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { ...@@ -58,12 +58,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
auto* act_op = test::CreateOp( auto* act_op =
&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", true); act_op->SetAttr("approximate", true);
Graph graph(prog); Graph graph(prog);
...@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { ...@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_tanh"), 0); EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
} }
} }
...@@ -93,12 +93,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { ...@@ -93,12 +93,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
auto* act_op = test::CreateOp( auto* act_op =
&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", false); act_op->SetAttr("approximate", false);
Graph graph(prog); Graph graph(prog);
...@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { ...@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_erf"), 0); EXPECT_EQ(act_type.compare("gelu_erf"), 0);
} }
} }
...@@ -128,11 +128,11 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { ...@@ -128,11 +128,11 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog); Graph graph(prog);
constexpr int removed_nodes_count = 2; constexpr int removed_nodes_count = 2;
...@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { ...@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu"), 0); EXPECT_EQ(act_type.compare("gelu"), 0);
} }
} }
...@@ -161,11 +161,11 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { ...@@ -161,11 +161,11 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
test::CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); test::CreateOp(&prog, "tanh", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog); Graph graph(prog);
constexpr int removed_nodes_count = 2; constexpr int removed_nodes_count = 2;
...@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { ...@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("tanh"), 0); EXPECT_EQ(act_type.compare("tanh"), 0);
} }
} }
...@@ -194,12 +194,11 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { ...@@ -194,12 +194,11 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
test::CreateOp( test::CreateOp(&prog, "sigmoid", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog); Graph graph(prog);
constexpr int removed_nodes_count = 2; constexpr int removed_nodes_count = 2;
...@@ -213,9 +212,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { ...@@ -213,9 +212,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("sigmoid"), 0); EXPECT_EQ(act_type.compare("sigmoid"), 0);
} }
} }
...@@ -228,11 +227,11 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { ...@@ -228,11 +227,11 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); test::CreateOp(&prog, "mish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog); Graph graph(prog);
constexpr int removed_nodes_count = 2; constexpr int removed_nodes_count = 2;
...@@ -246,9 +245,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { ...@@ -246,9 +245,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("mish"), 0); EXPECT_EQ(act_type.compare("mish"), 0);
} }
} }
...@@ -261,12 +260,12 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { ...@@ -261,12 +260,12 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
"fc", "fc",
{ {
{"Input", "x"}, {"Input", "x"},
{"Weights", "weights"}, {"W", "weights"},
{"Bias", "bias"}, {"Bias", "bias"},
}, },
{{"Out", "fc_y"}}); {{"Out", "fc_y"}});
test::CreateOp( test::CreateOp(
&prog, "hard_swish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); &prog, "hard_swish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog); Graph graph(prog);
constexpr int removed_nodes_count = 2; constexpr int removed_nodes_count = 2;
...@@ -280,9 +279,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { ...@@ -280,9 +279,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
const auto* op = node->Op(); const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type")); ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type = auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("hard_swish"), 0); EXPECT_EQ(act_type.compare("hard_swish"), 0);
} }
} }
......
...@@ -242,7 +242,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) { ...@@ -242,7 +242,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) { if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) { if (!input_desc.second.empty()) {
LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator (" LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!"; << op_name_ << ") not registered in OpCompat!";
return false; return false;
} }
} }
...@@ -269,7 +269,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) { ...@@ -269,7 +269,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) { if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) { if (!output_desc.second.empty()) {
LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator (" LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!"; << op_name_ << ") not registered in OpCompat!";
return false; return false;
} }
} }
......
...@@ -87,8 +87,7 @@ class FCMKLDNNHandler ...@@ -87,8 +87,7 @@ class FCMKLDNNHandler
dnnl::memory::format_tag::a); dnnl::memory::format_tag::a);
} }
dnnl::primitive_attr attrs; const auto attrs = CreateFCAttrs(ctx);
HandlePostOps(ctx, &attrs);
this->AcquireForwardPrimitiveDescriptor(attrs, this->AcquireForwardPrimitiveDescriptor(attrs,
prop_kind::forward_inference, prop_kind::forward_inference,
...@@ -99,44 +98,33 @@ class FCMKLDNNHandler ...@@ -99,44 +98,33 @@ class FCMKLDNNHandler
} }
private: private:
void HandlePostOps(const paddle::framework::ExecutionContext& ctx, dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr* attrs) { dnnl::primitive_attr attributes;
static std::unordered_map<std::string, dnnl::algorithm> algo_map = { dnnl::post_ops post_operations;
{"relu", dnnl::algorithm::eltwise_relu},
{"gelu", dnnl::algorithm::eltwise_gelu},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"mish", dnnl::algorithm::eltwise_mish}};
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
float scale = 1.0f; float scale = 1.0f;
if (IsInt8<T_w>()) { if (IsInt8<T_w>()) {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1); int mask = CreateMask(1, output_shift_scale.size() > 1);
attrs->set_output_scales(mask, output_shift_scale); attributes.set_output_scales(mask, output_shift_scale);
} }
dnnl::post_ops post_ops; float sum_scale = 1.0f;
constexpr float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") && if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) { ctx.Attr<bool>("fuse_residual_connection")) {
post_ops.append_sum(sum_scale); post_operations.append_sum(sum_scale);
} }
std::string activation_type = ctx.Attr<std::string>("activation_type"); // ReLU from "fc_fuse_pass"
if (ctx.Attr<std::string>("activation_type") == "relu") {
if (activation_type.empty() == false) { post_operations.append_eltwise(
constexpr float alpha = 0.0f; scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
constexpr float beta = 0.0f;
post_ops.append_eltwise(scale, algo_map[activation_type], alpha, beta);
} }
platform::AppendActivation(ctx, post_operations, scale);
attrs->set_post_ops(post_ops); attributes.set_post_ops(post_operations);
return attributes;
} }
// Compute the bias scales so that its values correspond to the // Compute the bias scales so that its values correspond to the
......
...@@ -226,7 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -226,7 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT
300) 300)
set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_fc_mish_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_onednn_fc_activation_fuse_pass PROPERTIES TIMEOUT
300)
set_tests_properties(test_mkldnn_fc_elementwise_add_fuse_pass set_tests_properties(test_mkldnn_fc_elementwise_add_fuse_pass
PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_conv_affine_channel_fuse_pass set_tests_properties(test_mkldnn_conv_affine_channel_fuse_pass
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig
import numpy as np
import unittest
import hypothesis.strategies as st
class TestFCMishMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
x_shape = draw(
st.lists(st.integers(min_value=1, max_value=128),
min_size=2,
max_size=3))
in_num_col_dims = len(x_shape) - 1
w_shape = draw(
st.lists(st.integers(min_value=1, max_value=128),
min_size=2,
max_size=2))
w_shape[0] = int(np.prod(x_shape[in_num_col_dims:]))
fc_bias_shape = [w_shape[1]]
ops_config = [{
"op_type": "fc",
"op_inputs": {
"Input": ["fc_x"],
"W": ["fc_w"],
"Bias": ["fc_bias"]
},
"op_outputs": {
"Out": ["fc_out"]
},
"op_attrs": {
"activation_type": "",
"padding_weights": False,
"in_num_col_dims": in_num_col_dims,
"use_mkldnn": True
}
}, {
"op_type": "mish",
"op_inputs": {
"X": ["fc_out"]
},
"op_outputs": {
"Out": ["mish_output"]
},
"op_attrs": {},
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(ops=ops,
weights={
"fc_w":
TensorConfig(shape=w_shape),
"fc_bias":
TensorConfig(shape=fc_bias_shape),
},
inputs={
"fc_x": TensorConfig(shape=x_shape),
},
outputs=["mish_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"])
yield config, ["fc"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestFCActivationOneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
fc_in = draw(st.sampled_from([32, 64]))
fc_wei = draw(st.sampled_from([64]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
fc_op = OpConfig(type="fc",
inputs={
"Input": ["fc_input"],
"W": ["fc_weight"],
"Bias": ["fc_bias"]
},
outputs={"Out": ["fc_output"]},
attrs={
"use_mkldnn": True,
"padding_weights": False,
"in_num_col_dims": 1,
})
if activation_type == "clip":
activation_op = OpConfig(
activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
elif activation_type == "gelu":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
approximate=draw(st.booleans()))
elif activation_type == "leaky_relu":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == "relu6":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
threshold=6)
elif activation_type == "swish":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
beta=draw(
st.floats(min_value=0.1,
max_value=10.0)))
else:
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]})
model_net = [fc_op, activation_op]
program_config = ProgramConfig(
ops=model_net,
weights={
"fc_weight":
TensorConfig(
data_gen=partial(generate_input, [fc_wei, fc_wei])),
"fc_bias":
TensorConfig(data_gen=partial(generate_input, [fc_wei])),
},
inputs={
"fc_input":
TensorConfig(data_gen=partial(generate_input, [fc_in, fc_wei]))
},
outputs=["activation_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"])
yield config, ["fc"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册