未验证 提交 53690719 编写于 作者: J jakpiase 提交者: GitHub

[Need review] Added conv + hard_sigmoid oneDNN fuse pass (#36869)

* added conv + hard_sigmoid fuse pass

* Removed IsOptional() statements

* Reverted removing optional
上级 703487c6
......@@ -88,6 +88,13 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("beta"));
if (activation_type() == "hard_sigmoid") {
desc->SetAttr("fuse_alpha",
activation->Op()->GetAttrIfExists<float>("slope"));
desc->SetAttr("fuse_beta",
activation->Op()->GetAttrIfExists<float>("offset"));
}
GraphSafeRemoveNodes(graph, {activation, conv_out});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
......@@ -213,6 +220,26 @@ Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
.End();
}
Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// optional, default=0.2
.AddAttr("slope")
.IsOptional()
.IsType<float>()
.End()
// optional, default=0.5
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -259,3 +286,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_swish_mkldnn_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_swish", 0));
REGISTER_PASS(conv_hard_sigmoid_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DHardSigmoidFusePass);
REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("hard_sigmoid", 0));
......@@ -72,6 +72,15 @@ class Conv2DHardSwishFusePass : public ConvActivationFusePass {
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; }
};
/*
* Fuse Conv and HardSigmoid class
*/
class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
public:
Conv2DHardSigmoidFusePass();
std::string activation_type() const { return "hard_sigmoid"; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -148,6 +148,9 @@ TEST(ConvActivationFusePass, conv_swish_fuse_pass) { MainTest("swish"); }
TEST(ConvActivationFusePass, conv_hard_swish_fuse_pass) {
MainTest("hard_swish");
}
TEST(ConvActivationFusePass, conv_hard_sigmoid_fuse_pass) {
MainTest("hard_sigmoid");
}
} // namespace ir
} // namespace framework
......
......@@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
......
type: "hard_sigmoid"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
attrs {
name: "slope"
type: FLOAT
}
attrs {
name: "offset"
type: FLOAT
}
}
......@@ -475,23 +475,25 @@ class ConvMKLDNNHandlerT
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
constexpr float scale = 1.0f;
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(
scale, mkldnn::algorithm::eltwise_hardswish, fuse_alpha, fuse_beta);
} else if (fuse_activation == "hard_sigmoid") {
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_linear,
fuse_alpha, fuse_beta);
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_clip,
0.0f, 1.0f);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
......
......@@ -102,5 +102,14 @@ class ConvActivationMkldnnFusePassTest_5(ConvActivationMkldnnFusePassTest):
self.pass_name = 'conv_hard_swish_mkldnn_fuse_pass'
class ConvHardSigmoidOneDNNFusePassTest(ConvActivationMkldnnFusePassTest):
def set_params(self):
self.conv_num_filters = 5
self.conv_filter_size = 5
self.conv_bias_attr = True
self.act = "hard_sigmoid"
self.pass_name = 'conv_hard_sigmoid_mkldnn_fuse_pass'
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册