diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index 613482eb6cd73edad8e10581a1a81ef260ae3c47..ff064fb2ee93fc540e932da36fb07bb78eef989a 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -23,8 +23,15 @@ namespace lite { namespace mir { void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { + std::vector act_types{"relu"}; + for (auto& place : graph->valid_places()) { + if (place.target == TARGET(kCUDA)) { + act_types.push_back("leaky_relu"); + break; + } + } for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { - for (auto act_type : {"relu", "leaky_relu"}) { + for (auto act_type : act_types) { for (auto has_bias : {true, false}) { fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); fuser(graph.get());