# skip the input and output of operators inserted in the reshard phase
ifname==var_nameandop_dist_attrisnotNone:
dist_op=dist_context.get_dist_op_for_program(op)
ifop.desc.id()==matched_op.desc.id():
ifdist_op:
op.desc._rename_input(name,target_tensor.name)
forvar_nameinop.output_arg_names:
op_dist_attr.set_input_dims_mapping(
ifvar_namenotinsub_block_op_outputs:
target_tensor.name,dims_mapping)
sub_block_op_outputs.append(var_name)
op_dist_attr.set_input_dist_attr(name,None)
forvar_nameinop.input_arg_names:
continue
sub_block_op_inputs.add(var_name)
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.
assertisinstance(rank_id,int),"The type of rank_id should be int, " \
continue
"but got {}.".format(type(rank_id))
assertisinstance(dist_context,DistributedContext),"The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
def_is_special_op(op):
# NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation.