未验证 提交 7b7ef3ba 编写于 作者: Z zhaoyingli 提交者: GitHub

fix oprole (#45610)

上级 6c5f9aa8
...@@ -224,13 +224,15 @@ class Partitioner(object): ...@@ -224,13 +224,15 @@ class Partitioner(object):
forward_op_id2forward_op[ forward_op_id2forward_op[
serial_ops[idx].desc.original_id()] = serial_ops[idx] serial_ops[idx].desc.original_id()] = serial_ops[idx]
appended_grad_times = 0
# partiiton # partiiton
appended_grad_times = 0
for idx, op in enumerate(serial_ops): for idx, op in enumerate(serial_ops):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1])
or is_loss_op(serial_ops[idx - 1])): or is_loss_op(serial_ops[idx - 1])):
appended_grad_times += 1 if not op_dist_attr.is_recompute:
appended_grad_times += 1
# partititon input variables # partititon input variables
for serial_input_varname in op.desc.input_arg_names(): for serial_input_varname in op.desc.input_arg_names():
...@@ -256,7 +258,6 @@ class Partitioner(object): ...@@ -256,7 +258,6 @@ class Partitioner(object):
serial_output_varname] = new_varname serial_output_varname] = new_varname
# partition op # partition op
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_forward_op(op) or op_dist_attr.is_recompute: if is_forward_op(op) or op_dist_attr.is_recompute:
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement( dist_op_forward_impl = _get_dist_op_forward_implement(
......
...@@ -18,14 +18,13 @@ import paddle ...@@ -18,14 +18,13 @@ import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import register_pass from .pass_base import register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group from paddle.distributed.auto_parallel.process_group import get_world_process_group
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op, OP_ROLE_KEY, OpRole
from .auto_parallel_amp import AMPPass from .auto_parallel_amp import AMPPass
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
...@@ -331,6 +330,7 @@ class FP16State(object): ...@@ -331,6 +330,7 @@ class FP16State(object):
attrs={ attrs={
"in_dtype": in_var.dtype, "in_dtype": in_var.dtype,
"out_dtype": cast_var.dtype, "out_dtype": cast_var.dtype,
OP_ROLE_KEY: OpRole.Forward
}) })
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context) cast_op, ref_mesh, ref_mapping, dist_context)
...@@ -409,6 +409,7 @@ class FP16State(object): ...@@ -409,6 +409,7 @@ class FP16State(object):
attrs={ attrs={
"in_dtype": dst_dtype, "in_dtype": dst_dtype,
"out_dtype": src_dtype, "out_dtype": src_dtype,
OP_ROLE_KEY: OpRole.Backward
}) })
grad.desc.set_dtype(src_dtype) grad.desc.set_dtype(src_dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册