diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 643408d49cbe6a0f2719fddf76fd904431b90098..4f2b2e798743bd0b291824e61c629b686c945e9e 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1706,6 +1706,7 @@ class Completer: "elementwise_max", "elementwise_div", ]: + # complete op dist_attr with global world ranks op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = world_ranks for in_name in op.input_arg_names: @@ -1713,8 +1714,8 @@ class Completer: in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( in_var ) - op_dist_attr.set_input_dist_attr( - in_name, in_dist_attr + op_dist_attr.set_input_dims_mapping( + in_name, in_dist_attr.dims_mapping ) for out_name in op.output_arg_names: out_var = vars[out_name] @@ -1726,10 +1727,11 @@ class Completer: self._dist_context.set_tensor_dist_attr_for_program( out_var, out_dist_attr ) - op_dist_attr.set_output_dist_attr( - out_name, out_dist_attr + op_dist_attr.set_output_dims_mapping( + out_name, out_dist_attr.dims_mapping ) else: + # get ref_process_mesh and ref_dims_mapping from input_var in_var = vars[op.input("X")[0]] in_dist_attr = ( self._dist_context.get_tensor_dist_attr_for_program( @@ -1751,6 +1753,7 @@ class Completer: assert ref_dist_attr is not None ref_process_mesh = ref_dist_attr.process_mesh + # complete out_var's tensor_dist_attr out_var = vars[op.output("Out")[0]] out_dist_attr = TensorDistributedAttribute() out_dist_attr.process_mesh = ref_process_mesh @@ -1766,14 +1769,26 @@ class Completer: out_var, out_dist_attr ) + # complete op'd dist_attr + # complete op process_mesh with input_var's process_mesh op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = ref_process_mesh - op_dist_attr.set_input_dist_attr( - in_var.name, in_dist_attr - ) - op_dist_attr.set_output_dist_attr( - out_var.name, out_dist_attr - ) + for in_name in op.input_arg_names: + in_var = vars[in_name] + in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + in_var + ) + op_dist_attr.set_input_dims_mapping( + in_name, in_dist_attr.dims_mapping + ) + for out_name in op.output_arg_names: + out_var = vars[out_name] + out_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + out_var + ) + op_dist_attr.set_output_dims_mapping( + out_name, out_dist_attr.dims_mapping + ) self._dist_context.set_op_dist_attr_for_program( op, op_dist_attr diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index dc7470283aef8859d3c46462a117f3b6bb427e6d..e9ed861106617b5f76418dcf736c7b3fc72bd348 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -493,10 +493,10 @@ class Engine: # logging user fetches collect_fetches = get_collection(CollectionNames.FETCHES) logs_fetch = {} - for name, var in collect_fetches: - if var.name in fetch_names: - idx = fetch_names.index(var.name) - logs_fetch[name or var.name] = outs[idx] + for name, var_name in collect_fetches: + if var_name in fetch_names: + idx = fetch_names.index(var_name) + logs_fetch[name or var_name] = outs[idx] logs["fetches"] = logs_fetch return logs diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index 882b63b39395b1254dfa8112fe3c072fb20a64f7..98316100a8dbcef04e3267ef47b616bd407569f9 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -256,6 +256,16 @@ def add_to_collection(collection_name, value, name=None): def fetch(tensor, name=None, logging=False): + if isinstance(tensor, paddle.fluid.framework.Variable): + tensor = tensor.name + elif isinstance(tensor, str): + tensor = tensor + else: + raise TypeError( + "Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format( + type(tensor) + ) + ) add_to_collection(CollectionNames.FETCHES, tensor, name) if logging: add_to_collection(CollectionNames.LOGGING, tensor, name) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index cba613676d58de12ac44fa8ffa3e412c4002ef93..06b8e4a19e88f90cf0f44b0a89d7237e346f12bb 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -800,6 +800,9 @@ class AMPPass(PassBase): pre_grad_name = first_backward_op.output_arg_names[0] first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + first_backward_op, ref_mesh, [-1], self.dist_context + ) cast_grad_op = main_block._insert_op( loss_op_idx + 3, type='cast', @@ -871,6 +874,9 @@ class AMPPass(PassBase): first_backward_op._rename_output( pre_grad_name, self._scaled_loss_grad.name ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + first_backward_op, ref_mesh, [-1], self.dist_context + ) # FIXME(JZ-LIANG) a trick to insert backward op main_block._sync_with_cpp() elementwise_mul_grad_op_desc = main_block.desc._insert_op(