未验证 提交 df470954 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] split data in dataloader (#42838)

* slice data in dist_loader & flag to scale grad

* bug fix

* update unittest

* enable static
上级 16ce33b0
...@@ -115,6 +115,9 @@ class DistributedContext: ...@@ -115,6 +115,9 @@ class DistributedContext:
self._is_initialized = False self._is_initialized = False
# flag whether scale gradient with dp size
self._gradient_scale = True
@property @property
def serial_main_program(self): def serial_main_program(self):
return self._serial_main_program return self._serial_main_program
...@@ -187,6 +190,14 @@ class DistributedContext: ...@@ -187,6 +190,14 @@ class DistributedContext:
return len(self._dist_tensors_for_program) or len( return len(self._dist_tensors_for_program) or len(
self._dist_ops_for_program) self._dist_ops_for_program)
@property
def gradient_scale(self):
return self._gradient_scale
@gradient_scale.setter
def gradient_scale(self, gs):
self._gradient_scale = gs
def initialize(self): def initialize(self):
if not self._is_initialized: if not self._is_initialized:
self._serial_main_program = self._original_serial_main_program.clone( self._serial_main_program = self._original_serial_main_program.clone(
......
...@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.data_parallel_world_size = data_parallel_world_size self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.drop_lost = drop_last self.drop_lost = drop_last
if data_parallel_world_size is not None: if data_parallel_world_size is not None and batch_size is not None:
assert batch_size % data_parallel_world_size == 0 assert batch_size % data_parallel_world_size == 0
@abc.abstractmethod @abc.abstractmethod
...@@ -56,12 +56,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -56,12 +56,12 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
steps_per_epoch=None, steps_per_epoch=None,
data_parallel_world_size=None, data_parallel_world_size=None,
data_parallel_rank=None, data_parallel_rank=None,
drop_last=False, drop_last=False):
sample_generator=True):
self.feed_list = feed_list self.feed_list = feed_list
self.places = places self.places = places
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
self._sample_generator = sample_generator self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
super(NonIterableGeneratorLoader, self).__init__( super(NonIterableGeneratorLoader, self).__init__(
dataset, batch_size, epochs, data_parallel_world_size, dataset, batch_size, epochs, data_parallel_world_size,
...@@ -85,7 +85,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -85,7 +85,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
if self.steps_per_epoch is not None: if self.steps_per_epoch is not None:
return self.steps_per_epoch return self.steps_per_epoch
try: try:
steps_per_epoch = len(self.dataset) // self.batch_size if self.batch_size is None:
steps_per_epoch = len(self.dataset)
else:
steps_per_epoch = len(self.dataset) // self.batch_size
except: except:
raise ValueError( raise ValueError(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class." "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
...@@ -102,17 +105,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -102,17 +105,28 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
for idx in range(len(data)): for idx in range(len(data)):
batch_data[idx].append(data[idx]) batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0: if (step + 1) % self.batch_size == 0:
yield batch_data partial_data = []
for d in batch_data:
array = np.array(d)
partial_data.append(
np.split(array, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
batch_data = None batch_data = None
def batch_data_generator(): def batch_data_generator():
for data in self.dataset: for data in self.dataset:
data = flatten(data) data = flatten(data)
yield data partial_data = []
for d in data:
assert d.shape[0] % self.dp_world_size == 0, \
"Please padding dataset with data parallel size"
partial_data.append(
np.split(d, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
dataloader = paddle.fluid.io.DataLoader.from_generator( dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False) feed_list=self.feed_list, capacity=70, iterable=False)
if self._sample_generator: if self.batch_size is not None:
dataloader.set_batch_generator(sample_data_generator, self.places) dataloader.set_batch_generator(sample_data_generator, self.places)
else: else:
dataloader.set_batch_generator(batch_data_generator, self.places) dataloader.set_batch_generator(batch_data_generator, self.places)
......
...@@ -45,8 +45,6 @@ from .utils import print_program_with_dist_attr, to_list ...@@ -45,8 +45,6 @@ from .utils import print_program_with_dist_attr, to_list
from .process_group import get_all_process_groups, get_world_process_group from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
paddle.enable_static()
class Engine: class Engine:
def __init__(self, def __init__(self,
...@@ -82,6 +80,7 @@ class Engine: ...@@ -82,6 +80,7 @@ class Engine:
def prepare(self, def prepare(self,
optimizer=None, optimizer=None,
loss=None, loss=None,
gradient_scale=True,
metrics=None, metrics=None,
mode='train', mode='train',
all_ranks=False): all_ranks=False):
...@@ -90,6 +89,7 @@ class Engine: ...@@ -90,6 +89,7 @@ class Engine:
self._loss = loss self._loss = loss
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._mode = mode self._mode = mode
self._gradient_scale = gradient_scale
# Build forward program # Build forward program
self._build(mode) self._build(mode)
# Do the planning process # Do the planning process
...@@ -149,6 +149,7 @@ class Engine: ...@@ -149,6 +149,7 @@ class Engine:
self._serial_main_progs[mode], self._serial_startup_progs[mode], self._serial_main_progs[mode], self._serial_startup_progs[mode],
self._optimizer, losses, self._feed_vars[mode], self._optimizer, losses, self._feed_vars[mode],
self._fetch_vars[mode], self.strategy) self._fetch_vars[mode], self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode): def _initialize(self, mode):
if self._nranks > 1: if self._nranks > 1:
...@@ -183,14 +184,14 @@ class Engine: ...@@ -183,14 +184,14 @@ class Engine:
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True, return_numpy=True):
sample_generator=True):
# TODO: callbacks # TODO: callbacks
# TODO: evaluate after training # TODO: evaluate after training
self.mode = 'train' self.mode = 'train'
assert self.mode in self._dist_main_progs, "train model is not ready, please call `engine.prepare(mode='train')` first." assert self.mode in self._dist_main_progs, \
train_dataloader = self._create_dataloader( "train model is not ready, please call `engine.prepare(mode='train')` first."
train_data, batch_size, epochs, steps_per_epoch, sample_generator) train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
outputs = [] outputs = []
for epoch in range(epochs): for epoch in range(epochs):
...@@ -209,12 +210,11 @@ class Engine: ...@@ -209,12 +210,11 @@ class Engine:
eval_data, eval_data,
batch_size=1, batch_size=1,
use_program_cache=False, use_program_cache=False,
return_numpy=True, return_numpy=True):
sample_generator=True):
self.mode = 'eval' self.mode = 'eval'
assert self.mode in self._dist_main_progs, "eval model is not ready, please call `engine.prepare(mode='eval')` first." assert self.mode in self._dist_main_progs, \
eval_dataloader = self._create_dataloader( "eval model is not ready, please call `engine.prepare(mode='eval')` first."
eval_data, batch_size, sample_generator=sample_generator) eval_dataloader = self._create_dataloader(eval_data, batch_size)
outputs = [] outputs = []
for step, data in enumerate(eval_dataloader): for step, data in enumerate(eval_dataloader):
...@@ -228,12 +228,11 @@ class Engine: ...@@ -228,12 +228,11 @@ class Engine:
test_data, test_data,
batch_size=1, batch_size=1,
use_program_cache=False, use_program_cache=False,
return_numpy=True, return_numpy=True):
sample_generator=True):
self.mode = 'predict' self.mode = 'predict'
assert self.mode in self._dist_main_progs, "predict model is not ready, please call `engine.prepare(mode='predict')` first." assert self.mode in self._dist_main_progs, \
test_dataloader = self._create_dataloader( "predict model is not ready, please call `engine.prepare(mode='predict')` first."
test_data, batch_size, sample_generator=sample_generator) test_dataloader = self._create_dataloader(test_data, batch_size)
outputs = [] outputs = []
for step, data in enumerate(test_dataloader): for step, data in enumerate(test_dataloader):
...@@ -304,21 +303,30 @@ class Engine: ...@@ -304,21 +303,30 @@ class Engine:
dataset, dataset,
batch_size, batch_size,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None):
sample_generator=True):
feed_list = self._feed_vars[self.mode]["inputs"] + self._feed_vars[
self.mode]["labels"]
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
serial_main_prog = self._serial_main_progs[self.mode]
serial_main_block = serial_main_prog.global_block() # get feed_list from dist_program
inputs_var = self._feed_vars[self.mode]["inputs"]
labels_var = self._feed_vars[self.mode]["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name])
dp_world_size, dp_rank = self._get_data_parallel_info(feed_list[0],
dist_context)
# remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops) op_size = len(dist_main_block.ops)
if dist_main_block.ops[0].type == 'create_py_reader': if dist_main_block.ops[0].type == 'create_py_reader':
op_size -= 3 op_size -= 3
for _ in range(3): for _ in range(3):
dist_main_block._remove_op(0, sync=False) dist_main_block._remove_op(0, sync=False)
# insert read op at the end of program
places = paddle.static.cuda_places() places = paddle.static.cuda_places()
with fluid.program_guard(dist_main_prog, dist_startup_prog): with fluid.program_guard(dist_main_prog, dist_startup_prog):
dataloader = NonIterableGeneratorLoader( dataloader = NonIterableGeneratorLoader(
...@@ -328,7 +336,10 @@ class Engine: ...@@ -328,7 +336,10 @@ class Engine:
batch_size, batch_size,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
sample_generator=sample_generator) data_parallel_world_size=dp_world_size,
data_parallel_rank=dp_rank)
# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops) new_op_size = len(dist_main_block.ops)
for _ in range(new_op_size - 1, op_size - 1, -1): for _ in range(new_op_size - 1, op_size - 1, -1):
op = dist_main_block.ops[new_op_size - 1] op = dist_main_block.ops[new_op_size - 1]
...@@ -337,17 +348,6 @@ class Engine: ...@@ -337,17 +348,6 @@ class Engine:
new_op = Operator( new_op = Operator(
dist_main_block, new_op_desc, type=new_op_desc.type()) dist_main_block, new_op_desc, type=new_op_desc.type())
dist_main_block.ops.insert(0, new_op) dist_main_block.ops.insert(0, new_op)
for in_name in new_op.input_arg_names:
if "lod_tensor_blocking_queue" in in_name:
continue
if in_name not in dist_main_block.vars:
in_var = serial_main_block._var_recursive(in_name)
dist_main_block._clone_variable(in_var, in_var.persistable)
for out_name in new_op.output_arg_names:
if out_name not in dist_main_block.vars:
out_var = serial_main_block._var_recursive(out_name)
dist_main_block._clone_variable(out_var,
out_var.persistable)
dist_op = DistributedOperator(new_op) dist_op = DistributedOperator(new_op)
dist_context.add_dist_op_for_program(dist_op) dist_context.add_dist_op_for_program(dist_op)
for _ in range(new_op_size - op_size): for _ in range(new_op_size - op_size):
...@@ -387,6 +387,29 @@ class Engine: ...@@ -387,6 +387,29 @@ class Engine:
return var return var
def _get_data_parallel_info(self, var, dist_context):
# get data parallel world size and current data parallel rank
from .utils import _get_comm_group, _get_corresponding_rank
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
process_mesh = tensor_dist_attr.process_mesh
dims_mapping = tensor_dist_attr.dims_mapping
if self._cur_rank not in process_mesh.processes:
rank_id = _get_corresponding_rank(dist_context, process_mesh,
self._cur_rank)
else:
rank_id = self._cur_rank
batch_size_axis = dims_mapping[0]
if batch_size_axis > -1 and process_mesh.topology[batch_size_axis] > 1:
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
return len(group_ranks), group_ranks.index(rank_id)
return None, None
def save(self, path, training=True, mode=None): def save(self, path, training=True, mode=None):
if not mode: if not mode:
mode = self.mode mode = self.mode
......
...@@ -457,6 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -457,6 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
if len(allreduce_vars) > 0: if len(allreduce_vars) > 0:
for varname in allreduce_vars: for varname in allreduce_vars:
added_ops = []
grad_var = main_block.var(varname) grad_var = main_block.var(varname)
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(
...@@ -468,20 +469,23 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -468,20 +469,23 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
added_ops.append(allreduce_op)
scale_op = main_block.append_op( if ctx.gradient_scale:
type='scale', scale_op = main_block.append_op(
inputs={'X': grad_var}, type='scale',
outputs={'Out': grad_var}, inputs={'X': grad_var},
attrs={ outputs={'Out': grad_var},
'scale': 1.0 / dp_degree, attrs={
OP_ROLE_KEY: OpRole.Backward 'scale': 1.0 / dp_degree,
}) OP_ROLE_KEY: OpRole.Backward
})
added_ops.append(scale_op)
dims_mapping = ctx.get_tensor_dist_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping grad_var).dims_mapping
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in added_ops:
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, op_attr.set_output_dims_mapping(grad_var.name,
......
...@@ -67,8 +67,6 @@ class Parallelizer: ...@@ -67,8 +67,6 @@ class Parallelizer:
serial_optimizer, dist_params_grads) serial_optimizer, dist_params_grads)
# Do reshard process # Do reshard process
set_grad_var_shape(dist_main_prog, self._dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog,
self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads) self._dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
...@@ -84,8 +82,6 @@ class Parallelizer: ...@@ -84,8 +82,6 @@ class Parallelizer:
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, []) serial_main_program, serial_startup_program, [])
# Do reshard process # Do reshard process
make_data_unshard(dist_main_prog, dist_startup_prog,
self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1) self._dist_context, [], 1)
resharder.reshard() resharder.reshard()
......
...@@ -513,6 +513,8 @@ class Remover: ...@@ -513,6 +513,8 @@ class Remover:
idx += 1 idx += 1
for var in remove_vars: for var in remove_vars:
if block.vars[var].is_data:
continue
block._remove_var(var) block._remove_var(var)
@staticmethod @staticmethod
......
...@@ -127,8 +127,7 @@ def train(): ...@@ -127,8 +127,7 @@ def train():
engine.prepare(optimizer, loss) engine.prepare(optimizer, loss)
engine.fit(dataset, engine.fit(dataset,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size, steps_per_epoch=batch_num * batch_size)
sample_generator=True)
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.prepare(optimizer, loss, mode='eval') engine.prepare(optimizer, loss, mode='eval')
......
...@@ -117,6 +117,7 @@ def loss_func(eq_loss, bc_u, bc_value): ...@@ -117,6 +117,7 @@ def loss_func(eq_loss, bc_u, bc_value):
def main(): def main():
paddle.enable_static()
# dataset # dataset
train_dataset = LaplaceDataset(10) train_dataset = LaplaceDataset(10)
# optimizer # optimizer
...@@ -140,7 +141,7 @@ def main(): ...@@ -140,7 +141,7 @@ def main():
labels_spec=labels_spec, labels_spec=labels_spec,
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func) engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, sample_generator=False) res = engine.fit(train_dataset, batch_size=None)
dist_context = engine.dist_context dist_context = engine.dist_context
block = engine.main_program.global_block() block = engine.main_program.global_block()
......
...@@ -53,6 +53,10 @@ class TestPlannerReLaunch(unittest.TestCase): ...@@ -53,6 +53,10 @@ class TestPlannerReLaunch(unittest.TestCase):
"auto_parallel_rank_mapping.json") "auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path): if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path) os.remove(rank_mapping_json_path)
files_path = [path for path in os.listdir('.') if '.pkl' in path]
for path in files_path:
if os.path.exists(path):
os.remove(path)
log_path = os.path.join(file_dir, "log") log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path): if os.path.exists(log_path):
shutil.rmtree(log_path) shutil.rmtree(log_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册