未验证 提交 bbf31a4e 编写于 作者: Z zhaoyingli 提交者: GitHub

bug fix (#39630)

上级 8c7ee8c2
...@@ -442,7 +442,7 @@ class Completer: ...@@ -442,7 +442,7 @@ class Completer:
assert forward_op is not None assert forward_op is not None
if grad_op.type == "concat" and forward_op.type == "split": if grad_op.type == "concat" and forward_op.type == "split":
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
output_var = vars[grad_op.desc.output('Out')[0]] output_var = vars[grad_op.desc.output('Out')[0]]
split_input_var_name = forward_op.input("X")[0] split_input_var_name = forward_op.input("X")[0]
...@@ -458,14 +458,14 @@ class Completer: ...@@ -458,14 +458,14 @@ class Completer:
output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = ref_mesh output_var_dist_attr.process_mesh = ref_mesh
dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr) output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(output_var.name, grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping) ref_dims_mapping)
grad_op_dist_attr.process_mesh = ref_mesh grad_op_dist_attr.process_mesh = ref_mesh
dist_context.set_op_dist_attr_for_program(grad_op, self._dist_context.set_op_dist_attr_for_program(
grad_op_dist_attr) grad_op, grad_op_dist_attr)
continue continue
# op dist attr # op dist attr
...@@ -579,6 +579,28 @@ class Completer: ...@@ -579,6 +579,28 @@ class Completer:
# TODO to add attribute for moment var # TODO to add attribute for moment var
op = ops[idx] op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
self._dist_context.set_tensor_dist_attr_for_program(
out, out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
self._dist_context.set_op_dist_attr_for_program(
op, op_dist_attr)
if "Grad" in op.input_names and "Param" in ops[idx].input_names: if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input( assert len(op.input(
......
...@@ -142,7 +142,8 @@ class TensorCostNode(CostNode): ...@@ -142,7 +142,8 @@ class TensorCostNode(CostNode):
elif node.dtype == paddle.uint8: elif node.dtype == paddle.uint8:
self.dtype_factor = 1 self.dtype_factor = 1
else: else:
raise NotImplementedError("{} not counted".format(node.dtype)) self.dtype_factor = 2
# raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None self.batch_size = None
if batch_size is not None: if batch_size is not None:
self.batch_size = batch_size self.batch_size = batch_size
......
...@@ -86,7 +86,7 @@ class DistributedOperator: ...@@ -86,7 +86,7 @@ class DistributedOperator:
tensor_dims_mapping) tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names: for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name) tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER: if tensor.type == core.VarDesc.VarType.READER or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = [] tensor_shape = []
else: else:
tensor_shape = tensor.shape tensor_shape = tensor.shape
......
...@@ -26,7 +26,7 @@ from ..process_group import new_process_group ...@@ -26,7 +26,7 @@ from ..process_group import new_process_group
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.process_group import get_world_process_group from paddle.distributed.auto_parallel.process_group import get_world_process_group
global_process_mesh = get_world_process_group().ranks world_process_group = get_world_process_group()
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
...@@ -119,7 +119,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -119,7 +119,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
main_block._sync_with_cpp() main_block._sync_with_cpp()
# sync result # sync result
group = new_process_group(global_process_mesh) group = new_process_group(world_process_group.ranks)
inf_var = main_block.var(kwargs['FoundInfinite'][0]) inf_var = main_block.var(kwargs['FoundInfinite'][0])
inf_var_int32 = main_block.create_var( inf_var_int32 = main_block.create_var(
......
...@@ -222,6 +222,8 @@ class AutoParallelizer: ...@@ -222,6 +222,8 @@ class AutoParallelizer:
HAS_ALLGATHER.clear() HAS_ALLGATHER.clear()
_g_process_group_map.clear() _g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
for process_mesh in dist_context._process_meshes:
_g_process_group_map[0].add_ranks(process_mesh.processes)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
def parallelize(self, def parallelize(self,
......
...@@ -381,7 +381,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -381,7 +381,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0 assert op_dist_attr.impl_idx >= 0
dist_op_impl = get_distributed_operator_impl_container( dist_op_impl = get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx) op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx)
return dist_op_impl return dist_op_impl
dist_op = get_distributed_operator_impl_container("default") dist_op = get_distributed_operator_impl_container("default")
......
...@@ -1013,18 +1013,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -1013,18 +1013,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \
"but got {}.".format(type(dist_context)) "but got {}.".format(type(dist_context))
def _is_special_op(op):
global _g_special_ops
if op.type in _g_special_ops:
return True
return False
block = auto_parallel_main_prog.global_block() block = auto_parallel_main_prog.global_block()
idx = 0 idx = 0
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
def _is_special_op(op):
global _g_special_ops
if op.type in _g_special_ops:
return True
return False
if _is_special_op(op): if _is_special_op(op):
idx += 1 idx += 1
continue continue
...@@ -1053,6 +1053,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -1053,6 +1053,7 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
# insert send and recv op if output process mesh is different from tensor process mesh # insert send and recv op if output process mesh is different from tensor process mesh
idx = 0 idx = 0
skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"] skip_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
skip_ops += _g_special_ops
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
......
...@@ -26,7 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k ...@@ -26,7 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
global_process_mesh = get_world_process_group().ranks world_process_group = get_world_process_group()
class AMPState(object): class AMPState(object):
...@@ -445,7 +445,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -445,7 +445,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False)
set_var_dist_attr(dist_context, found_inf, [-1], global_process_mesh) set_var_dist_attr(dist_context, found_inf, [-1], world_process_group.ranks)
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
...@@ -457,9 +457,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -457,9 +457,10 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
attrs=attrs) attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = global_process_mesh new_op_dist_attr.process_mesh = world_process_group.ranks
if len(global_process_mesh) > 1: new_op_dist_attr.impl_idx = 0
new_op_dist_attr.impl_idx = 0 if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "check_finite_and_unscale"
for g in grads: for g in grads:
g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g) g_dist_attr = dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None assert g_dist_attr is not None
...@@ -550,7 +551,7 @@ class AMPPass(PassBase): ...@@ -550,7 +551,7 @@ class AMPPass(PassBase):
dtype='float32', dtype='float32',
persistable=True) persistable=True)
set_var_dist_attr(self.dist_context, self._loss_scaling, [-1], set_var_dist_attr(self.dist_context, self._loss_scaling, [-1],
global_process_mesh) world_process_group.ranks)
if self.get_attr("use_dynamic_loss_scaling"): if self.get_attr("use_dynamic_loss_scaling"):
self._num_good_steps = paddle.static.create_global_var( self._num_good_steps = paddle.static.create_global_var(
...@@ -560,7 +561,7 @@ class AMPPass(PassBase): ...@@ -560,7 +561,7 @@ class AMPPass(PassBase):
dtype='int32', dtype='int32',
persistable=True) persistable=True)
set_var_dist_attr(self.dist_context, self._num_good_steps, [-1], set_var_dist_attr(self.dist_context, self._num_good_steps, [-1],
global_process_mesh) world_process_group.ranks)
self._num_bad_steps = paddle.static.create_global_var( self._num_bad_steps = paddle.static.create_global_var(
name=unique_name.generate("num_bad_steps"), name=unique_name.generate("num_bad_steps"),
...@@ -569,7 +570,7 @@ class AMPPass(PassBase): ...@@ -569,7 +570,7 @@ class AMPPass(PassBase):
dtype='int32', dtype='int32',
persistable=True) persistable=True)
set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1], set_var_dist_attr(self.dist_context, self._num_bad_steps, [-1],
global_process_mesh) world_process_group.ranks)
def _scale_loss(self): def _scale_loss(self):
...@@ -700,9 +701,10 @@ class AMPPass(PassBase): ...@@ -700,9 +701,10 @@ class AMPPass(PassBase):
attrs=attrs) attrs=attrs)
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
new_op_dist_attr.process_mesh = global_process_mesh new_op_dist_attr.process_mesh = world_process_group.ranks
if len(global_process_mesh) > 1: new_op_dist_attr.impl_idx = 0
new_op_dist_attr.impl_idx = 0 if len(world_process_group.ranks) > 1:
new_op_dist_attr.impl_type = "update_loss_scaling"
for g in grads: for g in grads:
g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g) g_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(g)
assert g_dist_attr is not None assert g_dist_attr is not None
......
...@@ -382,6 +382,7 @@ class RecomputePass(PassBase): ...@@ -382,6 +382,7 @@ class RecomputePass(PassBase):
new_dist_attr = OperatorDistributedAttribute() new_dist_attr = OperatorDistributedAttribute()
new_dist_attr.is_recompute = True new_dist_attr.is_recompute = True
new_dist_attr.impl_idx = old_dist_attr.impl_idx new_dist_attr.impl_idx = old_dist_attr.impl_idx
new_dist_attr.impl_type = old_dist_attr.impl_type
new_dist_attr.process_mesh = old_dist_attr.process_mesh new_dist_attr.process_mesh = old_dist_attr.process_mesh
for input in old_dist_attr.inputs_dist_attrs.keys(): for input in old_dist_attr.inputs_dist_attrs.keys():
if input in var_name_dict.keys(): if input in var_name_dict.keys():
......
...@@ -40,7 +40,7 @@ class TestRecomputePass(AutoPallelPassTestBase): ...@@ -40,7 +40,7 @@ class TestRecomputePass(AutoPallelPassTestBase):
def apply_passes(self): def apply_passes(self):
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.recompute = True dist_strategy.recompute = True
dist_strategy.recompute_configs = {"checkpoints": ["tmp3", "tmp6"]} dist_strategy.recompute_configs = {"checkpoints": ["tmp_3", "tmp_6"]}
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册