未验证 提交 bc2265f8 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] adapt lazyinit & fix pass (#45840)

* adapt lazy init and fix pass

* add unittest

* update comment

* fix amp and sharding

* remove clip_by_norm
上级 abc85c50
......@@ -26,14 +26,7 @@ from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collat
class DistributedDataLoader(metaclass=abc.ABCMeta):
def __init__(self,
dataset,
batch_size=1,
epochs=1,
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False,
split_data=True):
def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
......@@ -42,19 +35,11 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.dataset = dataset
self.epochs = epochs
self.drop_lost = drop_last
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank
self.split_data = split_data
if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
if data_parallel_world_size is not None:
for dp_world_size in data_parallel_world_size:
if dp_world_size is not None:
assert batch_size % dp_world_size == 0, \
"batch_size must be divisible by dp_world_size value {}".format(str(dp_world_size))
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
......@@ -97,18 +82,22 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
epochs=1,
steps_per_epoch=None,
collate_fn=None,
data_parallel_world_size=None,
data_parallel_rank=None,
data_parallel_world_size=[],
data_parallel_rank=[],
drop_last=False,
split_data=True):
self.feed_list = feed_list
self.places = places
self.steps_per_epoch = steps_per_epoch
assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs,
data_parallel_world_size, data_parallel_rank,
drop_last, split_data)
self).__init__(dataset, batch_size, epochs, drop_last)
if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
......@@ -154,13 +143,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def _create_inner_dataloader(self):
def sample_data_generator():
def data_generator():
while True:
try:
indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices)
if batch is None: break
except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
......@@ -169,53 +157,23 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
break
partial_data = []
for i, d in enumerate(batch[:len(self.feed_list)]):
for i, d in enumerate(batch):
array = np.array(d)
if not self.split_data:
partial_data.append(array)
elif self.dp_world_sizes[i] is not None:
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
else:
partial_data.append(array)
yield partial_data
def batch_data_generator():
while True:
try:
indices = next(self.sampler_iter)
continue
batch = self.dataset_fetcher.fetch(indices)
if batch is None: break
except StopIteration:
break
partial_data = []
for i, d in enumerate(batch[:len(self.feed_list)]):
array = np.array(d)
if not self.split_data:
partial_data.append(array)
elif self.dp_world_sizes[i] is not None:
batch_size = array.shape[0]
assert batch_size % self.dp_world_sizes[i] == 0, \
"batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
else:
partial_data.append(array)
yield partial_data
self.dp_world_sizes = [
1 for _ in range(len(self.feed_list))
] if self.data_parallel_world_size is None else self.data_parallel_world_size
self.dp_ranks = [
0 for _ in range(len(self.feed_list))
] if self.data_parallel_rank is None else self.data_parallel_rank
yield partial_data
dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
if self.batch_size is not None:
dataloader.set_batch_generator(sample_data_generator, self.places)
else:
dataloader.set_batch_generator(batch_data_generator, self.places)
dataloader.set_batch_generator(data_generator, self.places)
return dataloader
......@@ -36,7 +36,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.passes import new_pass, PassContext
from .hepler import ProgramHelper
from .helper import ProgramHelper
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner
......@@ -118,8 +118,7 @@ class Engine:
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = optimizer
self._all_ranks = all_ranks
self._optimizer = self._validate_opt(optimizer)
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
......@@ -136,6 +135,7 @@ class Engine:
self._metrics = to_list(metrics)
self._gradient_scale = gradient_scale
self._planned_mode = None
self._all_ranks = all_ranks
self._prepare_single_mode("train")
def _prepare_single_mode(self, mode):
......@@ -161,21 +161,23 @@ class Engine:
self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.")
program_helper = ProgramHelper(self.model, self._loss,
self._metrics, self.inputs_spec,
self.labels_spec)
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
self.program_helper = ProgramHelper(self.model, self._loss,
self._metrics, inputs_spec,
labels_spec)
# build forward main program
program_helper.build_program(mode)
self.program_helper.build_program(mode)
self.concrete_program = program_helper.concrete_program
serial_main_prog = program_helper.main_program
serial_startup_prog = program_helper.startup_program
self.concrete_program = self.program_helper.concrete_program
serial_main_prog = self.program_helper.main_program
serial_startup_prog = self.program_helper.startup_program
inputs = program_helper.input_vars
outputs = program_helper.output_vars
labels = program_helper.label_vars
losses = program_helper.loss_vars
metrics = program_helper.metric_vars
inputs = self.program_helper.input_vars
outputs = self.program_helper.output_vars
labels = self.program_helper.label_vars
losses = self.program_helper.loss_vars
metrics = self.program_helper.metric_vars
paddle.enable_static()
else:
......@@ -334,40 +336,17 @@ class Engine:
continue
process_group.instantiate()
self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
self._place = fluid.CUDAPlace(ParallelEnv().dev_id)
place = _get_device()
if isinstance(place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._dygraph_mode:
paddle.disable_static()
main_program = self._dist_main_progs[mode][self._cur_rank]
for param in self.concrete_program.parameters:
# create var in scope and share parameters to scope
if param.name not in main_program.global_block().vars:
continue
# get param_var's dist_attr
var = main_program.global_block().vars[param.name]
var_dist_attr = self._dist_contexts[
mode].get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
from .converter import Converter
param_tensor = global_scope().var(param.name).get_tensor()
sliced_param = Converter.slice_with_dist_attr(
param.numpy(), dist_attr)
shared_tensor = paddle.to_tensor(sliced_param,
place=self._place)
param_tensor._share_data_with(
shared_tensor.value().get_tensor())
paddle.enable_static()
dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank]
self.program_helper.init(dist_main_program, place, dist_context)
if self._executor is None:
self._executor = paddle.static.Executor(self._place)
self._executor = paddle.static.Executor(place)
uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
for var in dist_startup_prog.list_vars():
......@@ -411,7 +390,7 @@ class Engine:
data = np.array(param_t)
param_t.set(np.float16(data), place)
cast_parameters_to_fp16(self._place, prune_startup_prog)
cast_parameters_to_fp16(place, prune_startup_prog)
def fit(self,
train_data,
......@@ -577,15 +556,20 @@ class Engine:
dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list from dist_program, then insert dataloader op
# with sharded var shape. Because predict_program does not contain
# labels var, so we will filter dataset's value with length of feed_list.
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self.mode]["inputs"]
labels_var = self._feed_vars[self.mode]["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name])
else:
copy_var = dist_main_block._clone_variable(var, var.persistable)
copy_var.desc.set_original_id(var.desc.original_id())
feed_list.append(copy_var)
# remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops)
......@@ -688,7 +672,7 @@ class Engine:
batch_size_axis, rank_id)
return len(group_ranks), group_ranks.index(rank_id)
return None, None
return 1, 0
def _set_recompute_ckpts(self):
# NOTE hack to enable recompute in engine api for GPT-3
......@@ -717,6 +701,11 @@ class Engine:
}
self._logger.info(logs)
def _validate_opt(self, optimizer):
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
def save(self, path, training=True, mode=None):
if not mode:
mode = self.mode
......
......@@ -15,14 +15,18 @@
import logging
from collections import defaultdict
import paddle
from paddle.nn import Layer
from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import program_guard
from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from .utils import to_list
from .converter import Converter
class ProxyLayer(Layer):
......@@ -89,13 +93,14 @@ class ProxyLayer(Layer):
# step 4. calculate metrics if needed
self._metric_vars[mode] = self.call_metrics(new_inputs)
def _predict(self, inputs):
def _predict(self, inputs, labels):
"""
Predict process of inner_layer with forward logic.
"""
# step 1. save feed variables of Program
mode = 'predict'
self._input_vars[mode] = inputs
self._label_vars[mode] = labels
# step 2. call inner_layer.forward
self._output_vars[mode] = self.inner_layer(*inputs)
......@@ -165,6 +170,10 @@ class ProxyLayer(Layer):
def metric_vars(self):
return self._metric_vars[self.mode]
@property
def startup_program(self):
return self.inner_layer._startup_program()
class BuildInfo:
......@@ -199,6 +208,7 @@ class ProgramHelper(object):
self.build_info = BuildInfo()
self._logger = get_logger(logging.INFO)
self.lazy_init = False
def reset(self):
"""
......@@ -221,8 +231,7 @@ class ProgramHelper(object):
return
self._logger.info("start to build program for mode = %s." % mode)
input_spec = [self.inputs_spec, self.labels_spec
] if mode != 'predict' else [self.inputs_spec]
input_spec = [self.inputs_spec, self.labels_spec]
static_func = to_static(self.static_func(), input_spec=input_spec)
func_name = '_' + mode
......@@ -238,6 +247,9 @@ class ProgramHelper(object):
"""
Create and Sync parameters into startup program.
"""
if len(self.startup_program.global_block().ops) > 1:
self.lazy_init = True
return
for param in self.concrete_program.parameters:
Parameter(name=param.name,
desc=param,
......@@ -294,6 +306,28 @@ class ProgramHelper(object):
func_name = '_' + self.proxy_layer.mode
return getattr(self.proxy_layer, func_name)
def init(self, main_program, place, dist_context):
if self.lazy_init:
return
for param in self.concrete_program.parameters:
# create var in scope and share parameters to scope
if param.name not in main_program.global_block().vars:
continue
# get param_var's dist_attr
var = main_program.global_block().vars[param.name]
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
dist_attr = {
"dims_mapping": var_dist_attr.dims_mapping,
"process_shape": var_dist_attr.process_mesh.topology,
"process_group": var_dist_attr.process_mesh.processes
}
# slice param_value with dist_attr
# share sliced_param_value with param_tensor in global_scope
param_tensor = global_scope().var(param.name).get_tensor()
sliced_param = Converter.slice_with_dist_attr(
param.numpy(), dist_attr)
param_tensor.set(sliced_param, place)
@property
def concrete_program(self):
return self.static_func().concrete_program
......@@ -304,7 +338,12 @@ class ProgramHelper(object):
@property
def startup_program(self):
try:
return self.proxy_layer.startup_program
except Exception as err:
if isinstance(err, AssertionError):
return self.concrete_program.startup_program
raise err
@property
def input_vars(self):
......
......@@ -145,11 +145,6 @@ class Parallelizer:
params_grads):
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
if self._dist_context._dygraph_mode:
paddle.disable_static()
optimizer = copy.deepcopy(optimizer)
paddle.enable_static()
else:
optimizer = copy.deepcopy(optimizer)
self._dist_context._lr_optimizer = optimizer
with program_guard(main_program, startup_program):
......@@ -222,6 +217,7 @@ class Parallelizer:
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context)
......
......@@ -270,7 +270,7 @@ class Inserter:
dtype=tensor_type,
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
cast_op = block._insert_op(idx,
type='cast',
inputs={'X': [tensor]},
outputs={'Out': [out]},
......@@ -279,6 +279,7 @@ class Inserter:
'out_dtype': out.dtype,
'op_role': op_role
})
cast_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
@staticmethod
......@@ -287,7 +288,7 @@ class Inserter:
op_type = 'send_v2'
# use pair comm group
process_group = new_process_group([src, dst])
block._insert_op(idx,
send_op = block._insert_op(idx,
type=op_type,
inputs={'X': [tensor]},
attrs={
......@@ -297,6 +298,7 @@ class Inserter:
'op_role': op_role,
'dynamic_shape': True
})
send_op._set_attr('op_namescope', "/auto_parallel/reshard")
@staticmethod
def insert_recv_op(block, idx, tensor, src, dst, op_role):
......@@ -304,7 +306,7 @@ class Inserter:
op_type = 'recv_v2'
# use pair group
process_group = new_process_group([src, dst])
block._insert_op(idx,
recv_op = block._insert_op(idx,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
......@@ -317,6 +319,7 @@ class Inserter:
'op_role': op_role,
'dynamic_shape': True
})
recv_op._set_attr('op_namescope', "/auto_parallel/reshard")
@staticmethod
def insert_reset_lod_op(block, idx, X, Y, op_role):
......@@ -330,7 +333,7 @@ class Inserter:
dtype=X.dtype,
lod_level=X.lod_level)
block._insert_op(idx,
reset_op = block._insert_op(idx,
type="lod_reset",
inputs={
'X': X,
......@@ -338,6 +341,7 @@ class Inserter:
},
outputs={'Out': reset_lod_out},
attrs={'op_role': op_role})
reset_op._set_attr('op_namescope', "/auto_parallel/reshard")
return reset_lod_out
@staticmethod
......@@ -359,11 +363,12 @@ class Inserter:
type=tensors[0].type,
persistable=False,
stop_gradient=False)
block._insert_op(idx,
concat_op = block._insert_op(idx,
type='concat',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
concat_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
@staticmethod
......@@ -391,11 +396,12 @@ class Inserter:
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
block._insert_op(idx,
slice_op = block._insert_op(idx,
type="assign",
inputs=inputs,
outputs=outputs,
attrs=attrs)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
# use split once
......@@ -427,11 +433,12 @@ class Inserter:
for i in range(num_or_sections)
]
out = outs[cur_idx]
op = block._insert_op(idx,
split_op = block._insert_op(idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
split_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
# use slice
......@@ -449,12 +456,12 @@ class Inserter:
dtype=tensor.dtype,
type=tensor.type,
lod_level=tensor.lod_level)
block._insert_op(idx,
slice_op = block._insert_op(idx,
type="slice",
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
@staticmethod
......@@ -482,11 +489,12 @@ class Inserter:
persistable=False,
stop_gradient=False) for i in range(num_or_sections)
]
block._insert_op(idx,
split_op = block._insert_op(idx,
type="split",
inputs=inputs,
outputs={'Out': outs},
attrs=attrs)
split_op._set_attr('op_namescope', "/auto_parallel/reshard")
return outs
@staticmethod
......@@ -514,12 +522,13 @@ class Inserter:
attrs=attrs,
shape=[0],
op_type='fill_constant')
block._insert_op(idx,
fillconstant_op = block._insert_op(idx,
type='fill_constant',
inputs=inputs,
outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True
fillconstant_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
@staticmethod
......@@ -537,7 +546,8 @@ class Inserter:
fill_constant_out.stop_gradient = True
# insert c_allreduce_sum op
block._insert_op(idx + 1,
allreduce_op = block._insert_op(
idx + 1,
type="c_allreduce_sum",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
......@@ -546,13 +556,15 @@ class Inserter:
'use_calc_stream': True,
'op_role': op_role
})
allreduce_op._set_attr('op_namescope', "/auto_parallel/reshard")
# insert c_sync_calc_stream op
block._insert_op(idx + 2,
sync_calc_op = block._insert_op(
idx + 2,
type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role})
sync_calc_op._set_attr('op_namescope', "/auto_parallel/reshard")
idx_offset = 3
# insert c_allgather op
......@@ -569,7 +581,7 @@ class Inserter:
type=tensor.type,
persistable=False,
stop_gradient=False)
block._insert_op(idx + idx_offset,
allgather_op = block._insert_op(idx + idx_offset,
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [allgather_out]},
......@@ -579,6 +591,7 @@ class Inserter:
'nranks': group.nranks,
'op_role': op_role
})
allgather_op._set_attr('op_namescope', "/auto_parallel/reshard")
idx_offset += 1
# insert split op
......
......@@ -26,6 +26,7 @@ from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_fp32_input, _k
from paddle.fluid.contrib.mixed_precision.fp16_utils import _valid_types, find_true_post_op, find_true_prev_op
from paddle.fluid.contrib.mixed_precision.fp16_utils import _is_in_black_varnames, _dtype_to_str, _rename_arg
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from ..auto_parallel.utils import is_forward_op, is_backward_op, is_loss_op
world_process_group = get_world_process_group()
......@@ -222,21 +223,33 @@ class AMPState(object):
loss_op = get_loss_op(self._block)
loss_op_index = find_op_index(self._block.desc, loss_op.desc)
appended_grad_times = 0
idx = loss_op_index + 1
while idx < len(ops):
num_cast_ops = 0
grad_op = ops[idx]
# NOTE: the map in `grad_var_to_var` may be changed when the var is casted,
# which will affect the dist_op to insert allreduce_sum op.
op_dist_attr = dist_context.get_op_dist_attr_for_program(grad_op)
if is_backward_op(grad_op) and (is_forward_op(ops[idx - 1])
or is_loss_op(ops[idx - 1])):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
grad_op_orig_id = grad_op.desc.original_id()
dist_op_context = dist_context.dist_op_context
if grad_op_orig_id in dist_op_context.grad_op_id_to_op_id:
if self._is_fp16_op(grad_op_orig_id) == False: # fp32
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32, dist_context)
core.VarDesc.VarType.FP32, dist_context,
appended_grad_times)
elif self._is_fp16_op(grad_op_orig_id) == True: # fp16
num_cast_ops = self._insert_cast_op_backward(
grad_op, idx, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dist_context)
core.VarDesc.VarType.FP16, dist_context,
appended_grad_times)
elif grad_op.type == "sum":
in_var_name = grad_op.desc.input_arg_names()[0]
src_dtype = self._block.var(in_var_name).dtype
......@@ -258,7 +271,7 @@ class AMPState(object):
_update_backward_cast_ops(params_grads, dist_context)
def _insert_cast_op_backward(self, grad_op, idx, src_dtype, dst_dtype,
dist_context):
dist_context, appended_grad_times):
""" only for backward cast """
def _keep_fp32_input(op, in_name):
......@@ -328,7 +341,10 @@ class AMPState(object):
grad_op)
fwd_cast_name = self._var_name_dict[fwd_op_id][
out_var_name_prefix]
cast_name = fwd_cast_name + "@GRAD"
suffix = ""
if "@RENAME" in out_var_name:
suffix = out_var_name[out_var_name.find("@RENAME"):]
cast_name = fwd_cast_name + "@GRAD" + suffix
cast_var = self._block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dst_dtype:
grad_op.desc._rename_output(out_var_name, cast_name)
......@@ -347,6 +363,8 @@ class AMPState(object):
stop_gradient=out_var.stop_gradient)
set_var_dist_attr(dist_context, cast_var,
ref_mapping, ref_mesh)
dist_op_context.grad_var_to_var[
appended_grad_times][cast_name] = fwd_cast_name
cast_op = self._block._insert_op(
idx + 1,
......
......@@ -45,6 +45,7 @@ class DataParallelOptimizationPass(PassBase):
# NOTE not use depence on loss and param_grads
self.set_attr("dist_context", None)
self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict()
......@@ -71,6 +72,7 @@ class DataParallelOptimizationPass(PassBase):
self.dist_context = self.get_attr("dist_context")
self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding")
with paddle.static.program_guard(main_program, startup_program):
self._analyze_program()
......@@ -224,7 +226,8 @@ class DataParallelOptimizationPass(PassBase):
num_dp_comm_stream = len(set(self._group_to_grad_name_map.keys()))
if num_dp_comm_stream > __max_stream_num_allow__:
return False
if self.use_sharding:
return False
return True
def _comms_overlap_calc(self):
......
......@@ -151,13 +151,14 @@ class RecomputeState(ProgramStats):
# modify dropout op's desc
self._ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name])
cur_op.desc.remove_attr("fix_seed")
cur_op.desc.remove_attr("seed")
cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed")
cur_op_dist_attr.set_input_dist_attr(seed_var.name,
seed_var_dist_attr)
self._block._sync_with_cpp()
op_idx += 2
self._block._sync_with_cpp()
def _find_op_index(block, cur_op):
for idx in range(block.desc.op_size()):
......@@ -339,12 +340,13 @@ class RecomputePass(PassBase):
grad_op = ops[i]
# remove some attrs of dropout_grad op's desc
if grad_op.type == "dropout_grad":
grad_op.desc.remove_attr("fix_seed")
grad_op.desc.remove_attr("seed")
main_block._sync_with_cpp()
grad_op._remove_attr("fix_seed")
grad_op._remove_attr("seed")
# rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict:
if key not in grad_op.input_arg_names + grad_op.output_arg_names:
continue
self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key])
......@@ -358,11 +360,11 @@ class RecomputePass(PassBase):
idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1]
for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx)
rc_op = main_block._insert_op_without_sync(idx,
type='nop')
rc_desc = rc_op.desc
rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id())
rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
op_desc.original_id())
......@@ -371,7 +373,6 @@ class RecomputePass(PassBase):
var_name_dict)
ckpt_ops_dict[fwd_op_id][0] = False
main_block._sync_with_cpp()
main_program._sync_with_cpp()
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from functools import reduce
from collections import OrderedDict
from collections import OrderedDict, defaultdict
import numpy as np
import paddle
......@@ -27,10 +27,7 @@ from paddle.distributed.auto_parallel.utils import _get_comm_group, naive_set_di
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
_skip_ops = [
'create_py_reader', 'create_double_buffer_reader', 'read', 'slice', 'split',
'assign', "send_v2"
]
_skip_ops = ['create_py_reader', 'create_double_buffer_reader', 'read']
# update here to support new optimizers
_supported_optimizer_type = [
"adam", "adamax", "adamw", "decayed_adagrad", "momentum", "dgc_momentum",
......@@ -38,6 +35,11 @@ _supported_optimizer_type = [
]
def _is_reshard_op(op):
return op.desc.has_attr("op_namescope") and \
"/auto_parallel/reshard" in op.desc.attr('op_namescope')
# NOTE we add the "auto_parallel" prefix to the pass in order to
# indicate that this pass should obey some constrains by auto_parallel
# for example all ops and vars should has dist attr before and after pass
......@@ -100,6 +102,10 @@ class ShardingPass(PassBase):
for op in main_block.ops:
if not _is_forward_op(op) or op.type in _skip_ops:
continue
# NOTE: there aren't dist_attr in the ops which reshard insert,
# and should be skip in sharding.
if _is_reshard_op(op):
continue
group = _inference_data_parallel_group_for_operator(
self.global_rank, op, self._dist_context)
if group is not None:
......@@ -187,8 +193,28 @@ class ShardingPass(PassBase):
if self._is_parameter_in_local_shard(param_name):
reversed_x.append(input_name)
# NOTE: When `reversed_x` is [], check_finite_and_unscale will be replaced by `fill_constant` op.
# The output of check_finite_and_unscale is be set False
if reversed_x:
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
else:
if op.type == "check_finite_and_unscale":
out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False)
main_block._insert_op_without_sync(
idx,
type="fill_constant",
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
"dtype": out_var.dtype,
"value": 0,
})
else:
main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp()
......@@ -359,6 +385,17 @@ class ShardingPass(PassBase):
else:
op._set_attr("ring_id", self.outer_dp_group.id)
# NOTE:
# var@GRAD = sum(var@GRAD@RENAME@0, var@GRAD@RENAME@1)
# If the var is not in local rank and it is output of many ops, or the var is renamed in another words,
# the sum op should be removed.
if _is_param_grad_sum_op(op, main_block):
out_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(out_name)
sharding_info = self.varname_to_sharding_info[base_name]
if not sharding_info.is_in_local_shard(base_name):
main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp()
def _shard_parameter(self, main_block, startup_block):
......@@ -606,6 +643,22 @@ def _is_param_grad_allreduce_op(op, block, dp_ring_ids):
return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block):
if not is_backward_op(op):
return False
if op.type != "sum":
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_forward_op(op):
return op.attr("op_role") == 0
......
......@@ -33,7 +33,7 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static()
batch_size = 1
batch_size = 2
batch_num = 10
hidden_size = 1024
sequence_len = 512
......@@ -133,10 +133,7 @@ def train(fetch):
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
fetches=fetches)
engine.fit(train_dataset, batch_size=batch_size, fetches=fetches)
# eval
eval_dataset = MyDataset(batch_size)
......
......@@ -67,15 +67,14 @@ def create_data_holder(batch_size):
def generate_model(strategy):
modeling.init_global()
modeling._global_process_mesh = list(
range(paddle.distributed.get_world_size()))
if strategy == "serial":
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = [0]
elif strategy == "mp":
modeling._global_parallel_strategy = "mp"
modeling._global_process_mesh = [0, 1]
elif strategy == "dp":
modeling._global_parallel_strategy = "dp"
modeling._global_process_mesh = [0, 1]
else:
raise ValueError("Only support serial, mp2 and dp2.")
......
......@@ -27,7 +27,6 @@ from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.hepler import ProgramHelper
from test_to_static import MLPLayer, MyDataset
......
......@@ -23,11 +23,12 @@ import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
import paddle.distributed.fleet as fleet
from paddle import LazyGuard
from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.hepler import ProgramHelper
from paddle.distributed.auto_parallel.helper import ProgramHelper
batch_size = 4
batch_num = 30
......@@ -158,5 +159,29 @@ class TestToStatic(unittest.TestCase):
engine.predict(dataset, batch_size=batch_size)
class TestLazyInit(unittest.TestCase):
def test_lazy_init(self):
with LazyGuard():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
metrics = paddle.metric.Accuracy()
loss = paddle.nn.CrossEntropyLoss()
inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels = InputSpec([batch_size], 'int64', 'label')
program_helper = ProgramHelper(mlp, loss, [metrics], [inputs], [labels])
program_helper.build_program(mode='train')
ops = program_helper.startup_program.block(0).ops
vars = program_helper.startup_program.block(0).vars
assert len(vars.keys()) == len(ops)
program_helper.reset()
if __name__ == "__main__":
unittest.main()
......@@ -914,12 +914,6 @@ class GPTForPretraining(nn.Layer):
initializer_range=0.02,
):
super(GPTForPretraining, self).__init__()
self.output_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(name="output_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.gpt = gpt
def forward(self,
......@@ -938,9 +932,45 @@ class GPTForPretraining(nn.Layer):
encoder_outputs, cached_kvs = outputs[:2]
else:
encoder_outputs = outputs
logits = paddle.matmul(encoder_outputs,
self.output_embeddings.weight,
transpose_y=True)
x = encoder_outputs
w = self.gpt.embeddings.word_embeddings.weight
mesh = _global_process_mesh
x_dims_mapping = [-1 for i in range(len(x.shape))]
w_dims_mapping = [-1 for i in range(len(w.shape))]
if _global_parallel_strategy == "pp":
mesh = PP_MESH_LIST[-1]
elif _global_parallel_strategy == "dp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
elif _global_parallel_strategy == "mp":
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
elif _global_parallel_strategy == "mp_pp":
mesh = MPPP_MESH_LIST[-1]
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)]
matmul = auto.shard_op(paddle.matmul,
dist_attr={
'process_mesh': mesh,
x: {
"dims_mapping": x_dims_mapping
},
w: {
"dims_mapping": w_dims_mapping
}
})
logits = matmul(x, w, transpose_y=True)
if use_cache:
return logits, cached_kvs
else:
......@@ -958,6 +988,26 @@ class GPTPretrainingCriterion(nn.Layer):
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
def forward(self, prediction_scores, masked_lm_labels, loss_mask):
mesh = _global_process_mesh
dims_mapping = [-1 for i in range(len(loss_mask.shape))]
if _global_parallel_strategy == "dp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
auto.shard_tensor(loss_mask,
dist_attr={
"process_mesh": mesh,
"dims_mapping": dims_mapping
})
masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2))
loss_mask = loss_mask.reshape([-1])
......
......@@ -178,7 +178,6 @@ class AutoPallelPassTestBase(DistPassTestBase):
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
......@@ -189,7 +188,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=clip)
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册