未验证 提交 7b7ef3ba 编写于 作者: Z zhaoyingli 提交者: GitHub

fix oprole (#45610)

上级 6c5f9aa8
......@@ -224,13 +224,15 @@ class Partitioner(object):
forward_op_id2forward_op[
serial_ops[idx].desc.original_id()] = serial_ops[idx]
appended_grad_times = 0
# partiiton
appended_grad_times = 0
for idx, op in enumerate(serial_ops):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1])
or is_loss_op(serial_ops[idx - 1])):
appended_grad_times += 1
if not op_dist_attr.is_recompute:
appended_grad_times += 1
# partititon input variables
for serial_input_varname in op.desc.input_arg_names():
......@@ -256,7 +258,6 @@ class Partitioner(object):
serial_output_varname] = new_varname
# partition op
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_forward_op(op) or op_dist_attr.is_recompute:
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement(
......
......@@ -18,14 +18,13 @@ import paddle
from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
from paddle.fluid.contrib.mixed_precision.fp16_utils import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32, _need_keep_fp32, _valid_types, _dtype_to_str
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op
from paddle.distributed.auto_parallel.utils import is_forward_op, is_backward_op, OP_ROLE_KEY, OpRole
from .auto_parallel_amp import AMPPass
world_process_group = get_world_process_group()
......@@ -331,6 +330,7 @@ class FP16State(object):
attrs={
"in_dtype": in_var.dtype,
"out_dtype": cast_var.dtype,
OP_ROLE_KEY: OpRole.Forward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, ref_mapping, dist_context)
......@@ -409,6 +409,7 @@ class FP16State(object):
attrs={
"in_dtype": dst_dtype,
"out_dtype": src_dtype,
OP_ROLE_KEY: OpRole.Backward
})
grad.desc.set_dtype(src_dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册