未验证 提交 df470954 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] split data in dataloader (#42838)

* slice data in dist_loader & flag to scale grad

* bug fix

* update unittest

* enable static
上级 16ce33b0
......@@ -115,6 +115,9 @@ class DistributedContext:
self._is_initialized = False
# flag whether scale gradient with dp size
self._gradient_scale = True
@property
def serial_main_program(self):
return self._serial_main_program
......@@ -187,6 +190,14 @@ class DistributedContext:
return len(self._dist_tensors_for_program) or len(
self._dist_ops_for_program)
@property
def gradient_scale(self):
return self._gradient_scale
@gradient_scale.setter
def gradient_scale(self, gs):
self._gradient_scale = gs
def initialize(self):
if not self._is_initialized:
self._serial_main_program = self._original_serial_main_program.clone(
......
......@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank
self.drop_lost = drop_last
if data_parallel_world_size is not None:
if data_parallel_world_size is not None and batch_size is not None:
assert batch_size % data_parallel_world_size == 0
@abc.abstractmethod
......@@ -56,12 +56,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
steps_per_epoch=None,
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False,
sample_generator=True):
drop_last=False):
self.feed_list = feed_list
self.places = places
self.steps_per_epoch = steps_per_epoch
self._sample_generator = sample_generator
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
super(NonIterableGeneratorLoader, self).__init__(
dataset, batch_size, epochs, data_parallel_world_size,
......@@ -85,6 +85,9 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
if self.steps_per_epoch is not None:
return self.steps_per_epoch
try:
if self.batch_size is None:
steps_per_epoch = len(self.dataset)
else:
steps_per_epoch = len(self.dataset) // self.batch_size
except:
raise ValueError(
......@@ -102,17 +105,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
for idx in range(len(data)):
batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0:
yield batch_data
partial_data = []
for d in batch_data:
array = np.array(d)
partial_data.append(
np.split(array, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
batch_data = None
def batch_data_generator():
for data in self.dataset:
data = flatten(data)
yield data
partial_data = []
for d in data:
assert d.shape[0] % self.dp_world_size == 0, \
"Please padding dataset with data parallel size"
partial_data.append(
np.split(d, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
if self._sample_generator:
if self.batch_size is not None:
dataloader.set_batch_generator(sample_data_generator, self.places)
else:
dataloader.set_batch_generator(batch_data_generator, self.places)
......
......@@ -45,8 +45,6 @@ from .utils import print_program_with_dist_attr, to_list
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
paddle.enable_static()
class Engine:
def __init__(self,
......@@ -82,6 +80,7 @@ class Engine:
def prepare(self,
optimizer=None,
loss=None,
gradient_scale=True,
metrics=None,
mode='train',
all_ranks=False):
......@@ -90,6 +89,7 @@ class Engine:
self._loss = loss
self._metrics = to_list(metrics)
self._mode = mode
self._gradient_scale = gradient_scale
# Build forward program
self._build(mode)
# Do the planning process
......@@ -149,6 +149,7 @@ class Engine:
self._serial_main_progs[mode], self._serial_startup_progs[mode],
self._optimizer, losses, self._feed_vars[mode],
self._fetch_vars[mode], self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode):
if self._nranks > 1:
......@@ -183,14 +184,14 @@ class Engine:
epochs=1,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
return_numpy=True):
# TODO: callbacks
# TODO: evaluate after training
self.mode = 'train'
assert self.mode in self._dist_main_progs, "train model is not ready, please call `engine.prepare(mode='train')` first."
train_dataloader = self._create_dataloader(
train_data, batch_size, epochs, steps_per_epoch, sample_generator)
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare(mode='train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
outputs = []
for epoch in range(epochs):
......@@ -209,12 +210,11 @@ class Engine:
eval_data,
batch_size=1,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
return_numpy=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, "eval model is not ready, please call `engine.prepare(mode='eval')` first."
eval_dataloader = self._create_dataloader(
eval_data, batch_size, sample_generator=sample_generator)
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare(mode='eval')` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
outputs = []
for step, data in enumerate(eval_dataloader):
......@@ -228,12 +228,11 @@ class Engine:
test_data,
batch_size=1,
use_program_cache=False,
return_numpy=True,
sample_generator=True):
return_numpy=True):
self.mode = 'predict'
assert self.mode in self._dist_main_progs, "predict model is not ready, please call `engine.prepare(mode='predict')` first."
test_dataloader = self._create_dataloader(
test_data, batch_size, sample_generator=sample_generator)
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare(mode='predict')` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
outputs = []
for step, data in enumerate(test_dataloader):
......@@ -304,21 +303,30 @@ class Engine:
dataset,
batch_size,
epochs=1,
steps_per_epoch=None,
sample_generator=True):
feed_list = self._feed_vars[self.mode]["inputs"] + self._feed_vars[
self.mode]["labels"]
steps_per_epoch=None):
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block()
serial_main_prog = self._serial_main_progs[self.mode]
serial_main_block = serial_main_prog.global_block()
# get feed_list from dist_program
inputs_var = self._feed_vars[self.mode]["inputs"]
labels_var = self._feed_vars[self.mode]["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name])
dp_world_size, dp_rank = self._get_data_parallel_info(feed_list[0],
dist_context)
# remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops)
if dist_main_block.ops[0].type == 'create_py_reader':
op_size -= 3
for _ in range(3):
dist_main_block._remove_op(0, sync=False)
# insert read op at the end of program
places = paddle.static.cuda_places()
with fluid.program_guard(dist_main_prog, dist_startup_prog):
dataloader = NonIterableGeneratorLoader(
......@@ -328,7 +336,10 @@ class Engine:
batch_size,
epochs,
steps_per_epoch,
sample_generator=sample_generator)
data_parallel_world_size=dp_world_size,
data_parallel_rank=dp_rank)
# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops)
for _ in range(new_op_size - 1, op_size - 1, -1):
op = dist_main_block.ops[new_op_size - 1]
......@@ -337,17 +348,6 @@ class Engine:
new_op = Operator(
dist_main_block, new_op_desc, type=new_op_desc.type())
dist_main_block.ops.insert(0, new_op)
for in_name in new_op.input_arg_names:
if "lod_tensor_blocking_queue" in in_name:
continue
if in_name not in dist_main_block.vars:
in_var = serial_main_block._var_recursive(in_name)
dist_main_block._clone_variable(in_var, in_var.persistable)
for out_name in new_op.output_arg_names:
if out_name not in dist_main_block.vars:
out_var = serial_main_block._var_recursive(out_name)
dist_main_block._clone_variable(out_var,
out_var.persistable)
dist_op = DistributedOperator(new_op)
dist_context.add_dist_op_for_program(dist_op)
for _ in range(new_op_size - op_size):
......@@ -387,6 +387,29 @@ class Engine:
return var
def _get_data_parallel_info(self, var, dist_context):
# get data parallel world size and current data parallel rank
from .utils import _get_comm_group, _get_corresponding_rank
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if self._cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh,
self._cur_rank)
else:
rank_id = self._cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
return len(group_ranks), group_ranks.index(rank_id)
return None, None
def save(self, path, training=True, mode=None):
if not mode:
mode = self.mode
......
......@@ -457,6 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if len(allreduce_vars) > 0:
for varname in allreduce_vars:
added_ops = []
grad_var = main_block.var(varname)
allreduce_op = main_block.append_op(
......@@ -468,7 +469,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(allreduce_op)
if ctx.gradient_scale:
scale_op = main_block.append_op(
type='scale',
inputs={'X': grad_var},
......@@ -477,11 +480,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
for op in added_ops:
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name,
......
......@@ -67,8 +67,6 @@ class Parallelizer:
serial_optimizer, dist_params_grads)
# Do reshard process
set_grad_var_shape(dist_main_prog, self._dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog,
self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads)
resharder.reshard()
......@@ -84,8 +82,6 @@ class Parallelizer:
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, [])
# Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog,
self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1)
resharder.reshard()
......
......@@ -513,6 +513,8 @@ class Remover:
idx += 1
for var in remove_vars:
if block.vars[var].is_data:
continue
block._remove_var(var)
@staticmethod
......
......@@ -127,8 +127,7 @@ def train():
engine.prepare(optimizer, loss)
engine.fit(dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
sample_generator=True)
steps_per_epoch=batch_num * batch_size)
eval_dataset = MyDataset(batch_size)
engine.prepare(optimizer, loss, mode='eval')
......
......@@ -117,6 +117,7 @@ def loss_func(eq_loss, bc_u, bc_value):
def main():
paddle.enable_static()
# dataset
train_dataset = LaplaceDataset(10)
# optimizer
......@@ -140,7 +141,7 @@ def main():
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, sample_generator=False)
res = engine.fit(train_dataset, batch_size=None)
dist_context = engine.dist_context
block = engine.main_program.global_block()
......
......@@ -53,6 +53,10 @@ class TestPlannerReLaunch(unittest.TestCase):
"auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
files_path = [path for path in os.listdir('.') if '.pkl' in path]
for path in files_path:
if os.path.exists(path):
os.remove(path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册