未验证 提交 4c269ccb 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Add conv_elementwise_act. (#43871)

* conv_fusion
上级 24d07b73
......@@ -105,6 +105,22 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
......@@ -188,4 +204,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.LE("conv2d", 1)
.LE("elementwise_add", 1)
.EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0));
......@@ -102,6 +102,22 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
......@@ -170,4 +186,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.LE("conv2d", 1)
.LE("elementwise_add", 1)
.EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0));
......@@ -2324,7 +2324,8 @@ PDNode *patterns::PriorBox::operator()() {
return boxes_var;
}
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
std::unordered_set<std::string> conv_act_set(
{"identity", "relu", "sigmoid", "tanh"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput();
......
......@@ -544,9 +544,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(conv2d_fusion,
REGISTER_OP_CUDA_KERNEL(
conv2d_fusion,
ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
ops::CUDNNConvFusionOpKernel<double>,
ops::CUDNNConvFusionOpKernel<paddle::platform::float16>);
#endif
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册