未验证 提交 047cd95c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix update complete and add_to_collection (#48943)

* [AutoParallel] fix update complete and add_to_collection

* fix annotation

* fix amp fill_constant dist_attr
上级 e9eb5db3
......@@ -1706,6 +1706,7 @@ class Completer:
"elementwise_max",
"elementwise_div",
]:
# complete op dist_attr with global world ranks
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = world_ranks
for in_name in op.input_arg_names:
......@@ -1713,8 +1714,8 @@ class Completer:
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var
)
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr
op_dist_attr.set_input_dims_mapping(
in_name, in_dist_attr.dims_mapping
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
......@@ -1726,10 +1727,11 @@ class Completer:
self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr
op_dist_attr.set_output_dims_mapping(
out_name, out_dist_attr.dims_mapping
)
else:
# get ref_process_mesh and ref_dims_mapping from input_var
in_var = vars[op.input("X")[0]]
in_dist_attr = (
self._dist_context.get_tensor_dist_attr_for_program(
......@@ -1751,6 +1753,7 @@ class Completer:
assert ref_dist_attr is not None
ref_process_mesh = ref_dist_attr.process_mesh
# complete out_var's tensor_dist_attr
out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
......@@ -1766,14 +1769,26 @@ class Completer:
out_var, out_dist_attr
)
# complete op'd dist_attr
# complete op process_mesh with input_var's process_mesh
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(
in_var.name, in_dist_attr
)
op_dist_attr.set_output_dist_attr(
out_var.name, out_dist_attr
)
for in_name in op.input_arg_names:
in_var = vars[in_name]
in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
in_var
)
op_dist_attr.set_input_dims_mapping(
in_name, in_dist_attr.dims_mapping
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
out_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
out_var
)
op_dist_attr.set_output_dims_mapping(
out_name, out_dist_attr.dims_mapping
)
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
......
......@@ -493,10 +493,10 @@ class Engine:
# logging user fetches
collect_fetches = get_collection(CollectionNames.FETCHES)
logs_fetch = {}
for name, var in collect_fetches:
if var.name in fetch_names:
idx = fetch_names.index(var.name)
logs_fetch[name or var.name] = outs[idx]
for name, var_name in collect_fetches:
if var_name in fetch_names:
idx = fetch_names.index(var_name)
logs_fetch[name or var_name] = outs[idx]
logs["fetches"] = logs_fetch
return logs
......
......@@ -256,6 +256,16 @@ def add_to_collection(collection_name, value, name=None):
def fetch(tensor, name=None, logging=False):
if isinstance(tensor, paddle.fluid.framework.Variable):
tensor = tensor.name
elif isinstance(tensor, str):
tensor = tensor
else:
raise TypeError(
"Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format(
type(tensor)
)
)
add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name)
......@@ -800,6 +800,9 @@ class AMPPass(PassBase):
pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context
)
cast_grad_op = main_block._insert_op(
loss_op_idx + 3,
type='cast',
......@@ -871,6 +874,9 @@ class AMPPass(PassBase):
first_backward_op._rename_output(
pre_grad_name, self._scaled_loss_grad.name
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context
)
# FIXME(JZ-LIANG) a trick to insert backward op
main_block._sync_with_cpp()
elementwise_mul_grad_op_desc = main_block.desc._insert_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册