未验证 提交 347e4b2e 编写于 作者: S Sławomir Siwek 提交者: GitHub

Generalize conv+activation fuse pass (#43382)

* consolidate conv act passes

* generalize conv_activation

* integrate conv+act tests

* code style format

* whitespaces

* remove timeout from old tests

* implement comments from review

* restore ut

* whitespace

* code style

* transpose

* fixes after review

* method for gettin act

* Change Paddle_enforce error type

* code format

* add missing opcompats
上级 9aa89b99
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,110 +14,109 @@ ...@@ -14,110 +14,109 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Graph; using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
std::vector<std::string> act_types = {
"relu", "mish", "swish", "sqrt", "hard_swish", "sigmoid", "abs",
"gelu", "relu6", "clip", "tanh", "hard_sigmoid", "leaky_relu"};
std::vector<std::string> conv_types = {"conv2d"};
for (const auto& conv_type : conv_types)
for (auto& act_type : act_types) {
std::unordered_map<std::string, std::string> attrs_map;
if (act_type == "swish")
attrs_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
attrs_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
attrs_map.emplace("slope", "fuse_alpha");
attrs_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attrs_map.emplace("min", "fuse_alpha");
attrs_map.emplace("max", "fuse_beta");
} else {
attrs_map.emplace("alpha", "fuse_alpha");
attrs_map.emplace("beta", "fuse_beta");
}
FuseConvAct(graph, conv_type, act_type, attrs_map);
}
}
void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { void ConvActivationMkldnnFusePass::FuseConvAct(
Graph* graph, const std::string& conv_type, std::string& act_type,
const std::unordered_map<std::string, std::string>& attrs_map) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv_activation_mkldnn_fuse", graph); FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_activation_mkldnn_fuse/conv_input") ->NewNode("conv_activation_mkldnn_fuse/conv_input")
->AsInput() ->AsInput()
->assert_is_op_input(conv_type(), "Input"); ->assert_is_op_input(conv_type, "Input");
patterns::ConvActivation conv_activation_pattern( patterns::ConvActivation conv_act_pattern(gpd.mutable_pattern(),
gpd.mutable_pattern(), "conv_activation_mkldnn_fuse"); "conv_activation_mkldnn_fuse");
conv_activation_pattern(conv_input, conv_type(), activation_type()); conv_act_pattern(conv_input, conv_type, act_type);
int found_conv_activation_count = 0; int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; VLOG(4) << "handle " + conv_type + "+" + act_type + " fuse";
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed."; LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed.";
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_activation_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
conv_activation_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_activation_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
conv_activation_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(activation, activation,
conv_activation_pattern); // Activation op
// Transform Conv node into ConvActivation node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output",
std::vector<std::string>({activation_out->Name()}));
if (activation_type() == "gelu" &&
activation->Op()->HasAttr("approximate")) {
bool approximate =
BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
desc->SetAttr("fuse_activation", "gelu" + type);
} else {
desc->SetAttr("fuse_activation", activation_type());
}
// MKLDNN ops use alpha and beta as activation parameters but paddle ops are GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_act_pattern);
// not generalized GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_act_pattern);
if (activation_type() == "relu6") { GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_act_pattern);
desc->SetAttr( GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
"fuse_alpha", GET_IR_NODE_FROM_SUBGRAPH(activation, activation, conv_act_pattern);
BOOST_GET_CONST(float, activation->Op()->GetAttr("threshold")));
} else if (activation_type() == "swish") { OpDesc* conv_op = conv->Op();
// paddle uses beta but mkldnn uses alpha for swish OpDesc* act_op = activation->Op();
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("beta")); for (const auto& attrs : attrs_map) {
} else { if (act_op->HasAttr(attrs.first)) {
desc->SetAttr("fuse_alpha", conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
activation->Op()->GetAttrIfExists<float>("alpha")); }
}
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("beta"));
if (activation_type() == "hard_sigmoid") {
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("slope"));
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("offset"));
} }
GraphSafeRemoveNodes(graph, {activation, conv_out}); if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type = BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
conv_op->SetOutput("Output", {activation_out->Name()});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
platform::errors::InvalidArgument(
"Subgraph has to contain conv input node."));
IR_NODE_LINK_TO(conv, activation_out); IR_NODE_LINK_TO(conv, activation_out);
GraphSafeRemoveNodes(graph, {activation, conv_out});
found_conv_activation_count++; found_conv_activation_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_conv_activation_count); AddStatis(found_conv_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail("--- fused %d conv with %s activation",
found_conv_activation_count, act_type);
}
} }
ConvActivationFusePass::ConvActivationFusePass() { ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() {
AddOpCompat(OpCompat("conv2d")) AddOpCompat(OpCompat("conv2d"))
.AddInput("Input") .AddInput("Input")
.IsTensor() .IsTensor()
...@@ -142,8 +141,6 @@ ConvActivationFusePass::ConvActivationFusePass() { ...@@ -142,8 +141,6 @@ ConvActivationFusePass::ConvActivationFusePass() {
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
// IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this
// attribute
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional() .IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
...@@ -154,7 +151,6 @@ ConvActivationFusePass::ConvActivationFusePass() { ...@@ -154,7 +151,6 @@ ConvActivationFusePass::ConvActivationFusePass() {
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.AddAttr("data_format") .AddAttr("data_format")
.IsOptional() .IsOptional()
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
...@@ -167,8 +163,7 @@ ConvActivationFusePass::ConvActivationFusePass() { ...@@ -167,8 +163,7 @@ ConvActivationFusePass::ConvActivationFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End(); .End();
}
Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
AddOpCompat(OpCompat("leaky_relu")) AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -176,12 +171,10 @@ Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() { ...@@ -176,12 +171,10 @@ Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
// float, default=0.02
.AddAttr("alpha") .AddAttr("alpha")
.IsType<float>() .IsType<float>()
.End(); .End();
}
Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
AddOpCompat(OpCompat("relu6")) AddOpCompat(OpCompat("relu6"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -189,12 +182,10 @@ Conv2DReLU6FusePass::Conv2DReLU6FusePass() { ...@@ -189,12 +182,10 @@ Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
// default = 6.0f
.AddAttr("threshold") .AddAttr("threshold")
.IsType<float>() .IsType<float>()
.End(); .End();
}
Conv2DSwishFusePass::Conv2DSwishFusePass() {
AddOpCompat(OpCompat("swish")) AddOpCompat(OpCompat("swish"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -205,8 +196,7 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() { ...@@ -205,8 +196,7 @@ Conv2DSwishFusePass::Conv2DSwishFusePass() {
.AddAttr("beta") .AddAttr("beta")
.IsType<float>() .IsType<float>()
.End(); .End();
}
Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
AddOpCompat(OpCompat("hard_swish")) AddOpCompat(OpCompat("hard_swish"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -214,23 +204,19 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { ...@@ -214,23 +204,19 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
// float, optional, default=6.0
.AddAttr("threshold") .AddAttr("threshold")
.IsOptional() .IsOptional()
.IsType<float>() .IsType<float>()
.End() .End()
// float, optional, default=6.0
.AddAttr("scale") .AddAttr("scale")
.IsOptional() .IsOptional()
.IsType<float>() .IsType<float>()
.End() .End()
// float, optional, default=3.0
.AddAttr("offset") .AddAttr("offset")
.IsOptional() .IsOptional()
.IsType<float>() .IsType<float>()
.End(); .End();
}
Conv2DMishFusePass::Conv2DMishFusePass() {
AddOpCompat(OpCompat("mish")) AddOpCompat(OpCompat("mish"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -238,8 +224,7 @@ Conv2DMishFusePass::Conv2DMishFusePass() { ...@@ -238,8 +224,7 @@ Conv2DMishFusePass::Conv2DMishFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End(); .End();
}
Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
AddOpCompat(OpCompat("hard_sigmoid")) AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -247,19 +232,15 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { ...@@ -247,19 +232,15 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End() .End()
// optional, default=0.2
.AddAttr("slope") .AddAttr("slope")
.IsOptional() .IsOptional()
.IsType<float>() .IsType<float>()
.End() .End()
// optional, default=0.5
.AddAttr("offset") .AddAttr("offset")
.IsOptional() .IsOptional()
.IsType<float>() .IsType<float>()
.End(); .End();
}
Conv2DGeluFusePass::Conv2DGeluFusePass() {
AddOpCompat(OpCompat("gelu")) AddOpCompat(OpCompat("gelu"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -270,6 +251,38 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() { ...@@ -270,6 +251,38 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() {
.AddAttr("approximate") .AddAttr("approximate")
.IsType<bool>() .IsType<bool>()
.End(); .End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sqrt"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("abs"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
} }
} // namespace ir } // namespace ir
...@@ -277,68 +290,22 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() { ...@@ -277,68 +290,22 @@ Conv2DGeluFusePass::Conv2DGeluFusePass() {
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_activation_mkldnn_fuse_pass, REGISTER_PASS(conv_activation_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass); paddle::framework::ir::ConvActivationMkldnnFusePass);
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("relu", 0));
REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DLeakyReLUFusePass);
REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.LE("leaky_relu", 1));
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DReLU6FusePass);
REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("relu6", 0));
REGISTER_PASS(conv_swish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DSwishFusePass);
REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("swish", 0));
REGISTER_PASS(conv_hard_swish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSwishFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_swish", 0));
REGISTER_PASS(conv_mish_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DMishFusePass);
REGISTER_PASS_CAPABILITY(conv_mish_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("mish", 1));
REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSigmoidFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_sigmoid", 0));
REGISTER_PASS(conv_gelu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_activation_mkldnn_fuse_pass)
paddle::framework::ir::Conv2DGeluFusePass);
REGISTER_PASS_CAPABILITY(conv_gelu_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1) .LE("conv2d", 1)
.EQ("gelu", 0)); .EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -18,84 +18,22 @@ ...@@ -18,84 +18,22 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
/*
* Fuse Conv and Activation base class.
*/
class Graph;
class ConvActivationFusePass : public FusePassBase { class ConvActivationMkldnnFusePass : public FusePassBase {
public: public:
ConvActivationFusePass(); ConvActivationMkldnnFusePass();
virtual ~ConvActivationFusePass() {} virtual ~ConvActivationMkldnnFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
virtual std::string activation_type() const { return "relu"; }
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(Graph *graph) const override;
const std::string name_scope_{"conv_activation_mkldnn_fuse"};
};
/*
* Fuse Conv and LeakyReLU class
*/
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
public:
Conv2DLeakyReLUFusePass();
std::string activation_type() const { return "leaky_relu"; }
};
/*
* Fuse Conv and BoundedReLU class
*/
class Conv2DReLU6FusePass : public ConvActivationFusePass {
public:
Conv2DReLU6FusePass();
std::string activation_type() const { return "relu6"; }
};
/*
* Fuse Conv and Swish class
*/
class Conv2DSwishFusePass : public ConvActivationFusePass {
public:
Conv2DSwishFusePass();
std::string activation_type() const { return "swish"; }
};
/*
* Fuse Conv and HardSwish class
*/
class Conv2DHardSwishFusePass : public ConvActivationFusePass {
public:
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; }
};
/*
* Fuse Conv and Mish class
*/
class Conv2DMishFusePass : public ConvActivationFusePass {
public:
Conv2DMishFusePass();
std::string activation_type() const { return "mish"; }
};
/*
* Fuse Conv and HardSigmoid class
*/
class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
public:
Conv2DHardSigmoidFusePass();
std::string activation_type() const { return "hard_sigmoid"; }
};
/* void FuseConvAct(
* Fuse Conv and Gelu class Graph *graph, const std::string &conv_type, std::string &act_type,
*/ const std::unordered_map<std::string, std::string> &attrs_map) const;
class Conv2DGeluFusePass : public ConvActivationFusePass {
public:
Conv2DGeluFusePass();
std::string activation_type() const { return "gelu"; }
}; };
} // namespace ir } // namespace ir
......
...@@ -104,8 +104,7 @@ void MainTest(std::string activation) { ...@@ -104,8 +104,7 @@ void MainTest(std::string activation) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = auto pass = PassRegistry::Instance().Get("conv_activation_mkldnn_fuse_pass");
PassRegistry::Instance().Get("conv_" + activation + "_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size(); int original_nodes_num = graph->Nodes().size();
......
...@@ -27,7 +27,7 @@ using string::PrettyLogDetail; ...@@ -27,7 +27,7 @@ using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = { std::vector<std::string> act_types = {
"relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt", "relu", "tanh", "leaky_relu", "swish", "hard_swish", "sqrt",
"abs", "clip", "gelu", "relu6", "sigmoid"}; "abs", "clip", "gelu", "relu6", "sigmoid"};
std::vector<std::string> elt_types = {"elementwise_add", "elementwise_sub", std::vector<std::string> elt_types = {"elementwise_add", "elementwise_sub",
"elementwise_mul"}; "elementwise_mul"};
...@@ -56,7 +56,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( ...@@ -56,7 +56,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
const std::unordered_map<std::string, std::string> &attr_map) const { const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elementwise_act", graph); FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto *elementwise_input = gpd.mutable_pattern() auto *elementwise_input = gpd.mutable_pattern()
......
...@@ -302,15 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -302,15 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() {
// "conv3d_bias_mkldnn_fuse_pass", // // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", // "conv_activation_mkldnn_fuse_pass", //
"conv_leaky_relu_mkldnn_fuse_pass", //
"conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", //
"conv_mish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", //
// TODO(baoachun) fix int8 accuracy
"conv_gelu_mkldnn_fuse_pass",
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", // "reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
...@@ -403,14 +395,7 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -403,14 +395,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass"); passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass");
passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("conv_concat_relu_mkldnn_fuse_pass"); passes_.push_back("conv_concat_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu_mkldnn_fuse_pass"); passes_.push_back("conv_activation_mkldnn_fuse_pass");
passes_.push_back("conv_leaky_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu6_mkldnn_fuse_pass");
passes_.push_back("conv_swish_mkldnn_fuse_pass");
passes_.push_back("conv_hard_swish_mkldnn_fuse_pass");
passes_.push_back("conv_mish_mkldnn_fuse_pass");
passes_.push_back("conv_hard_sigmoid_mkldnn_fuse_pass");
passes_.push_back("conv_gelu_mkldnn_fuse_pass");
passes_.push_back("fc_fuse_pass"); passes_.push_back("fc_fuse_pass");
passes_.push_back("repeated_fc_relu_fuse_pass"); passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("fc_mkldnn_pass"); passes_.push_back("fc_mkldnn_pass");
......
...@@ -104,12 +104,12 @@ TEST(Analyzer_vit_ocr, fuse_status) { ...@@ -104,12 +104,12 @@ TEST(Analyzer_vit_ocr, fuse_status) {
SetConfig(&cfg, true); SetConfig(&cfg, true);
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_status = GetFuseStatis( auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops); static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
CHECK_EQ(fuse_status.at("fc_mkldnn_pass"), 33); CHECK_EQ(fuse_statis.at("fc_mkldnn_pass"), 33);
CHECK_EQ(fuse_status.at("conv_activation_mkldnn_fuse"), 2); CHECK_EQ(fuse_statis.at("conv2d_gelu_mkldnn_fuse_pass"), 2);
CHECK_EQ(fuse_status.at("fc_elementwise_add_mkldnn_fuse"), 16); CHECK_EQ(fuse_statis.at("fc_elementwise_add_mkldnn_fuse"), 16);
} }
#endif #endif
......
...@@ -61,27 +61,10 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -61,27 +61,10 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
? ctx.Attr<float>("activation_beta") ? ctx.Attr<float>("activation_beta")
: 0.0f; : 0.0f;
static std::unordered_map<std::string, dnnl::algorithm> algo_map = { const auto activation_algorithm = platform::AcquireActivationAlgorithm(
{"relu", dnnl::algorithm::eltwise_relu}, ctx.Attr<std::string>("activation_type"));
{"tanh", dnnl::algorithm::eltwise_tanh},
{"leaky_relu", dnnl::algorithm::eltwise_relu}, post_operations.append_eltwise(scale, activation_algorithm, alpha, beta);
{"swish", dnnl::algorithm::eltwise_swish},
{"hardswish", dnnl::algorithm::eltwise_hardswish},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic}};
const auto& activation_type =
algo_map.find(ctx.Attr<std::string>("activation_type"));
if (activation_type != algo_map.end()) {
post_operations.append_eltwise(scale, activation_type->second, alpha,
beta);
}
} }
return post_operations; return post_operations;
} }
......
...@@ -505,41 +505,20 @@ class ConvMKLDNNHandlerT ...@@ -505,41 +505,20 @@ class ConvMKLDNNHandlerT
if (fuse_residual_conn) { if (fuse_residual_conn) {
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
} }
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation. if (fuse_activation == "hard_sigmoid") {
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_relu, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "relu6") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_swish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_swish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_hardswish,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "mish") {
post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_mish, fuse_alpha,
fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(activation_scale, post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_linear, dnnl::algorithm::eltwise_linear,
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
post_operations.append_eltwise(activation_scale, post_operations.append_eltwise(activation_scale,
dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); dnnl::algorithm::eltwise_clip, 0.0f, 1.0f);
} else if (fuse_activation == "gelu_tanh") { } else if (fuse_activation != "") {
post_operations.append_eltwise( const auto activation_algorithm =
activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f); platform::AcquireActivationAlgorithm(fuse_activation);
} else if (fuse_activation == "gelu_erf") { post_operations.append_eltwise(activation_scale, activation_algorithm,
post_operations.append_eltwise( fuse_alpha, fuse_beta);
activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f);
} }
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
} }
......
...@@ -951,6 +951,33 @@ class ActivationMKLDNNHandler ...@@ -951,6 +951,33 @@ class ActivationMKLDNNHandler
} }
}; };
static const dnnl::algorithm AcquireActivationAlgorithm(
std::string activation_name) {
std::unordered_map<std::string, dnnl::algorithm> activation_map = {
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"mish", dnnl::algorithm::eltwise_mish},
{"relu", dnnl::algorithm::eltwise_relu},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"swish", dnnl::algorithm::eltwise_swish},
{"tanh", dnnl::algorithm::eltwise_tanh}};
const auto& activation_type = activation_map.find(activation_name);
PADDLE_ENFORCE_NE(activation_type, activation_map.end(),
platform::errors::InvalidArgument(
"Activation '%s' not found in oneDNN algorithms mapper",
activation_name));
return activation_type->second;
}
class ReorderMKLDNNHandler { class ReorderMKLDNNHandler {
public: public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
......
...@@ -432,14 +432,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -432,14 +432,7 @@ class Quant2Int8MkldnnPass(object):
graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_activation_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_leaky_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_mish_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_hard_sigmoid_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_gelu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass', graph = self._apply_pass(graph, 'fc_fuse_pass',
['use_gpu', 'use_fc_padding'], [False, False]) ['use_gpu', 'use_fc_padding'], [False, False])
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
......
...@@ -170,28 +170,24 @@ class TestConvActMkldnnFusePass(PassAutoScanTest): ...@@ -170,28 +170,24 @@ class TestConvActMkldnnFusePass(PassAutoScanTest):
# 11. Generate legal attr of act # 11. Generate legal attr of act
act_op = None act_op = None
self.passes = None self.passes = ["conv_activation_mkldnn_fuse_pass"]
if act_type == "relu6": if act_type == "relu6":
self.passes = ["conv_relu6_mkldnn_fuse_pass"]
threshold = draw(st.floats(min_value=1.0, max_value=10.0)) threshold = draw(st.floats(min_value=1.0, max_value=10.0))
act_op = OpConfig("relu6", act_op = OpConfig("relu6",
inputs={"X": ["conv2d_out"]}, inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]}, outputs={"Out": ["relu_out"]},
threshold=threshold) threshold=threshold)
if act_type == "leaky_relu": elif act_type == "leaky_relu":
self.passes = ["conv_leaky_relu_mkldnn_fuse_pass"]
alpha = draw(st.floats(min_value=0.1, max_value=1.0)) alpha = draw(st.floats(min_value=0.1, max_value=1.0))
act_op = OpConfig("leaky_relu", act_op = OpConfig("leaky_relu",
inputs={"X": ["conv2d_out"]}, inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]}, outputs={"Out": ["relu_out"]},
alpha=alpha) alpha=alpha)
if act_type == "relu": elif act_type == "relu":
self.passes = ["conv_relu_mkldnn_fuse_pass"]
act_op = OpConfig("relu", act_op = OpConfig("relu",
inputs={"X": ["conv2d_out"]}, inputs={"X": ["conv2d_out"]},
outputs={"Out": ["relu_out"]}) outputs={"Out": ["relu_out"]})
if act_type == "swish": elif act_type == "swish":
self.passes = ["conv_swish_mkldnn_fuse_pass"]
beta = draw(st.floats(min_value=0.1, max_value=1.0)) beta = draw(st.floats(min_value=0.1, max_value=1.0))
act_op = OpConfig("swish", act_op = OpConfig("swish",
inputs={"X": ["conv2d_out"]}, inputs={"X": ["conv2d_out"]},
......
...@@ -18,8 +18,6 @@ import unittest ...@@ -18,8 +18,6 @@ import unittest
import numpy as np import numpy as np
from inference_pass_test import InferencePassTest from inference_pass_test import InferencePassTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import PassVersionChecker from paddle.fluid.core import PassVersionChecker
...@@ -42,13 +40,13 @@ class ConvActivationMkldnnFusePassTest(InferencePassTest): ...@@ -42,13 +40,13 @@ class ConvActivationMkldnnFusePassTest(InferencePassTest):
} }
self.fetch_list = [conv_out] self.fetch_list = [conv_out]
self.enable_mkldnn = True self.enable_mkldnn = True
self.pass_name = 'conv_activation_mkldnn_fuse_pass'
def set_params(self): def set_params(self):
self.conv_num_filters = 3 self.conv_num_filters = 3
self.conv_filter_size = 3 self.conv_filter_size = 3
self.conv_bias_attr = False self.conv_bias_attr = False
self.act = "relu" self.act = "relu"
self.pass_name = 'conv_relu_mkldnn_fuse_pass'
def test_check_output(self): def test_check_output(self):
use_gpu = False use_gpu = False
...@@ -65,7 +63,6 @@ class ConvActivationMkldnnFusePassTest_1(ConvActivationMkldnnFusePassTest): ...@@ -65,7 +63,6 @@ class ConvActivationMkldnnFusePassTest_1(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5 self.conv_filter_size = 5
self.conv_bias_attr = True self.conv_bias_attr = True
self.act = "relu" self.act = "relu"
self.pass_name = 'conv_relu_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest): class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest):
...@@ -75,7 +72,6 @@ class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest): ...@@ -75,7 +72,6 @@ class ConvActivationMkldnnFusePassTest_2(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 3 self.conv_filter_size = 3
self.conv_bias_attr = False self.conv_bias_attr = False
self.act = "leaky_relu" self.act = "leaky_relu"
self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest): class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest):
...@@ -85,7 +81,6 @@ class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest): ...@@ -85,7 +81,6 @@ class ConvActivationMkldnnFusePassTest_3(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5 self.conv_filter_size = 5
self.conv_bias_attr = True self.conv_bias_attr = True
self.act = "leaky_relu" self.act = "leaky_relu"
self.pass_name = 'conv_leaky_relu_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest): class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest):
...@@ -95,7 +90,6 @@ class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest): ...@@ -95,7 +90,6 @@ class ConvActivationMkldnnFusePassTest_4(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 3 self.conv_filter_size = 3
self.conv_bias_attr = False self.conv_bias_attr = False
self.act = "relu6" self.act = "relu6"
self.pass_name = 'conv_relu6_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
...@@ -105,7 +99,6 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest): ...@@ -105,7 +99,6 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5 self.conv_filter_size = 5
self.conv_bias_attr = True self.conv_bias_attr = True
self.act = "hard_swish" self.act = "hard_swish"
self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass'
class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest): class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
...@@ -115,7 +108,6 @@ class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest): ...@@ -115,7 +108,6 @@ class ConvActivationMkldnnFusePassTest_6(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5 self.conv_filter_size = 5
self.conv_bias_attr = True self.conv_bias_attr = True
self.act = "mish" self.act = "mish"
self.pass_name = 'conv_mish_mkldnn_fuse_pass'
class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
...@@ -125,7 +117,6 @@ class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest): ...@@ -125,7 +117,6 @@ class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
self.conv_filter_size = 5 self.conv_filter_size = 5
self.conv_bias_attr = True self.conv_bias_attr = True
self.act = "hard_sigmoid" self.act = "hard_sigmoid"
self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass'
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -102,7 +102,8 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): ...@@ -102,7 +102,8 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) self.run_and_statis(quant=False,
passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -104,7 +104,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): ...@@ -104,7 +104,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest):
def test(self): def test(self):
self.run_and_statis(quant=False, self.run_and_statis(quant=False,
passes=["conv_hard_sigmoid_mkldnn_fuse_pass"]) passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -106,7 +106,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): ...@@ -106,7 +106,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest):
def test(self): def test(self):
self.run_and_statis(quant=False, self.run_and_statis(quant=False,
passes=["conv_hard_swish_mkldnn_fuse_pass"]) passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -99,7 +99,8 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest): ...@@ -99,7 +99,8 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest):
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis(quant=False, passes=["conv_mish_mkldnn_fuse_pass"]) self.run_and_statis(quant=False,
passes=["conv_activation_mkldnn_fuse_pass"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册