diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index d1261eca247f1d3b52a4c6bdd46fd747b7b12feb..2b66ac4e16a22d663716574466253d5553ee953d 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -949,6 +949,7 @@ class Engine: ... batch_size=64) """ self._mode = 'train' + self._inputs_spec, self._labels_spec = self._prepare_data_spec( train_data, train_sample_split, batch_size ) @@ -1011,14 +1012,14 @@ class Engine: logs = {} cbks.on_epoch_begin(epoch) - for step, data in enumerate(train_dataloader): + for step, batch in enumerate(train_dataloader): if auto_utils.use_new_executor(): - feeds = self._validate_feed(data) + batches = self._validate_batch(batch) else: - feeds = [{}] + batches = [{}] try: - for micro_feed in feeds: + for micro_batch in batches: with paddle.profiler.utils._nvprof_range( iter_id=step, start=nvprof_range[0], @@ -1027,7 +1028,7 @@ class Engine: cbks.on_batch_begin('train', step, logs) outs = self._executor.run( self.main_program, - feed=micro_feed, + feed=micro_batch, fetch_list=fetch_names, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy, @@ -1136,12 +1137,13 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( valid_data, valid_sample_split, batch_size ) - micro_batch_size = self._validate_batch_size(batch_size) + if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) + micro_batch_size = self._validate_batch_size(batch_size) valid_dataloader = self._prepare_dataloader_from_generator( dataset=valid_data, capacity=70, @@ -1243,12 +1245,13 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( test_data, test_sample_split, batch_size ) - micro_batch_size = self._validate_batch_size(batch_size) + if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) + micro_batch_size = self._validate_batch_size(batch_size) test_dataloader = self._prepare_dataloader_from_generator( dataset=test_data, capacity=70, @@ -1304,19 +1307,21 @@ class Engine: ): if mode is not None: self.to_mode(mode) + self._inputs_spec, self._labels_spec = self._prepare_data_spec( dataset, sample_split, batch_size ) - micro_batch_size = self._validate_batch_size(batch_size) + if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) + batch_size = self._validate_batch_size(batch_size) dataloader = self._prepare_dataloader( dataset, return_list=False, - batch_size=micro_batch_size, + batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, @@ -1351,12 +1356,13 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( dataset, sample_split, batch_size ) - micro_batch_size = self._validate_batch_size(batch_size) + if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: self._switch_mode(self._mode) + micro_batch_size = self._validate_batch_size(batch_size) dataloader = self._prepare_dataloader_from_generator( dataset=dataset, capacity=capacity, @@ -1582,33 +1588,48 @@ class Engine: def _validate_batch_size(self, batch_size): if batch_size is None: return None - if self._strategy.pipeline.enable and auto_utils.use_new_executor(): - return batch_size - assert ( - batch_size % self._acc_steps == 0 - ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format( - batch_size, self._acc_steps - ) - return batch_size // self._acc_steps - def _validate_feed(self, feed): - if feed is None: + if auto_utils.use_new_executor(): + assert ( + len(set(self._dp_world_sizes)) == 1 + ), "DistributedBatchSampler only support one data parallel group, but got [{}] different data parallel groups".format( + len(set(self._dp_world_sizes)) + ) + assert ( + batch_size % self._dp_world_sizes[0] == 0 + ), "batch_size [{}] is not divisible by dp_world_size [{}]".format( + str(batch_size), str(self._dp_world_sizes[0]) + ) + return batch_size // self._dp_world_sizes[0] + else: + assert ( + batch_size % self._acc_steps == 0 + ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format( + batch_size, self._acc_steps + ) + return batch_size // self._acc_steps + + def _validate_batch(self, batch): + if batch is None: return [None] - # pp with schedule or navie-pp + if self._strategy.pipeline.enable or self._acc_steps == 1: - return feed - - # split feed data with gradient_merge k_steps - feed_names = [] - split_feeds = [] - for feed_name, cur_feed in feed[0].items(): - feed_names.append(feed_name) - split_feeds.append(np.split(np.array(cur_feed), self._acc_steps, 0)) - micro_feeds = [] - for i in range(self._acc_steps): - split_feed = [sf[i] for sf in split_feeds] - micro_feeds.append(dict(zip(feed_names, split_feed))) - return micro_feeds + # pp with schedule or navie-pp + return batch + else: + # split feed data with gradient_merge k_steps + feed_names = [] + split_batches = [] + for feed_name, cur_feed in batch[0].items(): + feed_names.append(feed_name) + split_batches.append( + np.split(np.array(cur_feed), self._acc_steps, 0) + ) + baches = [] + for i in range(self._acc_steps): + micro_batch = [split_batch[i] for split_batch in split_batches] + baches.append(dict(zip(feed_names, micro_batch))) + return baches def _validate_spec(self, specs): specs = auto_utils.to_list(specs)