未验证 提交 756140a8 编写于 作者: X Xiaoyang LI 提交者: GitHub

turn off conv_leaky_relu fusion when target is not cuda

上级 b0fdeba0
...@@ -23,8 +23,15 @@ namespace lite { ...@@ -23,8 +23,15 @@ namespace lite {
namespace mir { namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"};
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kCUDA)) {
act_types.push_back("leaky_relu");
break;
}
}
for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { for (auto conv_type : {"conv2d", "depthwise_conv2d"}) {
for (auto act_type : {"relu", "leaky_relu"}) { for (auto act_type : act_types) {
for (auto has_bias : {true, false}) { for (auto has_bias : {true, false}) {
fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias);
fuser(graph.get()); fuser(graph.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册