未验证 提交 2ea7a6a3 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel]organize dataloder in engine (#56788)

上级 97b09e81
......@@ -949,6 +949,7 @@ class Engine:
... batch_size=64)
"""
self._mode = 'train'
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size
)
......@@ -1011,14 +1012,14 @@ class Engine:
logs = {}
cbks.on_epoch_begin(epoch)
for step, data in enumerate(train_dataloader):
for step, batch in enumerate(train_dataloader):
if auto_utils.use_new_executor():
feeds = self._validate_feed(data)
batches = self._validate_batch(batch)
else:
feeds = [{}]
batches = [{}]
try:
for micro_feed in feeds:
for micro_batch in batches:
with paddle.profiler.utils._nvprof_range(
iter_id=step,
start=nvprof_range[0],
......@@ -1027,7 +1028,7 @@ class Engine:
cbks.on_batch_begin('train', step, logs)
outs = self._executor.run(
self.main_program,
feed=micro_feed,
feed=micro_batch,
fetch_list=fetch_names,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy,
......@@ -1136,12 +1137,13 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
micro_batch_size = self._validate_batch_size(batch_size)
valid_dataloader = self._prepare_dataloader_from_generator(
dataset=valid_data,
capacity=70,
......@@ -1243,12 +1245,13 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
micro_batch_size = self._validate_batch_size(batch_size)
test_dataloader = self._prepare_dataloader_from_generator(
dataset=test_data,
capacity=70,
......@@ -1304,19 +1307,21 @@ class Engine:
):
if mode is not None:
self.to_mode(mode)
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
batch_size = self._validate_batch_size(batch_size)
dataloader = self._prepare_dataloader(
dataset,
return_list=False,
batch_size=micro_batch_size,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
......@@ -1351,12 +1356,13 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
self._switch_mode(self._mode)
micro_batch_size = self._validate_batch_size(batch_size)
dataloader = self._prepare_dataloader_from_generator(
dataset=dataset,
capacity=capacity,
......@@ -1582,8 +1588,20 @@ class Engine:
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
if self._strategy.pipeline.enable and auto_utils.use_new_executor():
return batch_size
if auto_utils.use_new_executor():
assert (
len(set(self._dp_world_sizes)) == 1
), "DistributedBatchSampler only support one data parallel group, but got [{}] different data parallel groups".format(
len(set(self._dp_world_sizes))
)
assert (
batch_size % self._dp_world_sizes[0] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self._dp_world_sizes[0])
)
return batch_size // self._dp_world_sizes[0]
else:
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
......@@ -1591,24 +1609,27 @@ class Engine:
)
return batch_size // self._acc_steps
def _validate_feed(self, feed):
if feed is None:
def _validate_batch(self, batch):
if batch is None:
return [None]
# pp with schedule or navie-pp
if self._strategy.pipeline.enable or self._acc_steps == 1:
return feed
if self._strategy.pipeline.enable or self._acc_steps == 1:
# pp with schedule or navie-pp
return batch
else:
# split feed data with gradient_merge k_steps
feed_names = []
split_feeds = []
for feed_name, cur_feed in feed[0].items():
split_batches = []
for feed_name, cur_feed in batch[0].items():
feed_names.append(feed_name)
split_feeds.append(np.split(np.array(cur_feed), self._acc_steps, 0))
micro_feeds = []
split_batches.append(
np.split(np.array(cur_feed), self._acc_steps, 0)
)
baches = []
for i in range(self._acc_steps):
split_feed = [sf[i] for sf in split_feeds]
micro_feeds.append(dict(zip(feed_names, split_feed)))
return micro_feeds
micro_batch = [split_batch[i] for split_batch in split_batches]
baches.append(dict(zip(feed_names, micro_batch)))
return baches
def _validate_spec(self, specs):
specs = auto_utils.to_list(specs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册