未验证 提交 5cf3f898 编写于 作者: Y Yuang Liu 提交者: GitHub

[cuda graph] bug fix for cuda graph static mode (#43539)

上级 890c7315
......@@ -173,7 +173,7 @@ def construct_program_and_find_ins_outs(section, origin_program, section_idx):
# This in var is generated from op outside this section
# Only record once for same input
ins.append(in_name)
elif later_ins.count(in_name) == 0:
elif later_ins.count(in_name) == 0 and outs.count(in_name) > 0:
# this is var is generated from op inside this section, and only will be used inside this section
outs.remove(in_name)
for out_name in op.output_arg_names:
......@@ -248,13 +248,13 @@ def get_cuda_graph_sections(program):
sub_block_related = (op.type == 'conditional_block'
or op.type == 'while')
if loss_related or sub_block_related:
# if loss_related is True
# If loss_related is True
# The internal section contains loss related ops,
# although these ops are between two cuda graph sections with same graph id,
# they belong to none of these two sections.
# The loss related op should be wrapped by user explicitly.
# if sub_block_related is True
# If sub_block_related is True
# The internal section contains while op or conditional block op.
# These two ops are not supported by cuda graph. Won't extend the section.
internal_section = []
......@@ -274,6 +274,7 @@ def get_cuda_graph_sections(program):
current_section.append(internal_section[i])
current_idx.append(internal_idx[i])
internal_section = []
internal_idx = []
current_section.append(op)
current_idx.append(idx)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册