未验证 提交 7f0eb2e3 编写于 作者: Z zyfncg 提交者: GitHub

Refactor Pass for fused_conv (#48848)

* refactor conv_activation_mkldnn_fuse_pass

* refactor conv_affine_channel_mkldnn_fuse_pass

* fix conv_activation_mkldnn_fuse_pass

* fix mkldnn unittest

* refactor int8_scale_calculation_mkldnn_pass and params_quantization_mkldnn_pass

* refactor conv_elementwise_add_mkldnn_fuse_pass

* fix quant

* refactor conv_bn_fuse_pass

* fix conv_bn_fuse_pass

* refactor depthwise_conv_bn_fuse_pass

* fix unittest

* fix conv_bn_fuse_pass

* remove redundant conv2d in params_quantization_mkldnn_pass

* fix params_quantization_mkldnn_pass_tester
上级 b8814777
...@@ -213,6 +213,43 @@ ConvBNFusePass::ConvBNFusePass() { ...@@ -213,6 +213,43 @@ ConvBNFusePass::ConvBNFusePass() {
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("batch_norm")) AddOpCompat(OpCompat("batch_norm"))
.AddInput("X") .AddInput("X")
...@@ -361,6 +398,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -361,6 +398,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
// with MKL-DNN fuse conv+bn into conv with bias // with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add // without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) { if (fuse_option == FUSE_MKLDNN) {
if (conv->Op()->Type() == "conv2d" ||
conv->Op()->Type() == "depthwise_conv2d") {
conv->Op()->SetType("fused_conv2d");
}
auto input_names = conv->Op()->InputNames(); auto input_names = conv->Op()->InputNames();
bool has_bias = bool has_bias =
std::find(input_names.begin(), input_names.end(), "Bias") != std::find(input_names.begin(), input_names.end(), "Bias") !=
...@@ -818,6 +859,43 @@ DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() { ...@@ -818,6 +859,43 @@ DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
} }
} // namespace ir } // namespace ir
......
...@@ -1826,20 +1826,20 @@ PDNode *patterns::ConvBias::operator()( ...@@ -1826,20 +1826,20 @@ PDNode *patterns::ConvBias::operator()(
return eltwise_out_var; return eltwise_out_var;
} }
PDNode *patterns::Conv::operator()() { PDNode *patterns::Conv::operator()(const std::string &conv_type) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type);
auto input_var = pattern->NewNode(conv_input_repr()) auto input_var = pattern->NewNode(conv_input_repr())
->AsInput() ->AsInput()
->assert_is_op_input("conv2d", "Input"); ->assert_is_op_input(conv_type, "Input");
auto filter_var = pattern->NewNode(conv_filter_repr()) auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput() ->AsInput()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_op_input(conv_type, "Filter");
auto output_var = pattern->NewNode(conv_output_repr()) auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("conv2d", "Output"); ->assert_is_op_output(conv_type, "Output");
conv_op->LinksFrom({input_var, filter_var}).LinksTo({output_var}); conv_op->LinksFrom({input_var, filter_var}).LinksTo({output_var});
return output_var; return output_var;
...@@ -2658,10 +2658,12 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) { ...@@ -2658,10 +2658,12 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) {
} }
PDNode *patterns::ConvAffineChannel::operator()( PDNode *patterns::ConvAffineChannel::operator()(
paddle::framework::ir::PDNode *conv_input, bool with_eltwise_add) { paddle::framework::ir::PDNode *conv_input,
const std::string &conv_type,
bool with_eltwise_add) {
// Create Operators // Create Operators
conv_input->assert_is_op_input("conv2d", "Input"); conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
PDNode *eltwise_op = nullptr; PDNode *eltwise_op = nullptr;
if (with_eltwise_add) { if (with_eltwise_add) {
...@@ -2676,11 +2678,11 @@ PDNode *patterns::ConvAffineChannel::operator()( ...@@ -2676,11 +2678,11 @@ PDNode *patterns::ConvAffineChannel::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_op_input(conv_type, "Filter");
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op("conv2d"); ->assert_is_only_output_of_op(conv_type);
PDNode *eltwise_y_in_var = nullptr; PDNode *eltwise_y_in_var = nullptr;
PDNode *eltwise_out_var = nullptr; PDNode *eltwise_out_var = nullptr;
......
...@@ -1044,7 +1044,7 @@ struct Conv : public PatternBase { ...@@ -1044,7 +1044,7 @@ struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope) Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {} : PatternBase(pattern, name_scope, "convolution") {}
PDNode* operator()(); PDNode* operator()(const std::string& conv_type);
PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input); PATTERN_DECL_NODE(conv_input);
...@@ -1544,7 +1544,9 @@ struct ConvAffineChannel : public PatternBase { ...@@ -1544,7 +1544,9 @@ struct ConvAffineChannel : public PatternBase {
ConvAffineChannel(PDPattern* pattern, const std::string& name_scope) ConvAffineChannel(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_affine_channel") {} : PatternBase(pattern, name_scope, "conv_affine_channel") {}
PDNode* operator()(PDNode* conv_input, bool with_eltwise_add); PDNode* operator()(PDNode* conv_input,
const std::string& conv_type,
bool with_eltwise_add);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(conv);
......
...@@ -26,7 +26,7 @@ using string::PrettyLogDetail; ...@@ -26,7 +26,7 @@ using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations(); auto act_types = phi::funcs::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d", "fused_conv2d"}; std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};
for (auto& act_type : act_types) { for (auto& act_type : act_types) {
FuseConvConcatAct(graph, act_type); FuseConvConcatAct(graph, act_type);
...@@ -64,6 +64,10 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, ...@@ -64,6 +64,10 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
OpDesc* conv_op = conv->Op(); OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op(); OpDesc* act_op = activation->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
auto attr_map = phi::funcs::GetAttributeMap(act_type); auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) { for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) { if (act_op->HasAttr(attrs.first)) {
...@@ -91,8 +95,9 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, ...@@ -91,8 +95,9 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
AddStatis(found_conv_activation_count); AddStatis(found_conv_activation_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) && if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_conv_activation_count > 0) { found_conv_activation_count > 0) {
PrettyLogDetail("--- fused %d conv with %s activation", PrettyLogDetail("--- fused %d %s with %s activation",
found_conv_activation_count, found_conv_activation_count,
conv_type,
act_type); act_type);
} }
} }
...@@ -134,15 +139,20 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct( ...@@ -134,15 +139,20 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
bool is_not_conv_mkldnn = bool is_not_conv_mkldnn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn")); !(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if (prev_op_nodes[0]->Op()->Type() != "conv2d" || is_not_conv_mkldnn) { if ((prev_op_nodes[0]->Op()->Type() != "conv2d" &&
LOG(WARNING) prev_op_nodes[0]->Op()->Type() != "fused_conv2d") ||
<< "This fuse pass supports only conv2d (mkldnn) + activation."; is_not_conv_mkldnn) {
LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | "
"fused_conv2d(mkldnn) + activation.";
return; return;
} }
} }
for (auto node : concat_inputs) { for (auto node : concat_inputs) {
OpDesc* conv_op = node->inputs[0]->Op(); OpDesc* conv_op = node->inputs[0]->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
OpDesc* act_op = activation_op->Op(); OpDesc* act_op = activation_op->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type); auto attr_map = phi::funcs::GetAttributeMap(act_type);
......
...@@ -165,7 +165,8 @@ void MainTest(std::string activation) { ...@@ -165,7 +165,8 @@ void MainTest(std::string activation) {
int conv_activation_count = 0; int conv_activation_count = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") { if (node->IsOp() && (node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fused_conv2d")) {
auto* op = node->Op(); auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
......
...@@ -143,6 +143,44 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { ...@@ -143,6 +143,44 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsStringIn({"NCHW", "AnyLayout"}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel")) AddOpCompat(OpCompat("affine_channel"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -177,6 +215,12 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { ...@@ -177,6 +215,12 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
} }
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
FuseConvAffineChannel(graph, "conv2d");
FuseConvAffineChannel(graph, "fused_conv2d");
}
void ConvAffineChannelFusePass::FuseConvAffineChannel(
ir::Graph* graph, const std::string& conv_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
...@@ -190,10 +234,10 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -190,10 +234,10 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
gpd.mutable_pattern() gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput() ->AsInput()
->assert_is_op_input("conv2d", "Input"); ->assert_is_op_input(conv_type, "Input");
patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(), patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(),
name_scope_); name_scope_);
conv_ac_pattern(conv_input, false /*with_eltwise_add*/); conv_ac_pattern(conv_input, conv_type, false /*with_eltwise_add*/);
int found_conv_ac_count = 0; int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......
...@@ -36,6 +36,8 @@ class ConvAffineChannelFusePass : public FusePassBase { ...@@ -36,6 +36,8 @@ class ConvAffineChannelFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph*) const override; void ApplyImpl(ir::Graph*) const override;
void FuseConvAffineChannel(ir::Graph* graph,
const std::string& conv_type) const;
const std::string name_scope_{"conv_affine_channel_mkldnn_fuse"}; const std::string name_scope_{"conv_affine_channel_mkldnn_fuse"};
}; };
......
...@@ -60,6 +60,43 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { ...@@ -60,6 +60,43 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"}) .IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X") .AddInput("X")
...@@ -79,12 +116,13 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { ...@@ -79,12 +116,13 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
const std::string& name_scope, const std::string& name_scope,
const GraphWithStats& graph_with_stats, const GraphWithStats& graph_with_stats,
const std::string& conv_type,
bool as_x) const { bool as_x) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern(conv_type);
patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x}; patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x};
elementwise_pattern( elementwise_pattern(
...@@ -127,6 +165,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( ...@@ -127,6 +165,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
return; return;
} }
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
}
conv_op->Op()->SetInput("ResidualData", {residual_data->Name()}); conv_op->Op()->SetInput("ResidualData", {residual_data->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
...@@ -155,15 +197,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( ...@@ -155,15 +197,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv(
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
const std::string& name_scope, const std::string& name_scope,
const GraphWithStats& graph_with_stats) const { const GraphWithStats& graph_with_stats,
const std::string& conv_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Conv conv_x_pattern{pattern, name_scope}; patterns::Conv conv_x_pattern{pattern, name_scope};
auto conv_x_output = conv_x_pattern(); auto conv_x_output = conv_x_pattern(conv_type);
patterns::Conv conv_y_pattern{pattern, name_scope}; patterns::Conv conv_y_pattern{pattern, name_scope};
auto conv_y_output = conv_y_pattern(); auto conv_y_output = conv_y_pattern(conv_type);
patterns::Elementwise elementwise_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add"); elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add");
...@@ -215,6 +258,9 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -215,6 +258,9 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return; if (HasFusedActivation(residual_conv_op)) return;
if (residual_conv_op->Op()->Type() == "conv2d") {
residual_conv_op->Op()->SetType("fused_conv2d");
}
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
...@@ -243,10 +289,18 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -243,10 +289,18 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const { void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto graph_with_stats = auto graph_with_stats =
FuseProjectionConv(name_scope_, std::make_pair(graph, 0)); FuseProjectionConv(name_scope_, std::make_pair(graph, 0), "fused_conv2d");
graph_with_stats = FuseConv(name_scope_, graph_with_stats, true); graph_with_stats =
graph_with_stats = FuseConv(name_scope_, graph_with_stats, false); FuseConv(name_scope_, graph_with_stats, "fused_conv2d", true);
graph_with_stats =
FuseConv(name_scope_, graph_with_stats, "fused_conv2d", false);
graph_with_stats =
FuseProjectionConv(name_scope_, graph_with_stats, "conv2d");
graph_with_stats = FuseConv(name_scope_, graph_with_stats, "conv2d", true);
graph_with_stats = FuseConv(name_scope_, graph_with_stats, "conv2d", false);
AddStatis(graph_with_stats.second); AddStatis(graph_with_stats.second);
} }
......
...@@ -27,10 +27,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -27,10 +27,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private: private:
GraphWithStats FuseConv(const std::string& name_scope, GraphWithStats FuseConv(const std::string& name_scope,
const GraphWithStats& graph_with_stats, const GraphWithStats& graph_with_stats,
const std::string& conv_type,
bool as_x) const; bool as_x) const;
GraphWithStats FuseProjectionConv( GraphWithStats FuseProjectionConv(const std::string& name_scope,
const std::string& name_scope, const GraphWithStats& graph_with_stats,
const GraphWithStats& graph_with_stats) const; const std::string& conv_type) const;
public: public:
ResidualConnectionMKLDNNFusePass(); ResidualConnectionMKLDNNFusePass();
......
...@@ -100,8 +100,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { ...@@ -100,8 +100,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
"relu", "relu",
nodes_removed, nodes_removed,
nodes_added)); nodes_added));
EXPECT_TRUE( EXPECT_TRUE(test::AssertOpsCount(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
} }
TEST(ConvElementwiseAddMKLDNNFusePass, TEST(ConvElementwiseAddMKLDNNFusePass,
...@@ -134,8 +134,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ...@@ -134,8 +134,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
"relu", "relu",
nodes_removed, nodes_removed,
nodes_added)); nodes_added));
EXPECT_TRUE( EXPECT_TRUE(test::AssertOpsCount(
test::AssertOpsCount(graph, {{"conv2d", 2}, {"elementwise_add", 0}})); graph, {{"conv2d", 1}, {"fused_conv2d", 1}, {"elementwise_add", 0}}));
} }
TEST(ConvElementwiseAddMKLDNNFusePass, TEST(ConvElementwiseAddMKLDNNFusePass,
...@@ -159,8 +159,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ...@@ -159,8 +159,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
"relu", "relu",
nodes_removed, nodes_removed,
nodes_added)); nodes_added));
EXPECT_TRUE( EXPECT_TRUE(test::AssertOpsCount(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
} }
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
...@@ -185,8 +185,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) { ...@@ -185,8 +185,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
"relu", "relu",
nodes_removed, nodes_removed,
nodes_added)); nodes_added));
EXPECT_TRUE( EXPECT_TRUE(test::AssertOpsCount(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
} }
TEST(ConvElementwiseAddMKLDNNFusePass, TEST(ConvElementwiseAddMKLDNNFusePass,
...@@ -210,8 +210,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ...@@ -210,8 +210,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
"relu", "relu",
nodes_removed, nodes_removed,
nodes_added)); nodes_added));
EXPECT_TRUE( EXPECT_TRUE(test::AssertOpsCount(
test::AssertOpsCount(graph, {{"conv2d", 1}, {"elementwise_add", 0}})); graph, {{"fused_conv2d", 1}, {"elementwise_add", 0}}));
} }
TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
......
...@@ -452,6 +452,9 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, ...@@ -452,6 +452,9 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
Graph* g) { Graph* g) {
VLOG(4) << "Quantize conv2d op"; VLOG(4) << "Quantize conv2d op";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
}
// skip if should not be quantized // skip if should not be quantized
if (!platform::HasOpINT8DataType(conv_op->Op())) { if (!platform::HasOpINT8DataType(conv_op->Op())) {
...@@ -1299,10 +1302,10 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1299,10 +1302,10 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument("Scope cannot be nullptr.")); platform::errors::InvalidArgument("Scope cannot be nullptr."));
GetQuantInfo(graph); GetQuantInfo(graph);
QuantizeConv(graph, "conv2d", false /* with_residual_data */);
QuantizeConv(graph, "conv2d", true /* with_residual_data */);
QuantizeConv(graph, "fused_conv2d", false /* with_residual_data */); QuantizeConv(graph, "fused_conv2d", false /* with_residual_data */);
QuantizeConv(graph, "fused_conv2d", true /* with_residual_data */); QuantizeConv(graph, "fused_conv2d", true /* with_residual_data */);
QuantizeConv(graph, "conv2d", false /* with_residual_data */);
QuantizeConv(graph, "conv2d", true /* with_residual_data */);
QuantizePool(graph); QuantizePool(graph);
QuantizeConcat(graph); QuantizeConcat(graph);
QuantizePriorBox(graph); QuantizePriorBox(graph);
......
...@@ -174,7 +174,7 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph, ...@@ -174,7 +174,7 @@ void PreparePass(std::unique_ptr<ir::Graph>* graph,
void CheckScales(const OpDesc* op, float scale, float shift) { void CheckScales(const OpDesc* op, float scale, float shift) {
std::string type = op->Type(); std::string type = op->Type();
std::vector<std::string> scale_names; std::vector<std::string> scale_names;
if (type == "conv2d" || type == "fc") { if (type == "conv2d" || type == "fused_conv2d" || type == "fc") {
EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights")[0], EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights")[0],
scale); scale);
scale_names.push_back("Scale_in"); scale_names.push_back("Scale_in");
...@@ -330,7 +330,7 @@ TEST(CpuQuantizePass, quantize) { ...@@ -330,7 +330,7 @@ TEST(CpuQuantizePass, quantize) {
// Insert nodes: 8 Quant + 8 IN + 7 OUT + 7 DEQUANT // Insert nodes: 8 Quant + 8 IN + 7 OUT + 7 DEQUANT
int added_nodes = 8 + 8 + 7 + 7; int added_nodes = 8 + 8 + 7 + 7;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"conv2d", 4}, {"pool2d", 2}, {"quantize", 8}, {"dequantize", 7}}; {"fused_conv2d", 4}, {"pool2d", 2}, {"quantize", 8}, {"dequantize", 7}};
MainTest(BuildProgramDesc(use_mkldnn, mkldnn_data_type), MainTest(BuildProgramDesc(use_mkldnn, mkldnn_data_type),
variable_names, variable_names,
expected_operators, expected_operators,
...@@ -343,7 +343,7 @@ TEST(CpuQuantizePass, do_not_quantize) { ...@@ -343,7 +343,7 @@ TEST(CpuQuantizePass, do_not_quantize) {
std::string mkldnn_data_type = "float32"; std::string mkldnn_data_type = "float32";
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"conv2d", 4}, {"pool2d", 2}, {"quantize", 0}, {"dequantize", 0}}; {"fused_conv2d", 4}, {"pool2d", 2}, {"quantize", 0}, {"dequantize", 0}};
MainTest(BuildProgramDesc(use_mkldnn, mkldnn_data_type), MainTest(BuildProgramDesc(use_mkldnn, mkldnn_data_type),
variable_names, variable_names,
expected_operators, expected_operators,
......
...@@ -341,6 +341,7 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { ...@@ -341,6 +341,7 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const {
if (dequant_in->outputs.size() == 1) { if (dequant_in->outputs.size() == 1) {
if (any_op->Op()->Type() == "conv2d" || if (any_op->Op()->Type() == "conv2d" ||
any_op->Op()->Type() == "fused_conv2d" ||
any_op->Op()->Type() == "conv2d_transpose" || any_op->Op()->Type() == "conv2d_transpose" ||
any_op->Op()->Type() == "fc") { any_op->Op()->Type() == "fc") {
// do not squash if fuse residual connection is true // do not squash if fuse residual connection is true
......
...@@ -60,9 +60,52 @@ Int8ScaleCalculationMkldnnPass::Int8ScaleCalculationMkldnnPass() { ...@@ -60,9 +60,52 @@ Int8ScaleCalculationMkldnnPass::Int8ScaleCalculationMkldnnPass() {
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"})
.End();
} }
void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const { void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
Int8ScaleImpl(graph, "fused_conv2d");
Int8ScaleImpl(graph, "conv2d");
}
void Int8ScaleCalculationMkldnnPass::Int8ScaleImpl(
ir::Graph* graph, const std::string& conv_type) const {
PADDLE_ENFORCE_NOT_NULL(graph, PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL.")); "Pointer to graph argument should not be NULL."));
...@@ -70,7 +113,7 @@ void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -70,7 +113,7 @@ void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::Conv conv_pattern(gpd.mutable_pattern(), patterns::Conv conv_pattern(gpd.mutable_pattern(),
"int8_scale_calculation_mkldnn_pass"); "int8_scale_calculation_mkldnn_pass");
conv_pattern(); conv_pattern(conv_type);
int found_int8_scales_count = 0; int found_int8_scales_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -80,6 +123,9 @@ void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -80,6 +123,9 @@ void Int8ScaleCalculationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
if (conv_op->Op()->Type() == "conv2d") {
conv_op->Op()->SetType("fused_conv2d");
}
if (!platform::HasOpINT8DataType(conv_op->Op()) || if (!platform::HasOpINT8DataType(conv_op->Op()) ||
conv_op->Op()->HasAttr("Sum_scale")) { conv_op->Op()->HasAttr("Sum_scale")) {
......
...@@ -31,6 +31,7 @@ class Int8ScaleCalculationMkldnnPass : public FusePassBase { ...@@ -31,6 +31,7 @@ class Int8ScaleCalculationMkldnnPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
void Int8ScaleImpl(ir::Graph* graph, const std::string& conv_type) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -77,7 +77,7 @@ void QuantizeConvInput(Scope* scope, ...@@ -77,7 +77,7 @@ void QuantizeConvInput(Scope* scope,
} // namespace } // namespace
ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() { ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() {
AddOpCompat(OpCompat("conv2d")) AddOpCompat(OpCompat("fused_conv2d"))
.AddInput("Input") .AddInput("Input")
.IsTensor() .IsTensor()
.End() .End()
...@@ -117,10 +117,11 @@ ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() { ...@@ -117,10 +117,11 @@ ParamsQuantizationMkldnnPass::ParamsQuantizationMkldnnPass() {
} }
void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph, void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
const std::string& conv_type,
bool with_residual_data) const { bool with_residual_data) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::ConvResidual conv_pattern(gpd.mutable_pattern(), name_scope_); patterns::ConvResidual conv_pattern(gpd.mutable_pattern(), name_scope_);
conv_pattern("conv2d", with_residual_data); conv_pattern(conv_type, with_residual_data);
int params_to_int8_conv_found = 0; int params_to_int8_conv_found = 0;
...@@ -159,8 +160,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph, ...@@ -159,8 +160,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
AddStatis(params_to_int8_conv_found); AddStatis(params_to_int8_conv_found);
std::stringstream msg_ss; std::stringstream msg_ss;
msg_ss << "Quantized params of " << params_to_int8_conv_found msg_ss << "Quantized params of " << params_to_int8_conv_found << " "
<< " conv2d ops"; << conv_type << " ops";
if (with_residual_data) msg_ss << " with residual connection"; if (with_residual_data) msg_ss << " with residual connection";
paddle::string::PrettyLogDetail(msg_ss.str().c_str()); paddle::string::PrettyLogDetail(msg_ss.str().c_str());
} }
...@@ -170,8 +171,8 @@ void ParamsQuantizationMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -170,8 +171,8 @@ void ParamsQuantizationMkldnnPass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL.")); "Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
QuantizeConv(graph, true /*with_residual_data*/); QuantizeConv(graph, "fused_conv2d", true /*with_residual_data*/);
QuantizeConv(graph, false /*with_residual_data*/); QuantizeConv(graph, "fused_conv2d", false /*with_residual_data*/);
} }
} // namespace ir } // namespace ir
......
...@@ -32,7 +32,9 @@ class ParamsQuantizationMkldnnPass : public FusePassBase { ...@@ -32,7 +32,9 @@ class ParamsQuantizationMkldnnPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_connection) const; void QuantizeConv(Graph* graph,
const std::string& conv_type,
bool with_residual_connection) const;
private: private:
const std::string name_scope_ = "params_quantization_mkldnn_pass"; const std::string name_scope_ = "params_quantization_mkldnn_pass";
......
...@@ -141,7 +141,7 @@ struct ConvProgramStrategy : public ProgramStrategy { ...@@ -141,7 +141,7 @@ struct ConvProgramStrategy : public ProgramStrategy {
protected: protected:
OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") { OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") {
auto op = program.MutableBlock(0)->AppendOp(); auto op = program.MutableBlock(0)->AppendOp();
op->SetType("conv2d"); op->SetType("fused_conv2d");
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
op->SetAttr("name", conv_name); op->SetAttr("name", conv_name);
op->SetAttr("mkldnn_data_type", std::string{"int8"}); op->SetAttr("mkldnn_data_type", std::string{"int8"});
......
...@@ -648,7 +648,8 @@ void QuantDequantMkldnnPass::DequantizeWeights( ...@@ -648,7 +648,8 @@ void QuantDequantMkldnnPass::DequantizeWeights(
for (auto* op_node : for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) { ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue; if (!op_node->IsOp()) continue;
if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") { if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d" ||
op_node->Name() == "fused_conv2d") {
if (onnx_format_quantize_model) { if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat(op_node, DequantizeOpWeightsFromONNXFormat(op_node,
scope, scope,
...@@ -708,8 +709,12 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -708,8 +709,12 @@ void QuantDequantMkldnnPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "quant_dequant_mkldnn_pass"; const std::string pattern_name = "quant_dequant_mkldnn_pass";
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
const std::unordered_set<std::string> skip_ops = { const std::unordered_set<std::string> skip_ops = {"conv2d",
"conv2d", "depthwise_conv2d", "mul", "matmul", "matmul_v2"}; "depthwise_conv2d",
"fused_conv2d",
"mul",
"matmul",
"matmul_v2"};
const std::unordered_set<std::string> fake_quantize_types = { const std::unordered_set<std::string> fake_quantize_types = {
"fake_quantize_moving_average_abs_max", "fake_quantize_range_abs_max"}; "fake_quantize_moving_average_abs_max", "fake_quantize_range_abs_max"};
......
...@@ -1023,10 +1023,10 @@ void OpDesc::Flush() { ...@@ -1023,10 +1023,10 @@ void OpDesc::Flush() {
[](std::pair<std::string, Attribute> a, [](std::pair<std::string, Attribute> a,
std::pair<std::string, Attribute> b) { return a.first < b.first; }); std::pair<std::string, Attribute> b) { return a.first < b.first; });
for (auto &attr : sorted_attrs) { for (auto &attr : sorted_runtime_attrs) {
set_attr_desc(attr.first, attr.second); set_attr_desc(attr.first, attr.second);
} }
for (auto &attr : sorted_runtime_attrs) { for (auto &attr : sorted_attrs) {
set_attr_desc(attr.first, attr.second); set_attr_desc(attr.first, attr.second);
} }
......
...@@ -129,6 +129,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -129,6 +129,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::unordered_set<int>(argument->quantize_excluded_op_ids())); new std::unordered_set<int>(argument->quantize_excluded_op_ids()));
} else if (pass_name == "cpu_quantize_pass") { } else if (pass_name == "cpu_quantize_pass") {
if (argument->quantize_enabled_op_types().count("conv2d") || if (argument->quantize_enabled_op_types().count("conv2d") ||
argument->quantize_enabled_op_types().count("fused_conv2d") ||
argument->quantize_enabled_op_types().count("depthwise_conv2d")) { argument->quantize_enabled_op_types().count("depthwise_conv2d")) {
pass->Set("data_layout", new std::string("NHWC")); pass->Set("data_layout", new std::string("NHWC"));
} }
......
...@@ -121,7 +121,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( ...@@ -121,7 +121,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
// force unsigned type if already know it // force unsigned type if already know it
bool is_unsigned = false; bool is_unsigned = false;
bool compute_scale = true; bool compute_scale = true;
if (op->Type() == "conv2d" || op->Type() == "fc") { if (op->Type() == "conv2d" || op->Type() == "fused_conv2d" ||
op->Type() == "fc") {
// output of conv2d with relu must be unsigned // output of conv2d with relu must be unsigned
std::string fuse_activation = std::string fuse_activation =
op->GetAttrIfExists<std::string>("fuse_activation"); op->GetAttrIfExists<std::string>("fuse_activation");
......
...@@ -39,14 +39,44 @@ def { ...@@ -39,14 +39,44 @@ def {
name: "data_format" name: "data_format"
type: STRING type: STRING
} }
}
extra {
attrs { attrs {
name: "fuse_activation" name: "fuse_activation"
type: STRING type: STRING
} }
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
attrs { attrs {
name: "fuse_residual_connection" name: "fuse_residual_connection"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "Scale_in"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "Scale_in_eltwise"
type: FLOAT
}
attrs {
name: "Scale_weights"
type: FLOATS
}
attrs {
name: "Bias_scales"
type: FLOATS
}
attrs { attrs {
name: "force_fp32_output" name: "force_fp32_output"
type: BOOLEAN type: BOOLEAN
......
...@@ -168,7 +168,7 @@ class TestConvBnFusePass(PassAutoScanTest): ...@@ -168,7 +168,7 @@ class TestConvBnFusePass(PassAutoScanTest):
if program_config.ops[0].attrs['use_mkldnn']: if program_config.ops[0].attrs['use_mkldnn']:
config = self.create_inference_config() config = self.create_inference_config()
config.enable_mkldnn() config.enable_mkldnn()
yield config, ['conv2d'], (1e-5, 1e-5) yield config, ['fused_conv2d'], (1e-5, 1e-5)
else: else:
config = self.create_inference_config() config = self.create_inference_config()
yield config, ['conv2d', 'elementwise_add'], (1e-5, 1e-5) yield config, ['conv2d', 'elementwise_add'], (1e-5, 1e-5)
......
...@@ -158,7 +158,7 @@ class TestConvConcatActivationMkldnnFusePass(PassAutoScanTest): ...@@ -158,7 +158,7 @@ class TestConvConcatActivationMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ['conv2d', 'conv2d', 'concat'], (1e-5, 1e-5) yield config, ['fused_conv2d', 'fused_conv2d', 'concat'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -118,7 +118,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -118,7 +118,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["relu", "conv2d", "conv2d"], (1e-5, 1e-5) yield config, ["relu", "conv2d", "fused_conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -96,7 +96,7 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest): ...@@ -96,7 +96,7 @@ class TestConvGeluMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -93,7 +93,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest): ...@@ -93,7 +93,7 @@ class TestConvHardSigmoidMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -98,7 +98,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest): ...@@ -98,7 +98,7 @@ class TestConvHardSwishMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -97,7 +97,7 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest): ...@@ -97,7 +97,7 @@ class TestConvMishMkldnnFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5) yield config, ["fused_conv2d"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册