未验证 提交 4c269ccb 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add conv_elementwise_act. (#43871)

* conv_fusion
上级 24d07b73
...@@ -105,6 +105,22 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { ...@@ -105,6 +105,22 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End(); .End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
} }
void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
...@@ -188,4 +204,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) ...@@ -188,4 +204,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.LE("conv2d", 1) .LE("conv2d", 1)
.LE("elementwise_add", 1) .LE("elementwise_add", 1)
.EQ("relu", 0) .EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0)); .EQ("identity", 0));
...@@ -102,6 +102,22 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { ...@@ -102,6 +102,22 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End(); .End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
} }
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
...@@ -170,4 +186,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) ...@@ -170,4 +186,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.LE("conv2d", 1) .LE("conv2d", 1)
.LE("elementwise_add", 1) .LE("elementwise_add", 1)
.EQ("relu", 0) .EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0)); .EQ("identity", 0));
...@@ -2324,7 +2324,8 @@ PDNode *patterns::PriorBox::operator()() { ...@@ -2324,7 +2324,8 @@ PDNode *patterns::PriorBox::operator()() {
return boxes_var; return boxes_var;
} }
std::unordered_set<std::string> conv_act_set({"identity", "relu"}); std::unordered_set<std::string> conv_act_set(
{"identity", "relu", "sigmoid", "tanh"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput(); conv_in->AsInput();
......
...@@ -544,9 +544,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -544,9 +544,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100 #if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, REGISTER_OP_CUDA_KERNEL(
ops::CUDNNConvFusionOpKernel<float>, conv2d_fusion,
ops::CUDNNConvFusionOpKernel<double>); ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>,
ops::CUDNNConvFusionOpKernel<paddle::platform::float16>);
#endif #endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>); REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册