未验证 提交 db2c71a4 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix 'op_role' for gradient merge & recompute (#44138)

* fix op_role

* fix engine

* update op_role
上级 7e3833a7
...@@ -18,7 +18,6 @@ from collections import defaultdict ...@@ -18,7 +18,6 @@ from collections import defaultdict
import paddle import paddle
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle import fluid, static from paddle import fluid, static
from paddle.io import Dataset from paddle.io import Dataset
...@@ -72,7 +71,6 @@ class Engine: ...@@ -72,7 +71,6 @@ class Engine:
self._saver = DistributedSaver() self._saver = DistributedSaver()
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
self._default_strategy = None
self._orig_main_prog = static.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program() self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
...@@ -117,9 +115,11 @@ class Engine: ...@@ -117,9 +115,11 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._modes = ['train', 'eval', 'predict'] self._modes = ['train', 'eval', 'predict']
self._build()
# Do auto parallel process # Build program and do auto parallel process
for mode in self._modes:
# Build forward program
self._build(mode)
for mode in self._modes: for mode in self._modes:
# Do the planning process # Do the planning process
self._plan(mode) self._plan(mode)
...@@ -129,56 +129,49 @@ class Engine: ...@@ -129,56 +129,49 @@ class Engine:
# Init comm and startup program # Init comm and startup program
self._initialize(mode) self._initialize(mode)
def _build(self): def _build(self, mode):
for mode in self._modes:
serial_main_prog = self._serial_main_progs.get(mode, None) serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None: if serial_main_prog is not None:
return return
losses = [] losses = []
metrics = [] metrics = []
serial_main_prog = self._orig_main_prog.clone() serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone() serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog), \ with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard(): utils.unique_name.guard():
inputs_spec = self.inputs_spec inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else [] labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec] inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec] labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs)) outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss: if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels))) losses = to_list(self._loss(*(outputs + labels)))
if mode != "predict": if mode != "predict":
for metric in self._metrics: for metric in self._metrics:
metrics.extend( metrics.extend(to_list(metric.compute(*(outputs + labels))))
to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context()
default_ctx = get_default_distributed_context() if not default_ctx.has_annotation:
if not default_ctx.has_annotation or self._default_strategy: # We build the world process group because the data parallel
# We build the world process group because the data parallel # needs all ranks by default.
# needs all ranks by default. new_process_group(list(range(self._nranks)))
new_process_group(list(range(self._nranks))) default_ctx.data_parallel = True
default_ctx.data_parallel = True
feed_vars = {"inputs": inputs, "labels": labels}
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels} fetch_vars = {
"outputs": flatten(outputs),
# self._fetch_vars[mode] = { "loss": losses,
# "outputs": flatten(outputs), "metrics": metrics
# "loss": losses, }
# "metrics": metrics
# } self._dist_contexts[mode] = DistributedContext(
fetch_vars = { serial_main_prog, serial_startup_prog, self._optimizer, losses,
"outputs": flatten(outputs), feed_vars, fetch_vars, self.cluster, self.strategy)
"loss": losses, self._dist_contexts[mode].gradient_scale = self._gradient_scale
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _plan(self, mode): def _plan(self, mode):
if self._planned_mode is None: if self._planned_mode is None:
...@@ -240,7 +233,6 @@ class Engine: ...@@ -240,7 +233,6 @@ class Engine:
continue continue
process_group.instantiate() process_group.instantiate()
# initialize
self._place = _get_device() self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace): if isinstance(self._place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id) self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
...@@ -273,8 +265,8 @@ class Engine: ...@@ -273,8 +265,8 @@ class Engine:
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch) epochs, steps_per_epoch)
usr_fetch = self._to_map_fetch(fetches) usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"]) fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
for epoch in range(epochs): for epoch in range(epochs):
...@@ -292,8 +284,7 @@ class Engine: ...@@ -292,8 +284,7 @@ class Engine:
user_outs = outs[len(fetch_loss):] user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):] user_fetch_list = fetch_list[len(fetch_loss):]
for i, out in enumerate(user_outs): for i, out in enumerate(user_outs):
train_logs["train_" + train_logs["train_" + fetch_map[user_fetch_list[i]]] = out
fetch_map[user_fetch_list[i]]] = out[0]
self._logger.info(train_logs) self._logger.info(train_logs)
def evaluate(self, def evaluate(self,
...@@ -307,9 +298,9 @@ class Engine: ...@@ -307,9 +298,9 @@ class Engine:
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size) eval_dataloader = self._create_dataloader(eval_data, batch_size)
usr_fetch = self._to_map_fetch(fetches) usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"]) fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"]) fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics) inner_fetch = dict(fetch_loss, **fetch_metrics)
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
...@@ -321,7 +312,7 @@ class Engine: ...@@ -321,7 +312,7 @@ class Engine:
return_numpy=return_numpy) return_numpy=return_numpy)
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
eval_logs["eval_loss"] = outs[0] eval_logs["eval_loss"] = outs[0][0]
# Metric # Metric
if fetch_metrics: if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)] metric_out = outs[len(fetch_loss):len(inner_fetch)]
...@@ -331,9 +322,9 @@ class Engine: ...@@ -331,9 +322,9 @@ class Engine:
for i, res in enumerate(to_list(results)): for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res eval_logs["eval_" + metric.name()[i]] = res
# usr fetches # usr fetches
usr_out = outs[len(inner_fetch):] usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):] usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_out): for i, out in enumerate(usr_outs):
eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
# logger # logger
self._logger.info(eval_logs) self._logger.info(eval_logs)
...@@ -349,8 +340,8 @@ class Engine: ...@@ -349,8 +340,8 @@ class Engine:
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(test_data, batch_size)
usr_fetch = self._to_map_fetch(fetches) usr_fetch = self._validate_fetches(fetches)
fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"]) fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch) fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = [] outputs = []
...@@ -362,42 +353,11 @@ class Engine: ...@@ -362,42 +353,11 @@ class Engine:
return_numpy=return_numpy) return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)]) outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs): for i, out in enumerate(outs):
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0] predict_logs["pred_" + fetch_map[fetch_list[i]]] = out
self._logger.info(predict_logs) self._logger.info(predict_logs)
return outputs return outputs
def _local_var(self, var):
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _to_map_fetch(self, fetches):
if not fetches:
return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
usr_fetches = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
usr_fetches = dict(zip(fetch_var_names, fetch_var_names))
return dict(filter(lambda x: self._local_var(x[0]),
usr_fetches.items()))
def _inner_fetch(self, fetch_vars):
fetch_list = list(
map(lambda x: x.name, list(filter(self._local_var, fetch_vars))))
inner_fetches = dict(zip(fetch_list, fetch_list))
return inner_fetches
def _fetch_map(self, inner_fetch, usr_fetch):
# replace inner fetch name if usr set for it
for iname in inner_fetch:
if iname in usr_fetch:
inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _create_dataloader(self, def _create_dataloader(self,
dataset, dataset,
batch_size, batch_size,
...@@ -468,26 +428,35 @@ class Engine: ...@@ -468,26 +428,35 @@ class Engine:
.format(i, spec)) .format(i, spec))
return specs return specs
def _set_data_parallel(self, var): def _is_local_var(self, var):
if self._nranks == 1: var_name = _to_name_str(var)
self._default_strategy = 'serial' return var_name in self.main_program.global_block().vars
auto.shard_tensor(var,
dist_attr={ def _validate_fetches(self, fetches):
"process_mesh": [0], # 1. Check user-defined fetches type
"dims_mapping": # 2. Prepare fetches_dict like {user_defined_name: var_name}
[-1 for _ in range(len(var.shape))] if not fetches:
}) return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
fetches_dict = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
fetches_dict = dict(zip(fetch_var_names, fetch_var_names))
else: else:
self._default_strategy = 'dp' raise TypeError("'fetches' only support 'dict' and 'list', "
auto.shard_tensor(var, "but got '{}'".format(str(type(fetches))))
dist_attr={ return dict(
"process_mesh": filter(lambda x: self._is_local_var(x[0]), fetches_dict.items()))
list(range(self._nranks)),
"dims_mapping": def _fetch_map(self, inner_fetch, usr_fetch):
[0] + [-1 for _ in range(len(var.shape) - 1)] # replace inner fetch name if usr set for it
}) for iname in inner_fetch:
if iname in usr_fetch:
return var inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _get_data_parallel_info(self, var, dist_context): def _get_data_parallel_info(self, var, dist_context):
# get data parallel world size and current data parallel rank # get data parallel world size and current data parallel rank
......
...@@ -137,7 +137,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -137,7 +137,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs={ attrs={
"in_dtype": inf_var.dtype, "in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Optimize
}) })
allreduce_op = main_block.append_op(type='c_allreduce_max', allreduce_op = main_block.append_op(type='c_allreduce_max',
inputs={'X': inf_var_int32}, inputs={'X': inf_var_int32},
...@@ -145,7 +145,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -145,7 +145,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs={ attrs={
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Optimize
}) })
cast_op2 = main_block.append_op(type='cast', cast_op2 = main_block.append_op(type='cast',
inputs={'X': inf_var_int32}, inputs={'X': inf_var_int32},
...@@ -153,7 +153,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -153,7 +153,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
attrs={ attrs={
"in_dtype": inf_var_int32.dtype, "in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var.dtype, "out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Optimize
}) })
main_block._sync_with_cpp() main_block._sync_with_cpp()
......
...@@ -222,7 +222,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -222,7 +222,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'W': [Weight_var] 'W': [Weight_var]
}, },
outputs={'Out': [intermediate_var_0]}, outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx}) attrs={
"start_index": relative_idx,
OP_ROLE_KEY: src_op.attr('op_role')
})
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
intermediate_var_0.desc.set_shape(ref_shape) intermediate_var_0.desc.set_shape(ref_shape)
...@@ -235,6 +238,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -235,6 +238,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
}) })
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -442,6 +446,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -442,6 +446,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
if need_gradient_allreduce: if need_gradient_allreduce:
added_ops = []
W_Grad_var = main_block.var(kwargs['W@GRAD'][0]) W_Grad_var = main_block.var(kwargs['W@GRAD'][0])
allreduce_op = main_block.append_op(type='c_allreduce_sum', allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [W_Grad_var]}, inputs={'X': [W_Grad_var]},
...@@ -451,19 +456,24 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -451,19 +456,24 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
scale_op = main_block.append_op(type='scale', added_ops.append(allreduce_op)
inputs={'X': W_Grad_var},
outputs={'Out': W_Grad_var}, if ctx.gradient_scale:
attrs={ scale_op = main_block.append_op(type='scale',
'scale': 1.0 / dp_degree, inputs={'X': W_Grad_var},
OP_ROLE_KEY: OpRole.Backward outputs={'Out': W_Grad_var},
}) attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
main_block._sync_with_cpp() main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).dims_mapping W_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in added_ops:
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
......
...@@ -405,6 +405,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -405,6 +405,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block): if need_gradient_allreduce and is_parameter_related(Y_var.name, main_block):
added_ops = []
Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0]) Y_Grad_var = main_block.var(kwargs['Y@GRAD'][0])
allreduce_op = main_block.append_op(type='c_allreduce_sum', allreduce_op = main_block.append_op(type='c_allreduce_sum',
inputs={'X': [Y_Grad_var]}, inputs={'X': [Y_Grad_var]},
...@@ -414,19 +415,24 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -414,19 +415,24 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
scale_op = main_block.append_op(type='scale', added_ops.append(allreduce_op)
inputs={'X': Y_Grad_var},
outputs={'Out': Y_Grad_var}, if ctx.gradient_scale:
attrs={ scale_op = main_block.append_op(type='scale',
'scale': 1.0 / dp_degree, inputs={'X': Y_Grad_var},
OP_ROLE_KEY: OpRole.Backward outputs={'Out': Y_Grad_var},
}) attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
main_block._sync_with_cpp() main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
Y_Grad_var).dims_mapping Y_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in added_ops:
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping) op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping)
...@@ -617,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -617,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
}) })
if intermediate_var_0.shape != ref_shape_x: if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
...@@ -629,6 +636,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -629,6 +636,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
'transpose_X': False, 'transpose_X': False,
'transpose_Y': False, 'transpose_Y': False,
'alpha': 1, 'alpha': 1,
OP_ROLE_KEY: src_op('op_role')
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_op = main_block.append_op(type='matmul', matmul_op = main_block.append_op(type='matmul',
...@@ -814,6 +822,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -814,6 +822,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
'transpose_X': False, 'transpose_X': False,
'transpose_Y': False, 'transpose_Y': False,
'alpha': 1, 'alpha': 1,
OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
...@@ -853,7 +862,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -853,7 +862,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
attrs={ attrs={
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
}) })
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -1137,6 +1147,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1137,6 +1147,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role'),
}) })
if intermediate_var_0.shape != ref_shape_x: if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
...@@ -1145,7 +1156,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1145,7 +1156,11 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64'], 'linear')
check_dtype(intermediate_var_0.dtype, 'dtype', check_dtype(intermediate_var_0.dtype, 'dtype',
['float16', 'float32', 'float64'], 'linear') ['float16', 'float32', 'float64'], 'linear')
attrs = {'trans_x': False, 'trans_y': False} attrs = {
'trans_x': False,
'trans_y': False,
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
matmul_v2_op = main_block.append_op(type='matmul_v2', matmul_v2_op = main_block.append_op(type='matmul_v2',
inputs=inputs, inputs=inputs,
...@@ -1322,7 +1337,11 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1322,7 +1337,11 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
'linear') 'linear')
check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], check_dtype(X_var.dtype, 'dtype', ['float16', 'float32', 'float64'],
'linear') 'linear')
attrs = {'trans_x': False, 'trans_y': False} attrs = {
'trans_x': False,
'trans_y': False,
OP_ROLE_KEY: src_op.attr('op_role')
}
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
# infer out var shape with op dist attr # infer out var shape with op dist attr
...@@ -1361,7 +1380,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1361,7 +1380,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
attrs={ attrs={
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
}) })
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
...@@ -1646,6 +1666,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -1646,6 +1666,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
}) })
if intermediate_var_0.shape != ref_shape_x: if intermediate_var_0.shape != ref_shape_x:
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
...@@ -1657,7 +1678,8 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -1657,7 +1678,8 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False} # attrs = {'trans_x': False, 'trans_y': False}
attrs = { attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"), "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims") "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]} inputs = {'X': [intermediate_var_0], 'Y': [Weight_var]}
mul_op = main_block.append_op(type='mul', mul_op = main_block.append_op(type='mul',
...@@ -1838,7 +1860,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -1838,7 +1860,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# attrs = {'trans_x': False, 'trans_y': False} # attrs = {'trans_x': False, 'trans_y': False}
attrs = { attrs = {
"x_num_col_dims": src_op.desc.attr("x_num_col_dims"), "x_num_col_dims": src_op.desc.attr("x_num_col_dims"),
"y_num_col_dims": src_op.desc.attr("y_num_col_dims") "y_num_col_dims": src_op.desc.attr("y_num_col_dims"),
OP_ROLE_KEY: src_op.attr('op_role')
} }
inputs = {'X': X_var, 'Y': Weight_var} inputs = {'X': X_var, 'Y': Weight_var}
...@@ -1878,7 +1901,8 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -1878,7 +1901,8 @@ class DistributedMulImpl1(DistributedOperatorImpl):
attrs={ attrs={
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True 'use_model_parallel': True,
OP_ROLE_KEY: src_op.attr('op_role')
}) })
if Out_var.shape != ref_shape: if Out_var.shape != ref_shape:
Out_var.desc.set_shape(ref_shape) Out_var.desc.set_shape(ref_shape)
......
...@@ -264,10 +264,12 @@ class Partitioner(object): ...@@ -264,10 +264,12 @@ class Partitioner(object):
self._dist_context, **kinputs, **koutputs, self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var}) **{"grad_var_to_var": grad_var_to_var})
elif is_optimize_op(op): elif is_optimize_op(op):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container( dist_op_opt_impl = _get_dist_op_backward_implement(
"default").get_impl(0) op, self._dist_context, forward_op_id2forward_op)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs) dist_op_opt_impl.backward(self._dist_context, **kinputs,
**koutputs)
else: else:
raise NotImplementedError( raise NotImplementedError(
"partitioner only support forward and backward, optimize ops, but got {}" "partitioner only support forward and backward, optimize ops, but got {}"
......
...@@ -1065,7 +1065,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1065,7 +1065,7 @@ def set_grad_var_shape(program, dist_context):
"softmax", "cross_entropy2", "dropout", "tanh", "softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt", "elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle_grad" "fused_softmax_mask_upper_triangle"
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
...@@ -1096,11 +1096,9 @@ OpRole = core.op_proto_and_checker_maker.OpRole ...@@ -1096,11 +1096,9 @@ OpRole = core.op_proto_and_checker_maker.OpRole
def is_forward_op(op): def is_forward_op(op):
ref_role1 = int(core.op_proto_and_checker_maker.OpRole.Forward)
ref_role2 = int(core.op_proto_and_checker_maker.OpRole.Loss)
op_role = int(op.attr('op_role')) op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == ref_role1 return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.Forward)
or op_role == ref_role2) or op_role == int(OpRole.Loss))
def is_backward_op(op): def is_backward_op(op):
...@@ -1113,9 +1111,14 @@ def is_optimize_op(op): ...@@ -1113,9 +1111,14 @@ def is_optimize_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize) int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
def is_lr_sched_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize.LRSched)
def is_loss_op(op): def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss))
def is_prim_op(op): def is_prim_op(op):
......
...@@ -452,7 +452,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context): ...@@ -452,7 +452,7 @@ def _check_and_update_gradient(params_grads, loss_scaling, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling} inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf} outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward} attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale', new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -732,7 +732,7 @@ class AMPPass(PassBase): ...@@ -732,7 +732,7 @@ class AMPPass(PassBase):
'incr_ratio': self.get_attr("incr_ratio"), 'incr_ratio': self.get_attr("incr_ratio"),
'decr_ratio': self.get_attr("decr_ratio"), 'decr_ratio': self.get_attr("decr_ratio"),
'stop_update': self.get_attr("stop_update"), 'stop_update': self.get_attr("stop_update"),
'op_role': OpRole.Backward 'op_role': OpRole.Optimize
} }
new_op = main_block.append_op(type='update_loss_scaling', new_op = main_block.append_op(type='update_loss_scaling',
......
...@@ -21,20 +21,13 @@ from paddle.framework import core ...@@ -21,20 +21,13 @@ from paddle.framework import core
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.framework import program_guard, device_guard from paddle.fluid.framework import program_guard, device_guard
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.auto_parallel.utils import set_var_dist_attr, is_optimize_op, OpRole, OP_ROLE_KEY
from paddle.distributed.auto_parallel.utils import set_var_dist_attr
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
def _is_the_optimizer_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
def _remove_and_get_optimizer_op(main_program, dist_context): def _remove_and_get_optimizer_op(main_program, dist_context):
# 1 create tmp block # 1 create tmp block
# 2 mv optimizer op from global program to tmp block # 2 mv optimizer op from global program to tmp block
...@@ -43,9 +36,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -43,9 +36,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block = main_program._create_block() temp_block = main_program._create_block()
removed_op_idx = [] removed_op_idx = []
optimize_ops_desc = [] optimize_ops_desc = []
skip_ops = ["increment", "elementwise_mod", "equal"]
for idx, op in enumerate(main_block.ops): for idx, op in enumerate(main_block.ops):
if _is_the_optimizer_op(op) and op.type not in skip_ops: if is_optimize_op(op):
# append optimizer op to tmp block # append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op() new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
...@@ -57,7 +49,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -57,7 +49,8 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
dist_context.del_dist_op_for_program(op) dist_context.del_dist_op_for_program(op)
for idx in removed_op_idx[::-1]: for idx in removed_op_idx[::-1]:
main_block._remove_op(idx) main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp()
return optimize_ops_desc return optimize_ops_desc
...@@ -109,7 +102,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -109,7 +102,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
outputs={'Out': [step_var]}, outputs={'Out': [step_var]},
attrs={ attrs={
'step': float(1.0), 'step': float(1.0),
'op_role': OpRole.Optimize OP_ROLE_KEY: OpRole.Backward
}) })
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context) increment_op, world_process_group.ranks, [-1], dist_context)
...@@ -123,7 +116,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -123,7 +116,8 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False, 'use_mkldnn': False,
'op_role': OpRole.Optimize OP_ROLE_KEY:
OpRole.Backward
}) })
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context) elementwise_mod_op, world_process_group.ranks, [-1], dist_context)
...@@ -134,7 +128,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -134,7 +128,7 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
'Y': zero_var 'Y': zero_var
}, },
outputs={'Out': cond_var}, outputs={'Out': cond_var},
attrs={'op_role': OpRole.Optimize}) attrs={OP_ROLE_KEY: OpRole.Backward})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context) equal_op, world_process_group.ranks, [-1], dist_context)
...@@ -143,7 +137,6 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -143,7 +137,6 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
def _append_gradient_merge_backward_op( def _append_gradient_merge_backward_op(
main_program, startup_program, params_grads: List[Tuple[Any, Any]], main_program, startup_program, params_grads: List[Tuple[Any, Any]],
cond_var_name: str,
dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]:
main_block = main_program.global_block() main_block = main_program.global_block()
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
...@@ -201,7 +194,7 @@ def _append_gradient_merge_backward_op( ...@@ -201,7 +194,7 @@ def _append_gradient_merge_backward_op(
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False, 'use_mkldnn': False,
'op_role': OpRole.Optimize OP_ROLE_KEY: OpRole.Backward
}) })
new_params_to_grads.append([param, gradient_merge_var]) new_params_to_grads.append([param, gradient_merge_var])
grad_to_gradient_merge[grad.name] = gradient_merge_var.name grad_to_gradient_merge[grad.name] = gradient_merge_var.name
...@@ -233,8 +226,7 @@ def _create_cond_block_and_update_optimizer( ...@@ -233,8 +226,7 @@ def _create_cond_block_and_update_optimizer(
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False 'bias_after_scale': False
}) })
new_grad.op._set_attr(op_maker.kOpRoleAttrName(), new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
OpRole.Optimize)
# append optimizer ops # append optimizer ops
for op_desc in optimize_ops_desc: for op_desc in optimize_ops_desc:
...@@ -272,29 +264,27 @@ def _create_cond_block_and_update_optimizer( ...@@ -272,29 +264,27 @@ def _create_cond_block_and_update_optimizer(
dtype=new_grad.dtype, dtype=new_grad.dtype,
value=0.0, value=0.0,
out=new_grad) out=new_grad)
new_grad.op._set_attr(op_maker.kOpRoleAttrName(), new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize)
op_maker.OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
cond_op = main_program.global_block().ops[-1] cond_op = main_program.global_block().ops[-1]
cond_op._set_attr('op_role', OpRole.Optimize) cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
def parse_program(main_program, startup_program, params_grads, k_steps, avg, def parse_program(main_program, startup_program, params_grads, k_steps, avg,
dist_context): dist_context):
# 1 create gradient_merge_cond # 1 remove optimizer_op from main_program
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)
# 2 remove optimizer_op from main_program
optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context) optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context)
# back to block 0 # back to block 0
main_program._rollback() main_program._rollback()
# 3 append gradient merge backward op to main_program # 2 append gradient merge backward op to main_program
new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op( new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, cond_var.name, main_program, startup_program, params_grads, dist_context)
dist_context)
# 3 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops # 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(main_program, cond_var, _create_cond_block_and_update_optimizer(main_program, cond_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册