未验证 提交 bbf31a4e 编写于 作者: Z zhaoyingli 提交者: GitHub

bug fix (#39630)

上级 8c7ee8c2
......@@ -442,7 +442,7 @@ class Completer:
assert forward_op is not None
if grad_op.type == "concat" and forward_op.type == "split":
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
output_var = vars[grad_op.desc.output('Out')[0]]
split_input_var_name = forward_op.input("X")[0]
......@@ -458,14 +458,14 @@ class Completer:
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = ref_mesh
dist_context.set_tensor_dist_attr_for_program(
self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)
grad_op_dist_attr.process_mesh = ref_mesh
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr)
continue
# op dist attr
......@@ -579,6 +579,28 @@ class Completer:
# TODO to add attribute for moment var
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
self._dist_context.set_tensor_dist_attr_for_program(
out, out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input(
......
......@@ -142,7 +142,8 @@ class TensorCostNode(CostNode):
elif node.dtype == paddle.uint8:
self.dtype_factor = 1
else:
raise NotImplementedError("{} not counted".format(node.dtype))
self.dtype_factor = 2
# raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None
if batch_size is not None:
self.batch_size = batch_size
......
......@@ -86,7 +86,7 @@ class DistributedOperator:
tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
if tensor.type == core.VarDesc.VarType.READER or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
......
......@@ -26,7 +26,7 @@ from ..process_group import new_process_group
from ..dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.process_group import get_world_process_group
global_process_mesh = get_world_process_group().ranks
world_process_group = get_world_process_group()
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
......@@ -119,7 +119,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
main_block._sync_with_cpp()
# sync result
group = new_process_group(global_process_mesh)
group = new_process_group(world_process_group.ranks)
inf_var = main_block.var(kwargs['FoundInfinite'][0])
inf_var_int32 = main_block.create_var(
......
......@@ -222,6 +222,8 @@ class AutoParallelizer:
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
def parallelize(self,
......
......@@ -381,7 +381,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
dist_op_impl = get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx)
return dist_op_impl
dist_op = get_distributed_operator_impl_container("default")
......
......@@ -1013,18 +1013,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context))
block = auto_parallel_main_prog.global_block()
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
def _is_special_op(op):
global _g_special_ops
if op.type in _g_special_ops:
return True
return False
block = auto_parallel_main_prog.global_block()
idx = 0
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
if _is_special_op(op):
idx += 1
continue
......@@ -1053,6 +1053,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
# insert send and recv op if output process mesh is different from tensor process mesh
idx = 0
skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
skip_ops += _g_special_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
......
......@@ -26,7 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
global_process_mesh = get_world_process_group().ranks
world_process_group = get_world_process_group()
class AMPState(object):
......@@ -445,7 +445,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
set_var_dist_attr(dist_context, found_inf, [-1], global_process_mesh)
set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
......@@ -457,9 +457,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = global_process_mesh
if len(global_process_mesh) > 1:
new_op_dist_attr.process_mesh = world_process_group.ranks
new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale"
for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
......@@ -550,7 +551,7 @@ class AMPPass(PassBase):
dtype='float32',
persistable=True)
set_var_dist_attr(self.dist_context, self._loss_scaling, [-1],
global_process_mesh)
world_process_group.ranks)
if self.get_attr("use_dynamic_loss_scaling"):
self._num_good_steps = paddle.static.create_global_var(
......@@ -560,7 +561,7 @@ class AMPPass(PassBase):
dtype='int32',
persistable=True)
set_var_dist_attr(self.dist_context, self._num_good_steps, [-1],
global_process_mesh)
world_process_group.ranks)
self._num_bad_steps = paddle.static.create_global_var(
name=unique_name.generate("num_bad_steps"),
......@@ -569,7 +570,7 @@ class AMPPass(PassBase):
dtype='int32',
persistable=True)
set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1],
global_process_mesh)
world_process_group.ranks)
def _scale_loss(self):
......@@ -700,9 +701,10 @@ class AMPPass(PassBase):
attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = global_process_mesh
if len(global_process_mesh) > 1:
new_op_dist_attr.process_mesh = world_process_group.ranks
new_op_dist_attr.impl_idx = 0
if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "update_loss_scaling"
for g in grads:
g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None
......
......@@ -382,6 +382,7 @@ class RecomputePass(PassBase):
new_dist_attr = OperatorDistributedAttribute()
new_dist_attr.is_recompute = True
new_dist_attr.impl_idx = old_dist_attr.impl_idx
new_dist_attr.impl_type = old_dist_attr.impl_type
new_dist_attr.process_mesh = old_dist_attr.process_mesh
for input in old_dist_attr.inputs_dist_attrs.keys():
if input in var_name_dict.keys():
......
......@@ -40,7 +40,7 @@ class TestRecomputePass(AutoPallelPassTestBase):
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.recompute = True
dist_strategy.recompute_configs = {"checkpoints": ["tmp3", "tmp6"]}
dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册