未验证 提交 2ea7a6a3 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel]organize dataloder in engine (#56788)

上级 97b09e81
...@@ -949,6 +949,7 @@ class Engine: ...@@ -949,6 +949,7 @@ class Engine:
... batch_size=64) ... batch_size=64)
""" """
self._mode = 'train' self._mode = 'train'
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size train_data, train_sample_split, batch_size
) )
...@@ -1011,14 +1012,14 @@ class Engine: ...@@ -1011,14 +1012,14 @@ class Engine:
logs = {} logs = {}
cbks.on_epoch_begin(epoch) cbks.on_epoch_begin(epoch)
for step, data in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
if auto_utils.use_new_executor(): if auto_utils.use_new_executor():
feeds = self._validate_feed(data) batches = self._validate_batch(batch)
else: else:
feeds = [{}] batches = [{}]
try: try:
for micro_feed in feeds: for micro_batch in batches:
with paddle.profiler.utils._nvprof_range( with paddle.profiler.utils._nvprof_range(
iter_id=step, iter_id=step,
start=nvprof_range[0], start=nvprof_range[0],
...@@ -1027,7 +1028,7 @@ class Engine: ...@@ -1027,7 +1028,7 @@ class Engine:
cbks.on_batch_begin('train', step, logs) cbks.on_batch_begin('train', step, logs)
outs = self._executor.run( outs = self._executor.run(
self.main_program, self.main_program,
feed=micro_feed, feed=micro_batch,
fetch_list=fetch_names, fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache, use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy, return_numpy=self._strategy.return_numpy,
...@@ -1136,12 +1137,13 @@ class Engine: ...@@ -1136,12 +1137,13 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size valid_data, valid_sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
micro_batch_size = self._validate_batch_size(batch_size)
valid_dataloader = self._prepare_dataloader_from_generator( valid_dataloader = self._prepare_dataloader_from_generator(
dataset=valid_data, dataset=valid_data,
capacity=70, capacity=70,
...@@ -1243,12 +1245,13 @@ class Engine: ...@@ -1243,12 +1245,13 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size test_data, test_sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
micro_batch_size = self._validate_batch_size(batch_size)
test_dataloader = self._prepare_dataloader_from_generator( test_dataloader = self._prepare_dataloader_from_generator(
dataset=test_data, dataset=test_data,
capacity=70, capacity=70,
...@@ -1304,19 +1307,21 @@ class Engine: ...@@ -1304,19 +1307,21 @@ class Engine:
): ):
if mode is not None: if mode is not None:
self.to_mode(mode) self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
batch_size = self._validate_batch_size(batch_size)
dataloader = self._prepare_dataloader( dataloader = self._prepare_dataloader(
dataset, dataset,
return_list=False, return_list=False,
batch_size=micro_batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last, drop_last=drop_last,
collate_fn=collate_fn, collate_fn=collate_fn,
...@@ -1351,12 +1356,13 @@ class Engine: ...@@ -1351,12 +1356,13 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
self._switch_mode(self._mode) self._switch_mode(self._mode)
micro_batch_size = self._validate_batch_size(batch_size)
dataloader = self._prepare_dataloader_from_generator( dataloader = self._prepare_dataloader_from_generator(
dataset=dataset, dataset=dataset,
capacity=capacity, capacity=capacity,
...@@ -1582,33 +1588,48 @@ class Engine: ...@@ -1582,33 +1588,48 @@ class Engine:
def _validate_batch_size(self, batch_size): def _validate_batch_size(self, batch_size):
if batch_size is None: if batch_size is None:
return None return None
if self._strategy.pipeline.enable and auto_utils.use_new_executor():
return batch_size
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_feed(self, feed): if auto_utils.use_new_executor():
if feed is None: assert (
len(set(self._dp_world_sizes)) == 1
), "DistributedBatchSampler only support one data parallel group, but got [{}] different data parallel groups".format(
len(set(self._dp_world_sizes))
)
assert (
batch_size % self._dp_world_sizes[0] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self._dp_world_sizes[0])
)
return batch_size // self._dp_world_sizes[0]
else:
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_batch(self, batch):
if batch is None:
return [None] return [None]
# pp with schedule or navie-pp
if self._strategy.pipeline.enable or self._acc_steps == 1: if self._strategy.pipeline.enable or self._acc_steps == 1:
return feed # pp with schedule or navie-pp
return batch
# split feed data with gradient_merge k_steps else:
feed_names = [] # split feed data with gradient_merge k_steps
split_feeds = [] feed_names = []
for feed_name, cur_feed in feed[0].items(): split_batches = []
feed_names.append(feed_name) for feed_name, cur_feed in batch[0].items():
split_feeds.append(np.split(np.array(cur_feed), self._acc_steps, 0)) feed_names.append(feed_name)
micro_feeds = [] split_batches.append(
for i in range(self._acc_steps): np.split(np.array(cur_feed), self._acc_steps, 0)
split_feed = [sf[i] for sf in split_feeds] )
micro_feeds.append(dict(zip(feed_names, split_feed))) baches = []
return micro_feeds for i in range(self._acc_steps):
micro_batch = [split_batch[i] for split_batch in split_batches]
baches.append(dict(zip(feed_names, micro_batch)))
return baches
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = auto_utils.to_list(specs) specs = auto_utils.to_list(specs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册