diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 8960c47c1f5bf23e1bbe3d601e2a839350a7b416..8979239df5f11c7344e1127fba5e3c53fa709830 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1850,11 +1850,11 @@ class Completer: op_dist_attr.set_output_dims_mapping( input_var.name, ref_dims_mapping ) - - input_var_attr.process_mesh = ref_process_mesh - self._dist_context.set_tensor_dist_attr_for_program( - input_var, input_var_attr - ) + if "SkipUpdate" not in input_name: + input_var_attr.process_mesh = ref_process_mesh + self._dist_context.set_tensor_dist_attr_for_program( + input_var, input_var_attr + ) self._dist_context.set_op_dist_attr_for_program( op, op_dist_attr