未验证 提交 047cd95c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix update complete and add_to_collection (#48943)

* [AutoParallel] fix update complete and add_to_collection

* fix annotation

* fix amp fill_constant dist_attr
上级 e9eb5db3
...@@ -1706,6 +1706,7 @@ class Completer: ...@@ -1706,6 +1706,7 @@ class Completer:
"elementwise_max", "elementwise_max",
"elementwise_div", "elementwise_div",
]: ]:
# complete op dist_attr with global world ranks
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = world_ranks op_dist_attr.process_mesh = world_ranks
for in_name in op.input_arg_names: for in_name in op.input_arg_names:
...@@ -1713,8 +1714,8 @@ class Completer: ...@@ -1713,8 +1714,8 @@ class Completer:
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var in_var
) )
op_dist_attr.set_input_dist_attr( op_dist_attr.set_input_dims_mapping(
in_name, in_dist_attr in_name, in_dist_attr.dims_mapping
) )
for out_name in op.output_arg_names: for out_name in op.output_arg_names:
out_var = vars[out_name] out_var = vars[out_name]
...@@ -1726,10 +1727,11 @@ class Completer: ...@@ -1726,10 +1727,11 @@ class Completer:
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr out_var, out_dist_attr
) )
op_dist_attr.set_output_dist_attr( op_dist_attr.set_output_dims_mapping(
out_name, out_dist_attr out_name, out_dist_attr.dims_mapping
) )
else: else:
# get ref_process_mesh and ref_dims_mapping from input_var
in_var = vars[op.input("X")[0]] in_var = vars[op.input("X")[0]]
in_dist_attr = ( in_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program( self._dist_context.get_tensor_dist_attr_for_program(
...@@ -1751,6 +1753,7 @@ class Completer: ...@@ -1751,6 +1753,7 @@ class Completer:
assert ref_dist_attr is not None assert ref_dist_attr is not None
ref_process_mesh = ref_dist_attr.process_mesh ref_process_mesh = ref_dist_attr.process_mesh
# complete out_var's tensor_dist_attr
out_var = vars[op.output("Out")[0]] out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute() out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh out_dist_attr.process_mesh = ref_process_mesh
...@@ -1766,14 +1769,26 @@ class Completer: ...@@ -1766,14 +1769,26 @@ class Completer:
out_var, out_dist_attr out_var, out_dist_attr
) )
# complete op'd dist_attr
# complete op process_mesh with input_var's process_mesh
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr( for in_name in op.input_arg_names:
in_var.name, in_dist_attr in_var = vars[in_name]
) in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
op_dist_attr.set_output_dist_attr( in_var
out_var.name, out_dist_attr )
) op_dist_attr.set_input_dims_mapping(
in_name, in_dist_attr.dims_mapping
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
out_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
out_var
)
op_dist_attr.set_output_dims_mapping(
out_name, out_dist_attr.dims_mapping
)
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr op, op_dist_attr
......
...@@ -493,10 +493,10 @@ class Engine: ...@@ -493,10 +493,10 @@ class Engine:
# logging user fetches # logging user fetches
collect_fetches = get_collection(CollectionNames.FETCHES) collect_fetches = get_collection(CollectionNames.FETCHES)
logs_fetch = {} logs_fetch = {}
for name, var in collect_fetches: for name, var_name in collect_fetches:
if var.name in fetch_names: if var_name in fetch_names:
idx = fetch_names.index(var.name) idx = fetch_names.index(var_name)
logs_fetch[name or var.name] = outs[idx] logs_fetch[name or var_name] = outs[idx]
logs["fetches"] = logs_fetch logs["fetches"] = logs_fetch
return logs return logs
......
...@@ -256,6 +256,16 @@ def add_to_collection(collection_name, value, name=None): ...@@ -256,6 +256,16 @@ def add_to_collection(collection_name, value, name=None):
def fetch(tensor, name=None, logging=False): def fetch(tensor, name=None, logging=False):
if isinstance(tensor, paddle.fluid.framework.Variable):
tensor = tensor.name
elif isinstance(tensor, str):
tensor = tensor
else:
raise TypeError(
"Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format(
type(tensor)
)
)
add_to_collection(CollectionNames.FETCHES, tensor, name) add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging: if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name) add_to_collection(CollectionNames.LOGGING, tensor, name)
...@@ -800,6 +800,9 @@ class AMPPass(PassBase): ...@@ -800,6 +800,9 @@ class AMPPass(PassBase):
pre_grad_name = first_backward_op.output_arg_names[0] pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name) first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context
)
cast_grad_op = main_block._insert_op( cast_grad_op = main_block._insert_op(
loss_op_idx + 3, loss_op_idx + 3,
type='cast', type='cast',
...@@ -871,6 +874,9 @@ class AMPPass(PassBase): ...@@ -871,6 +874,9 @@ class AMPPass(PassBase):
first_backward_op._rename_output( first_backward_op._rename_output(
pre_grad_name, self._scaled_loss_grad.name pre_grad_name, self._scaled_loss_grad.name
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context
)
# FIXME(JZ-LIANG) a trick to insert backward op # FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp() main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op( elementwise_mul_grad_op_desc = main_block.desc._insert_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册