未验证 提交 347e4b2e 编写于 作者: S Sławomir Siwek 提交者: GitHub

Generalize conv+activation fuse pass (#43382)

* consolidate conv act passes

* generalize conv_activation

* integrate conv+act tests

* code style format

* whitespaces

* remove timeout from old tests

* implement comments from review

* restore ut

* whitespace

* code style

* transpose

* fixes after review

* method for gettin act

* Change Paddle_enforce error type

* code format

* add missing opcompats
上级 9aa89b99
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,110 +14,109 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
} // namespace paddle
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
std::vector<std::string> act_types = {
"relu", "mish", "swish", "sqrt", "hard_swish", "sigmoid", "abs",
"gelu", "relu6", "clip", "tanh", "hard_sigmoid", "leaky_relu"};
std::vector<std::string> conv_types = {"conv2d"};
for (const auto& conv_type : conv_types)
for (auto& act_type : act_types) {
std::unordered_map<std::string, std::string> attrs_map;
if (act_type == "swish")
attrs_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
attrs_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
attrs_map.emplace("slope", "fuse_alpha");
attrs_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attrs_map.emplace("min", "fuse_alpha");
attrs_map.emplace("max", "fuse_beta");
} else {
attrs_map.emplace("alpha", "fuse_alpha");
attrs_map.emplace("beta", "fuse_beta");
}
FuseConvAct(graph, conv_type, act_type, attrs_map);
}
}
void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
void ConvActivationMkldnnFusePass::FuseConvAct(
Graph* graph, const std::string& conv_type, std::string& act_type,
const std::unordered_map<std::string, std::string>& attrs_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv_activation_mkldnn_fuse", graph);
FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_activation_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input(conv_type(), "Input");
patterns::ConvActivation conv_activation_pattern(
gpd.mutable_pattern(), "conv_activation_mkldnn_fuse");
conv_activation_pattern(conv_input, conv_type(), activation_type());
->assert_is_op_input(conv_type, "Input");
patterns::ConvActivation conv_act_pattern(gpd.mutable_pattern(),
"conv_activation_mkldnn_fuse");
conv_act_pattern(conv_input, conv_type, act_type);
int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
VLOG(4) << "handle " + conv_type + "+" + act_type + " fuse";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_activation_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
conv_activation_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_activation_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
conv_activation_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(activation, activation,
conv_activation_pattern); // Activation op
// Transform Conv node into ConvActivation node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output",
std::vector<std::string>({activation_out->Name()}));
if (activation_type() == "gelu" &&
activation->Op()->HasAttr("approximate")) {
bool approximate =
BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
desc->SetAttr("fuse_activation", "gelu" + type);
} else {
desc->SetAttr("fuse_activation", activation_type());
}
// MKLDNN ops use alpha and beta as activation parameters but paddle ops are
// not generalized
if (activation_type() == "relu6") {
desc->SetAttr(
"fuse_alpha",
BOOST_GET_CONST(float, activation->Op()->GetAttr("threshold")));
} else if (activation_type() == "swish") {
// paddle uses beta but mkldnn uses alpha for swish
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("beta"));
} else {
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("alpha"));
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern);
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
for (const auto& attrs : attrs_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("beta"));
if (activation_type() == "hard_sigmoid") {
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("slope"));
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("offset"));
}
GraphSafeRemoveNodes(graph, {activation, conv_out});
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type = BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
conv_op->SetOutput("Output", {activation_out->Name()});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
platform::errors::InvalidArgument(
"Subgraph has to contain conv input node."));
IR_NODE_LINK_TO(conv, activation_out);
GraphSafeRemoveNodes(graph, {activation, conv_out});
found_conv_activation_count++;
};
gpd(graph, handler);
AddStatis(found_conv_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail("--- fused %d conv with %s activation",
found_conv_activation_count, act_type);
}
}
ConvActivationFusePass::ConvActivationFusePass() {
ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
......@@ -142,8 +141,6 @@ ConvActivationFusePass::ConvActivationFusePass() {
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
// IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this
// attribute
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
......@@ -154,7 +151,6 @@ ConvActivationFusePass::ConvActivationFusePass() {
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
......@@ -167,8 +163,7 @@ ConvActivationFusePass::ConvActivationFusePass() {
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
......@@ -176,12 +171,10 @@ Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
.AddOutput("Out")
.IsTensor()
.End()
// float, default=0.02
.AddAttr("alpha")
.IsType<float>()
.End();
}
Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
AddOpCompat(OpCompat("relu6"))
.AddInput("X")
.IsTensor()
......@@ -189,12 +182,10 @@ Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
.AddOutput("Out")
.IsTensor()
.End()
// default = 6.0f
.AddAttr("threshold")
.IsType<float>()
.End();
}
Conv2DSwishFusePass::Conv2DSwishFusePass() {
AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
......@@ -205,8 +196,7 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() {
.AddAttr("beta")
.IsType<float>()
.End();
}
Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
AddOpCompat(OpCompat("hard_swish"))
.AddInput("X")
.IsTensor()
......@@ -214,23 +204,19 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
.AddOutput("Out")
.IsTensor()
.End()
// float, optional, default=6.0
.AddAttr("threshold")
.IsOptional()
.IsType<float>()
.End()
// float, optional, default=6.0
.AddAttr("scale")
.IsOptional()
.IsType<float>()
.End()
// float, optional, default=3.0
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
}
Conv2DMishFusePass::Conv2DMishFusePass() {
AddOpCompat(OpCompat("mish"))
.AddInput("X")
.IsTensor()
......@@ -238,8 +224,7 @@ Conv2DMishFusePass::Conv2DMishFusePass() {
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X")
.IsTensor()
......@@ -247,19 +232,15 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
.AddOutput("Out")
.IsTensor()
.End()
// optional, default=0.2
.AddAttr("slope")
.IsOptional()
.IsType<float>()
.End()
// optional, default=0.5
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
}
Conv2DGeluFusePass::Conv2DGeluFusePass() {
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
......@@ -270,6 +251,38 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() {
.AddAttr("approximate")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sqrt"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("abs"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
} // namespace ir
......@@ -277,68 +290,22 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() {
} // namespace paddle
REGISTER_PASS(conv_activation_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("relu", 0));
REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DLeakyReLUFusePass);
REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.LE("leaky_relu", 1));
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DReLU6FusePass);
REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("relu6", 0));
REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DSwishFusePass);
REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("swish", 0));
REGISTER_PASS(conv_hard_swish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSwishFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_swish", 0));
REGISTER_PASS(conv_mish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DMishFusePass);
REGISTER_PASS_CAPABILITY(conv_mish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("mish", 1));
REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSigmoidFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_sigmoid", 0));
paddle::framework::ir::ConvActivationMkldnnFusePass);
REGISTER_PASS(conv_gelu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DGeluFusePass);
REGISTER_PASS_CAPABILITY(conv_gelu_mkldnn_fuse_pass)
REGISTER_PASS_CAPABILITY(conv_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("gelu", 0));
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -18,84 +18,22 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Conv and Activation base class.
*/
class Graph;
class ConvActivationFusePass : public FusePassBase {
class ConvActivationMkldnnFusePass : public FusePassBase {
public:
ConvActivationFusePass();
virtual ~ConvActivationFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
virtual std::string activation_type() const { return "relu"; }
ConvActivationMkldnnFusePass();
virtual ~ConvActivationMkldnnFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_activation_mkldnn_fuse"};
};
/*
* Fuse Conv and LeakyReLU class
*/
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
public:
Conv2DLeakyReLUFusePass();
std::string activation_type() const { return "leaky_relu"; }
};
/*
* Fuse Conv and BoundedReLU class
*/
class Conv2DReLU6FusePass : public ConvActivationFusePass {
public:
Conv2DReLU6FusePass();
std::string activation_type() const { return "relu6"; }
};
/*
* Fuse Conv and Swish class
*/
class Conv2DSwishFusePass : public ConvActivationFusePass {
public:
Conv2DSwishFusePass();
std::string activation_type() const { return "swish"; }
};
/*
* Fuse Conv and HardSwish class
*/
class Conv2DHardSwishFusePass : public ConvActivationFusePass {
public:
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; }
};
/*
* Fuse Conv and Mish class
*/
class Conv2DMishFusePass : public ConvActivationFusePass {
public:
Conv2DMishFusePass();
std::string activation_type() const { return "mish"; }
};
/*
* Fuse Conv and HardSigmoid class
*/
class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
public:
Conv2DHardSigmoidFusePass();
std::string activation_type() const { return "hard_sigmoid"; }
};
void ApplyImpl(Graph *graph) const override;
/*
* Fuse Conv and Gelu class
*/
class Conv2DGeluFusePass : public ConvActivationFusePass {
public:
Conv2DGeluFusePass();
std::string activation_type() const { return "gelu"; }
void FuseConvAct(
Graph *graph, const std::string &conv_type, std::string &act_type,
const std::unordered_map<std::string, std::string> &attrs_map) const;
};
} // namespace ir
......
......@@ -104,8 +104,7 @@ void MainTest(std::string activation) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass =
PassRegistry::Instance().Get("conv_" + activation + "_mkldnn_fuse_pass");
auto pass = PassRegistry::Instance().Get("conv_activation_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
......
......@@ -27,7 +27,7 @@ using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {
"relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt",
"relu", "tanh", "leaky_relu", "swish", "hard_swish", "sqrt",
"abs", "clip", "gelu", "relu6", "sigmoid"};
std::vector<std::string> elt_types = {"elementwise_add", "elementwise_sub",
"elementwise_mul"};
......@@ -56,7 +56,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elementwise_act", graph);
FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
auto *elementwise_input = gpd.mutable_pattern()
......
......@@ -302,15 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() {
// "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", //
"conv_leaky_relu_mkldnn_fuse_pass", //
"conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", //
"conv_mish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", //
// TODO(baoachun) fix int8 accuracy
"conv_gelu_mkldnn_fuse_pass",
"conv_activation_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
......@@ -403,14 +395,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass");
passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("conv_concat_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu_mkldnn_fuse_pass");
passes_.push_back("conv_leaky_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu6_mkldnn_fuse_pass");
passes_.push_back("conv_swish_mkldnn_fuse_pass");
passes_.push_back("conv_hard_swish_mkldnn_fuse_pass");
passes_.push_back("conv_mish_mkldnn_fuse_pass");
passes_.push_back("conv_hard_sigmoid_mkldnn_fuse_pass");
passes_.push_back("conv_gelu_mkldnn_fuse_pass");
passes_.push_back("conv_activation_mkldnn_fuse_pass");
passes_.push_back("fc_fuse_pass");
passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("fc_mkldnn_pass");
......
......@@ -104,12 +104,12 @@ TEST(Analyzer_vit_ocr, fuse_status) {
SetConfig(&cfg, true);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_status = GetFuseStatis(
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
CHECK_EQ(fuse_status.at("fc_mkldnn_pass"), 33);
CHECK_EQ(fuse_status.at("conv_activation_mkldnn_fuse"), 2);
CHECK_EQ(fuse_status.at("fc_elementwise_add_mkldnn_fuse"), 16);
CHECK_EQ(fuse_statis.at("fc_mkldnn_pass"), 33);
CHECK_EQ(fuse_statis.at("conv2d_gelu_mkldnn_fuse_pass"), 2);
CHECK_EQ(fuse_statis.at("fc_elementwise_add_mkldnn_fuse"), 16);
}
#endif
......
......@@ -61,27 +61,10 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
? ctx.Attr<float>("activation_beta")
: 0.0f;
static std::unordered_map<std::string, dnnl::algorithm> algo_map = {
{"relu", dnnl::algorithm::eltwise_relu},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"swish", dnnl::algorithm::eltwise_swish},
{"hardswish", dnnl::algorithm::eltwise_hardswish},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic}};
const auto& activation_type =
algo_map.find(ctx.Attr<std::string>("activation_type"));
if (activation_type != algo_map.end()) {
post_operations.append_eltwise(scale, activation_type->second, alpha,
beta);
}
const auto activation_algorithm = platform::AcquireActivationAlgorithm(
ctx.Attr<std::string>("activation_type"));
post_operations.append_eltwise(scale, activation_algorithm, alpha, beta);
}
return post_operations;
}
......
......@@ -505,41 +505,20 @@ class ConvMKLDNNHandlerT
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_relu, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "relu6") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_swish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_swish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_hardswish,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "mish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_mish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear,
fuse_alpha, fuse_beta);
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else if (fuse_activation == "gelu_tanh") {
post_operations.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f);
} else if (fuse_activation == "gelu_erf") {
post_operations.append_eltwise(
activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f);
} else if (fuse_activation != "") {
const auto activation_algorithm =
platform::AcquireActivationAlgorithm(fuse_activation);
post_operations.append_eltwise(activation_scale, activation_algorithm,
fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
......
......@@ -951,6 +951,33 @@ class ActivationMKLDNNHandler
}
};
static const dnnl::algorithm AcquireActivationAlgorithm(
std::string activation_name) {
std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"mish", dnnl::algorithm::eltwise_mish},
{"relu", dnnl::algorithm::eltwise_relu},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(activation_name);
PADDLE_ENFORCE_NE(activation_type, activation_map.end(),
platform::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
activation_name));
return activation_type->second;
}
class ReorderMKLDNNHandler {
public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
......
......@@ -432,14 +432,7 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_leaky_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_mish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_sigmoid_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_gelu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_activation_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
......
......@@ -170,28 +170,24 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
# 11. Generate legal attr of act
act_op = None
self.passes = None
self.passes = ["conv_activation_mkldnn_fuse_pass"]
if act_type == "relu6":
self.passes = ["conv_relu6_mkldnn_fuse_pass"]
threshold = draw(st.floats(min_value=1.0, max_value=10.0))
act_op = OpConfig("relu6",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]},
threshold=threshold)
if act_type == "leaky_relu":
self.passes = ["conv_leaky_relu_mkldnn_fuse_pass"]
elif act_type == "leaky_relu":
alpha = draw(st.floats(min_value=0.1, max_value=1.0))
act_op = OpConfig("leaky_relu",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]},
alpha=alpha)
if act_type == "relu":
self.passes = ["conv_relu_mkldnn_fuse_pass"]
elif act_type == "relu":
act_op = OpConfig("relu",
inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]})
if act_type == "swish":
self.passes = ["conv_swish_mkldnn_fuse_pass"]
elif act_type == "swish":
beta = draw(st.floats(min_value=0.1, max_value=1.0))
act_op = OpConfig("swish",
inputs={"X": ["conv2d_out"]},
......
......@@ -18,8 +18,6 @@ import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import PassVersionChecker
......@@ -42,13 +40,13 @@ class ConvActivationMkldnnFusePassTest(InferencePassTest):
}
self.fetch_list = [conv_out]
self.enable_mkldnn = True
self.pass_name = 'conv_activation_mkldnn_fuse_pass'
def set_params(self):
self.conv_num_filters = 3
self.conv_filter_size = 3
self.conv_bias_attr = False
self.act = "relu"
self.pass_name = 'conv_relu_mkldnn_fuse_pass'
def test_check_output(self):
use_gpu = False
......@@ -65,7 +63,6 @@ class ConvActivationMkldnnFusePassTest_1(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "relu"
self.pass_name = 'conv_relu_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest):
......@@ -75,7 +72,6 @@ class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 3
self.conv_bias_attr = False
self.act = "leaky_relu"
self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest):
......@@ -85,7 +81,6 @@ class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "leaky_relu"
self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest):
......@@ -95,7 +90,6 @@ class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 3
self.conv_bias_attr = False
self.act = "relu6"
self.pass_name = 'conv_relu6_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
......@@ -105,7 +99,6 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "hard_swish"
self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
......@@ -115,7 +108,6 @@ class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "mish"
self.pass_name = 'conv_mish_mkldnn_fuse_pass'
class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
......@@ -125,7 +117,6 @@ class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "hard_sigmoid"
self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass'
if __name__ == "__main__":
......
......@@ -102,7 +102,8 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
yield config, ["conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"])
self.run_and_statis(quant=False,
passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
......@@ -104,7 +104,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest):
def test(self):
self.run_and_statis(quant=False,
passes=["conv_hard_sigmoid_mkldnn_fuse_pass"])
passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
......@@ -106,7 +106,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest):
def test(self):
self.run_and_statis(quant=False,
passes=["conv_hard_swish_mkldnn_fuse_pass"])
passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
......@@ -99,7 +99,8 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest):
yield config, ["conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["conv_mish_mkldnn_fuse_pass"])
self.run_and_statis(quant=False,
passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册