未验证 提交 db2c71a4 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix 'op_role' for gradient merge & recompute (#44138)

* fix op_role

* fix engine

* update op_role
上级 7e3833a7
......@@ -18,7 +18,6 @@ from collections import defaultdict
import paddle
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle import fluid, static
from paddle.io import Dataset
......@@ -72,7 +71,6 @@ class Engine:
self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO)
self._default_strategy = None
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
......@@ -117,9 +115,11 @@ class Engine:
self._planned_mode = None
self._modes = ['train', 'eval', 'predict']
self._build()
# Do auto parallel process
# Build program and do auto parallel process
for mode in self._modes:
# Build forward program
self._build(mode)
for mode in self._modes:
# Do the planning process
self._plan(mode)
......@@ -129,56 +129,49 @@ class Engine:
# Init comm and startup program
self._initialize(mode)
def _build(self):
for mode in self._modes:
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation or self._default_strategy:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _build(self, mode):
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
if mode != "predict":
for metric in self._metrics:
metrics.extend(to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation:
# We build the world process group because the data parallel
# needs all ranks by default.
new_process_group(list(range(self._nranks)))
default_ctx.data_parallel = True
feed_vars = {"inputs": inputs, "labels": labels}
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _plan(self, mode):
if self._planned_mode is None:
......@@ -240,7 +233,6 @@ class Engine:
continue
process_group.instantiate()
# initialize
self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
......@@ -273,8 +265,8 @@ class Engine:
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
usr_fetch = self._to_map_fetch(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
for epoch in range(epochs):
......@@ -292,8 +284,7 @@ class Engine:
user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):]
for i, out in enumerate(user_outs):
train_logs["train_" +
fetch_map[user_fetch_list[i]]] = out[0]
train_logs["train_" + fetch_map[user_fetch_list[i]]] = out
self._logger.info(train_logs)
def evaluate(self,
......@@ -307,9 +298,9 @@ class Engine:
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
usr_fetch = self._to_map_fetch(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"])
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics)
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
......@@ -321,7 +312,7 @@ class Engine:
return_numpy=return_numpy)
# inner fetches
if fetch_loss:
eval_logs["eval_loss"] = outs[0]
eval_logs["eval_loss"] = outs[0][0]
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
......@@ -331,9 +322,9 @@ class Engine:
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
# usr fetches
usr_out = outs[len(inner_fetch):]
usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_out):
for i, out in enumerate(usr_outs):
eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
# logger
self._logger.info(eval_logs)
......@@ -349,8 +340,8 @@ class Engine:
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
usr_fetch = self._to_map_fetch(fetches)
fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"])
usr_fetch = self._validate_fetches(fetches)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = []
......@@ -362,42 +353,11 @@ class Engine:
return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0]
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out
self._logger.info(predict_logs)
return outputs
def _local_var(self, var):
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _to_map_fetch(self, fetches):
if not fetches:
return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
usr_fetches = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
usr_fetches = dict(zip(fetch_var_names, fetch_var_names))
return dict(filter(lambda x: self._local_var(x[0]),
usr_fetches.items()))
def _inner_fetch(self, fetch_vars):
fetch_list = list(
map(lambda x: x.name, list(filter(self._local_var, fetch_vars))))
inner_fetches = dict(zip(fetch_list, fetch_list))
return inner_fetches
def _fetch_map(self, inner_fetch, usr_fetch):
# replace inner fetch name if usr set for it
for iname in inner_fetch:
if iname in usr_fetch:
inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _create_dataloader(self,
dataset,
batch_size,
......@@ -468,26 +428,35 @@ class Engine:
.format(i, spec))
return specs
def _set_data_parallel(self, var):
if self._nranks == 1:
self._default_strategy = 'serial'
auto.shard_tensor(var,
dist_attr={
"process_mesh": [0],
"dims_mapping":
[-1 for _ in range(len(var.shape))]
})
def _is_local_var(self, var):
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _validate_fetches(self, fetches):
# 1. Check user-defined fetches type
# 2. Prepare fetches_dict like {user_defined_name: var_name}
if not fetches:
return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
fetches_dict = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
fetches_dict = dict(zip(fetch_var_names, fetch_var_names))
else:
self._default_strategy = 'dp'
auto.shard_tensor(var,
dist_attr={
"process_mesh":
list(range(self._nranks)),
"dims_mapping":
[0] + [-1 for _ in range(len(var.shape) - 1)]
})
return var
raise TypeError("'fetches' only support 'dict' and 'list', "
"but got '{}'".format(str(type(fetches))))
return dict(
filter(lambda x: self._is_local_var(x[0]), fetches_dict.items()))
def _fetch_map(self, inner_fetch, usr_fetch):
# replace inner fetch name if usr set for it
for iname in inner_fetch:
if iname in usr_fetch:
inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _get_data_parallel_info(self, var, dist_context):
# get data parallel world size and current data parallel rank
......
......@@ -137,7 +137,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Optimize
})
allreduce_op = main_block.append_op(type='c_allreduce_max',
inputs={'X': inf_var_int32},
......@@ -145,7 +145,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs={
'ring_id': group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Optimize
})
cast_op2 = main_block.append_op(type='cast',
inputs={'X': inf_var_int32},
......@@ -153,7 +153,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Optimize
})
main_block._sync_with_cpp()
......
......@@ -222,7 +222,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'W': [Weight_var]
},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
attrs={
"start_index": relative_idx,
OP_ROLE_KEY: src_op.attr('op_role')
})
if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape)
......@@ -235,6 +238,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -442,6 +446,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dp_group = new_process_group(group_ranks)
if need_gradient_allreduce:
added_ops = []
W_Grad_var = main_block.var(kwargs['W@GRAD'][0])
allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [W_Grad_var]},
......@@ -451,19 +456,24 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
scale_op = main_block.append_op(type='scale',
inputs={'X': W_Grad_var},
outputs={'Out': W_Grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(allreduce_op)
if ctx.gradient_scale:
scale_op = main_block.append_op(type='scale',
inputs={'X': W_Grad_var},
outputs={'Out': W_Grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
for op in added_ops:
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
......
......@@ -405,6 +405,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_group = new_process_group(group_ranks)
if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
added_ops = []
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [Y_Grad_var]},
......@@ -414,19 +415,24 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
scale_op = main_block.append_op(type='scale',
inputs={'X': Y_Grad_var},
outputs={'Out': Y_Grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(allreduce_op)
if ctx.gradient_scale:
scale_op = main_block.append_op(type='scale',
inputs={'X': Y_Grad_var},
outputs={'Out': Y_Grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
Y_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
for op in added_ops:
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping)
......@@ -617,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
......@@ -629,6 +636,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
OP_ROLE_KEY: src_op('op_role')
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_op = main_block.append_op(type='matmul',
......@@ -814,6 +822,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'transpose_X': False,
'transpose_Y': False,
'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': X_var, 'Y': Weight_var}
......@@ -853,7 +862,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -1137,6 +1147,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role'),
})
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
......@@ -1145,7 +1156,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear')
attrs = {'trans_x': False, 'trans_y': False}
attrs = {
'trans_x': False,
'trans_y': False,
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_v2_op = main_block.append_op(type='matmul_v2',
inputs=inputs,
......@@ -1322,7 +1337,11 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'linear')
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear')
attrs = {'trans_x': False, 'trans_y': False}
attrs = {
'trans_x': False,
'trans_y': False,
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': X_var, 'Y': Weight_var}
# infer out var shape with op dist attr
......@@ -1361,7 +1380,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......@@ -1646,6 +1666,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x)
......@@ -1657,7 +1678,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims")
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
mul_op = main_block.append_op(type='mul',
......@@ -1838,7 +1860,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims")
"y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': X_var, 'Y': Weight_var}
......@@ -1878,7 +1901,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True
'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
})
if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape)
......
......@@ -264,10 +264,12 @@ class Partitioner(object):
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
elif is_optimize_op(op):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container(
"default").get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs)
dist_op_opt_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
dist_op_opt_impl.backward(self._dist_context, **kinputs,
**koutputs)
else:
raise NotImplementedError(
"partitioner only support forward and backward, optimize ops, but got {}"
......
......@@ -1065,7 +1065,7 @@ def set_grad_var_shape(program, dist_context):
"softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle_grad"
"fused_softmax_mask_upper_triangle"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......@@ -1096,11 +1096,9 @@ OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op):
ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward)
ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss)
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1
or op_role == ref_role2)
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.Forward)
or op_role == int(OpRole.Loss))
def is_backward_op(op):
......@@ -1113,9 +1111,14 @@ def is_optimize_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
def is_lr_sched_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize.LRSched)
def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss))
def is_prim_op(op):
......
......@@ -452,7 +452,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
......@@ -732,7 +732,7 @@ class AMPPass(PassBase):
'incr_ratio': self.get_attr("incr_ratio"),
'decr_ratio': self.get_attr("decr_ratio"),
'stop_update': self.get_attr("stop_update"),
'op_role': OpRole.Backward
'op_role': OpRole.Optimize
}
new_op = main_block.append_op(type='update_loss_scaling',
......
......@@ -21,20 +21,13 @@ from paddle.framework import core
from paddle.fluid import layers
from paddle.fluid.framework import program_guard, device_guard
from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.utils import set_var_dist_attr
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, is_optimize_op, OpRole, OP_ROLE_KEY
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_process_group = get_world_process_group()
def _is_the_optimizer_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
def _remove_and_get_optimizer_op(main_program, dist_context):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
......@@ -43,9 +36,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block = main_program._create_block()
removed_op_idx = []
optimize_ops_desc = []
skip_ops = ["increment", "elementwise_mod", "equal"]
for idx, op in enumerate(main_block.ops):
if _is_the_optimizer_op(op) and op.type not in skip_ops:
if is_optimize_op(op):
# append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc)
......@@ -57,7 +49,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
dist_context.del_dist_op_for_program(op)
for idx in removed_op_idx[::-1]:
main_block._remove_op(idx)
main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp()
return optimize_ops_desc
......@@ -109,7 +102,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
outputs={'Out': [step_var]},
attrs={
'step': float(1.0),
'op_role': OpRole.Optimize
OP_ROLE_KEY: OpRole.Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context)
......@@ -123,7 +116,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
attrs={
'axis': -1,
'use_mkldnn': False,
'op_role': OpRole.Optimize
OP_ROLE_KEY:
OpRole.Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context)
......@@ -134,7 +128,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
'Y': zero_var
},
outputs={'Out': cond_var},
attrs={'op_role': OpRole.Optimize})
attrs={OP_ROLE_KEY: OpRole.Backward})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context)
......@@ -143,7 +137,6 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
def _append_gradient_merge_backward_op(
main_program, startup_program, params_grads: List[Tuple[Any, Any]],
cond_var_name: str,
dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]:
main_block = main_program.global_block()
startup_block = startup_program.global_block()
......@@ -201,7 +194,7 @@ def _append_gradient_merge_backward_op(
attrs={
'axis': -1,
'use_mkldnn': False,
'op_role': OpRole.Optimize
OP_ROLE_KEY: OpRole.Backward
})
new_params_to_grads.append([param, gradient_merge_var])
grad_to_gradient_merge[grad.name] = gradient_merge_var.name
......@@ -233,8 +226,7 @@ def _create_cond_block_and_update_optimizer(
'bias': 0.0,
'bias_after_scale': False
})
new_grad.op._set_attr(op_maker.kOpRoleAttrName(),
OpRole.Optimize)
new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
# append optimizer ops
for op_desc in optimize_ops_desc:
......@@ -272,29 +264,27 @@ def _create_cond_block_and_update_optimizer(
dtype=new_grad.dtype,
value=0.0,
out=new_grad)
new_grad.op._set_attr(op_maker.kOpRoleAttrName(),
op_maker.OpRole.Optimize)
new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
cond_op = main_program.global_block().ops[-1]
cond_op._set_attr('op_role', OpRole.Optimize)
cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
def parse_program(main_program, startup_program, params_grads, k_steps, avg,
dist_context):
# 1 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)
# 2 remove optimizer_op from main_program
# 1 remove optimizer_op from main_program
optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context)
# back to block 0
main_program._rollback()
# 3 append gradient merge backward op to main_program
# 2 append gradient merge backward op to main_program
new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, cond_var.name,
dist_context)
main_program, startup_program, params_grads, dist_context)
# 3 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(main_program, cond_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册