未验证 提交 320958eb 编写于 作者: G gem5 提交者: GitHub

depthwise_conv 映射成 conv的逻辑中添加下cudnn版本的判断 (#50058)

上级 9737525a
......@@ -53,8 +53,13 @@ void MapOp2AnotherPass::ApplyImpl(ir::Graph* graph) const {
op_desc->SetAttr("shape", std::vector<int>{0, -1});
}
} else if (op_type == "depthwise_conv2d") {
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
auto groups = PADDLE_GET_CONST(int, op_desc->GetAttr("groups"));
if (groups > 1) {
#if CUDNN_VERSION >= 8100
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
#endif
}
}
op_desc->Flush();
++found_count;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册