diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py index 5c9c8740d85bbd1c2c2397649f4fa320f4d856d3..76cac9082a325374132da57c64d189176ea2a81e 100644 --- a/python/paddle/device/cuda/graphs.py +++ b/python/paddle/device/cuda/graphs.py @@ -173,7 +173,7 @@ def construct_program_and_find_ins_outs(section, origin_program, section_idx): # This in var is generated from op outside this section # Only record once for same input ins.append(in_name) - elif later_ins.count(in_name) == 0: + elif later_ins.count(in_name) == 0 and outs.count(in_name) > 0: # this is var is generated from op inside this section, and only will be used inside this section outs.remove(in_name) for out_name in op.output_arg_names: @@ -248,13 +248,13 @@ def get_cuda_graph_sections(program): sub_block_related = (op.type == 'conditional_block' or op.type == 'while') if loss_related or sub_block_related: - # if loss_related is True + # If loss_related is True # The internal section contains loss related ops, # although these ops are between two cuda graph sections with same graph id, # they belong to none of these two sections. # The loss related op should be wrapped by user explicitly. - # if sub_block_related is True + # If sub_block_related is True # The internal section contains while op or conditional block op. # These two ops are not supported by cuda graph. Won't extend the section. internal_section = [] @@ -274,6 +274,7 @@ def get_cuda_graph_sections(program): current_section.append(internal_section[i]) current_idx.append(internal_idx[i]) internal_section = [] + internal_idx = [] current_section.append(op) current_idx.append(idx) else: