From ab60ae920f9be253a0ad2b16acf11acc35c38143 Mon Sep 17 00:00:00 2001 From: JYChen Date: Tue, 1 Aug 2023 16:30:08 +0800 Subject: [PATCH] Remove old dataloader in ut (#55841) * done replaced in d2s models, except lac/transformer/sentiment * replace cinn_test * fix d2s transfomer model * fix cinn model ut --- .../contrib/test_image_classification_fp16.py | 6 -- .../test_multi_precision_fp16_train.py | 6 -- test/dygraph_to_static/test_bert.py | 49 +++++++++-- test/dygraph_to_static/test_mobile_net.py | 49 +++++++---- test/dygraph_to_static/test_resnet.py | 33 ++++++-- test/dygraph_to_static/test_resnet_v2.py | 33 ++++++-- test/dygraph_to_static/test_simnet.py | 37 +++++--- test/dygraph_to_static/test_simnet_v2.py | 35 ++++++-- test/dygraph_to_static/test_transformer.py | 39 +++++++-- test/dygraph_to_static/transformer_util.py | 84 +++++++++++++++++++ test/prim/model/test_resnet_cinn.py | 33 ++++++-- test/prim/model/test_resnet_prim.py | 33 ++++++-- test/prim/model/test_resnet_prim_cinn.py | 33 ++++++-- 13 files changed, 372 insertions(+), 98 deletions(-) diff --git a/test/contrib/test_image_classification_fp16.py b/test/contrib/test_image_classification_fp16.py index 9d192a1c76d..67e1189c0e5 100644 --- a/test/contrib/test_image_classification_fp16.py +++ b/test/contrib/test_image_classification_fp16.py @@ -497,12 +497,6 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): label = paddle.static.data( name='label', shape=[-1, 1], dtype='int64' ) - py_reader = fluid.io.DataLoader.from_generator( - feed_list=[image, label], - capacity=4, - iterable=False, - use_double_buffer=False, - ) net = vgg16_bn_drop(image) logits = paddle.static.nn.fc( diff --git a/test/contrib/test_multi_precision_fp16_train.py b/test/contrib/test_multi_precision_fp16_train.py index 64dc91585cb..bb1783a74a6 100644 --- a/test/contrib/test_multi_precision_fp16_train.py +++ b/test/contrib/test_multi_precision_fp16_train.py @@ -283,12 +283,6 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): label = paddle.static.data( name='label', shape=[-1, 1], dtype='int64' ) - py_reader = fluid.io.DataLoader.from_generator( - feed_list=[image, label], - capacity=4, - iterable=False, - use_double_buffer=False, - ) zero_var = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=0 ) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index df89f40b62b..ffb5a0d2517 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -35,6 +35,46 @@ STEP_NUM = 10 PRINT_STEP = 2 +class FakeBertDataset(paddle.io.Dataset): + def __init__(self, data_reader, steps): + self.src_ids = [] + self.pos_ids = [] + self.sent_ids = [] + self.input_mask = [] + self.mask_label = [] + self.mask_pos = [] + self.labels = [] + self.data_reader = data_reader + + self._generate_fake_data(1 * (steps + 1)) + + def _generate_fake_data(self, length): + for i, data in enumerate(self.data_reader.data_generator()()): + if i >= length: + break + self.src_ids.append(data[0]) + self.pos_ids.append(data[1]) + self.sent_ids.append(data[2]) + self.input_mask.append(data[3]) + self.mask_label.append(data[4]) + self.mask_pos.append(data[5]) + self.labels.append(data[6]) + + def __getitem__(self, idx): + return [ + self.src_ids[idx], + self.pos_ids[idx], + self.sent_ids[idx], + self.input_mask[idx], + self.mask_label[idx], + self.mask_pos[idx], + self.labels[idx], + ] + + def __len__(self): + return len(self.src_ids) + + class TestBert(unittest.TestCase): def setUp(self): self.bert_config = get_bert_config() @@ -56,11 +96,9 @@ class TestBert(unittest.TestCase): fluid.default_main_program().random_seed = SEED fluid.default_startup_program().random_seed = SEED - data_loader = fluid.io.DataLoader.from_generator( - capacity=50, iterable=True - ) - data_loader.set_batch_generator( - data_reader.data_generator(), places=place + fake_dataset = FakeBertDataset(data_reader, STEP_NUM) + data_loader = paddle.io.DataLoader( + fake_dataset, places=place, batch_size=None ) bert = PretrainModelLayer( @@ -80,6 +118,7 @@ class TestBert(unittest.TestCase): mask_pos, labels, ) = input_data + next_sent_acc, mask_lm_loss, total_loss = bert( src_ids=src_ids, position_ids=pos_ids, diff --git a/test/dygraph_to_static/test_mobile_net.py b/test/dygraph_to_static/test_mobile_net.py index 8b3273ba544..220779f1477 100644 --- a/test/dygraph_to_static/test_mobile_net.py +++ b/test/dygraph_to_static/test_mobile_net.py @@ -455,20 +455,33 @@ def create_optimizer(args, parameter_list): return optimizer -def fake_data_reader(batch_size, label_size): - local_random = np.random.RandomState(SEED) +class FakeDataSet(paddle.io.Dataset): + def __init__(self, batch_size, label_size, train_steps): + self.local_random = np.random.RandomState(SEED) + self.label_size = label_size - def reader(): - batch_data = [] - while True: - img = local_random.random_sample([3, 224, 224]).astype('float32') - label = local_random.randint(0, label_size, [1]).astype('int64') - batch_data.append([img, label]) - if len(batch_data) == batch_size: - yield batch_data - batch_data = [] + self.imgs = [] + self.labels = [] - return reader + self._generate_fake_data(batch_size * (train_steps + 1)) + + def _generate_fake_data(self, length): + for i in range(length): + img = self.local_random.random_sample([3, 224, 224]).astype( + 'float32' + ) + label = self.local_random.randint(0, self.label_size, [1]).astype( + 'int64' + ) + + self.imgs.append(img) + self.labels.append(label) + + def __getitem__(self, idx): + return [self.imgs[idx], self.labels[idx]] + + def __len__(self): + return len(self.imgs) class Args: @@ -513,9 +526,15 @@ def train_mobilenet(args, to_static): optimizer = create_optimizer(args=args, parameter_list=net.parameters()) # 3. reader - train_reader = fake_data_reader(args.batch_size, args.class_dim) - train_data_loader = fluid.io.DataLoader.from_generator(capacity=16) - train_data_loader.set_sample_list_generator(train_reader) + train_dataset = FakeDataSet( + args.batch_size, args.class_dim, args.train_step + ) + BatchSampler = paddle.io.BatchSampler( + train_dataset, batch_size=args.batch_size + ) + train_data_loader = paddle.io.DataLoader( + train_dataset, batch_sampler=BatchSampler + ) # 4. train loop loss_data = [] diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index afb001358b4..ed0382bcf28 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -220,6 +220,27 @@ def reader_decorator(reader): return __reader__ +class TransedFlowerDataSet(paddle.io.Dataset): + def __init__(self, flower_data, length): + self.img = [] + self.label = [] + self.flower_data = flower_data() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.flower_data): + if i >= length: + break + self.img.append(data[0]) + self.label.append(data[1]) + + def __getitem__(self, idx): + return self.img[idx], self.label[idx] + + def __len__(self): + return len(self.img) + + class ResNetHelper: def __init__(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -243,15 +264,13 @@ class ResNetHelper: paddle.seed(SEED) paddle.framework.random._manual_program_seed(SEED) - train_reader = paddle.batch( + dataset = TransedFlowerDataSet( reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), - batch_size=batch_size, - drop_last=True, + batch_size * (10 + 1), ) - data_loader = fluid.io.DataLoader.from_generator( - capacity=5, iterable=True + data_loader = paddle.io.DataLoader( + dataset, batch_size=batch_size, drop_last=True ) - data_loader.set_sample_list_generator(train_reader) resnet = ResNet() if to_static: @@ -315,8 +334,6 @@ class ResNetHelper: resnet.state_dict(), self.dy_state_dict_save_path + '.pdparams', ) - # avoid dataloader throw abort signaal - data_loader._reset() break return total_loss.numpy() diff --git a/test/dygraph_to_static/test_resnet_v2.py b/test/dygraph_to_static/test_resnet_v2.py index 2efbe46cedf..0805604e138 100644 --- a/test/dygraph_to_static/test_resnet_v2.py +++ b/test/dygraph_to_static/test_resnet_v2.py @@ -220,6 +220,27 @@ def reader_decorator(reader): return __reader__ +class TransedFlowerDataSet(paddle.io.Dataset): + def __init__(self, flower_data, length): + self.img = [] + self.label = [] + self.flower_data = flower_data() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.flower_data): + if i >= length: + break + self.img.append(data[0]) + self.label.append(data[1]) + + def __getitem__(self, idx): + return self.img[idx], self.label[idx] + + def __len__(self): + return len(self.img) + + class TestResnet(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -250,15 +271,13 @@ class TestResnet(unittest.TestCase): paddle.seed(SEED) paddle.framework.random._manual_program_seed(SEED) - train_reader = paddle.batch( + dataset = TransedFlowerDataSet( reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), - batch_size=batch_size, - drop_last=True, + batch_size * (10 + 1), ) - data_loader = paddle.fluid.io.DataLoader.from_generator( - capacity=5, iterable=True + data_loader = paddle.io.DataLoader( + dataset, batch_size=batch_size, drop_last=True ) - data_loader.set_sample_list_generator(train_reader) resnet = ResNet() optimizer = optimizer_setting(parameter_list=resnet.parameters()) @@ -311,8 +330,6 @@ class TestResnet(unittest.TestCase): resnet.state_dict(), self.dy_state_dict_save_path + '.pdparams', ) - # avoid dataloader throw abort signaal - data_loader._reset() break paddle.enable_static() diff --git a/test/dygraph_to_static/test_simnet.py b/test/dygraph_to_static/test_simnet.py index 2fdaa9fc836..fc966092ae9 100644 --- a/test/dygraph_to_static/test_simnet.py +++ b/test/dygraph_to_static/test_simnet.py @@ -73,8 +73,8 @@ def fake_vocabulary(): vocab = fake_vocabulary() -class FakeReaderProcessor: - def __init__(self, args, vocab): +class FakeReaderProcessor(paddle.io.Dataset): + def __init__(self, args, vocab, length): self.vocab = vocab self.seq_len = args.seq_len self.sample_size = args.fake_sample_size @@ -86,6 +86,10 @@ class FakeReaderProcessor: self.data_samples.append( np.array([query, pos_title, neg_title]).astype(np.int64) ) + self.query = [] + self.pos_title = [] + self.neg_title = [] + self._init_data(length) def get_reader(self, mode, epoch=0): def reader_with_pairwise(): @@ -95,8 +99,25 @@ class FakeReaderProcessor: return reader_with_pairwise + def _init_data(self, length): + reader = self.get_reader("train", epoch=args.epoch)() + for i, yield_data in enumerate(reader): + if i >= length: + break + self.query.append(yield_data[0]) + self.pos_title.append(yield_data[1]) + self.neg_title.append(yield_data[2]) -simnet_process = FakeReaderProcessor(args, vocab) + def __getitem__(self, idx): + return self.query[idx], self.pos_title[idx], self.neg_title[idx] + + def __len__(self): + return len(self.query) + + +simnet_process = FakeReaderProcessor( + args, vocab, args.batch_size * (args.epoch + 1) +) def train(conf_dict, to_static): @@ -133,14 +154,8 @@ def train(conf_dict, to_static): global_step = 0 losses = [] - train_loader = fluid.io.DataLoader.from_generator( - capacity=16, return_list=True, iterable=True, use_double_buffer=True - ) - get_train_examples = simnet_process.get_reader( - "train", epoch=args.epoch - ) - train_loader.set_sample_list_generator( - paddle.batch(get_train_examples, batch_size=args.batch_size), place + train_loader = paddle.io.DataLoader( + simnet_process, batch_size=args.batch_size, places=[place] ) for left, pos_right, neg_right in train_loader(): diff --git a/test/dygraph_to_static/test_simnet_v2.py b/test/dygraph_to_static/test_simnet_v2.py index a86259cc6d7..a49cc23af11 100644 --- a/test/dygraph_to_static/test_simnet_v2.py +++ b/test/dygraph_to_static/test_simnet_v2.py @@ -72,8 +72,8 @@ def fake_vocabulary(): vocab = fake_vocabulary() -class FakeReaderProcessor: - def __init__(self, args, vocab): +class FakeReaderProcessor(paddle.io.Dataset): + def __init__(self, args, vocab, length): self.vocab = vocab self.seq_len = args.seq_len self.sample_size = args.fake_sample_size @@ -85,6 +85,10 @@ class FakeReaderProcessor: self.data_samples.append( np.array([query, pos_title, neg_title]).astype(np.int64) ) + self.query = [] + self.pos_title = [] + self.neg_title = [] + self._init_data(length) def get_reader(self, mode, epoch=0): def reader_with_pairwise(): @@ -94,8 +98,25 @@ class FakeReaderProcessor: return reader_with_pairwise + def _init_data(self, length): + reader = self.get_reader("train", epoch=args.epoch)() + for i, yield_data in enumerate(reader): + if i >= length: + break + self.query.append(yield_data[0]) + self.pos_title.append(yield_data[1]) + self.neg_title.append(yield_data[2]) -simnet_process = FakeReaderProcessor(args, vocab) + def __getitem__(self, idx): + return self.query[idx], self.pos_title[idx], self.neg_title[idx] + + def __len__(self): + return len(self.query) + + +simnet_process = FakeReaderProcessor( + args, vocab, args.batch_size * (args.epoch + 1) +) def train(conf_dict, to_static): @@ -132,12 +153,8 @@ def train(conf_dict, to_static): global_step = 0 losses = [] - train_loader = paddle.fluid.io.DataLoader.from_generator( - capacity=16, return_list=True, iterable=True, use_double_buffer=True - ) - get_train_examples = simnet_process.get_reader("train", epoch=args.epoch) - train_loader.set_sample_list_generator( - paddle.batch(get_train_examples, batch_size=args.batch_size), place + train_loader = paddle.io.DataLoader( + simnet_process, batch_size=args.batch_size ) for left, pos_right, neg_right in train_loader(): diff --git a/test/dygraph_to_static/test_transformer.py b/test/dygraph_to_static/test_transformer.py index c0f540ac8eb..0942937bb68 100644 --- a/test/dygraph_to_static/test_transformer.py +++ b/test/dygraph_to_static/test_transformer.py @@ -63,10 +63,14 @@ def train_static(args, batch_generator): ] input_field = util.InputField(input_slots) # Define DataLoader - data_loader = fluid.io.DataLoader.from_generator( - input_field.feed_list, capacity=60 + data_loader = paddle.io.DataLoader( + batch_generator, + feed_list=input_field.feed_list, + return_list=False, + batch_size=None, + places=place, ) - data_loader.set_batch_generator(batch_generator, places=place) + # define model transformer = Transformer( args.src_vocab_size, @@ -183,8 +187,11 @@ def train_dygraph(args, batch_generator): paddle.seed(SEED) paddle.framework.random._manual_program_seed(SEED) # define data loader - train_loader = fluid.io.DataLoader.from_generator(capacity=10) - train_loader.set_batch_generator(batch_generator, places=place) + + train_loader = paddle.io.DataLoader( + batch_generator, batch_size=None, places=place + ) + # define model transformer = Transformer( args.src_vocab_size, @@ -322,8 +329,9 @@ def predict_dygraph(args, batch_generator): paddle.framework.random._manual_program_seed(SEED) # define data loader - test_loader = fluid.io.DataLoader.from_generator(capacity=10) - test_loader.set_batch_generator(batch_generator, places=place) + test_loader = paddle.io.DataLoader( + batch_generator, batch_size=None, places=place + ) # define model transformer = Transformer( @@ -433,8 +441,13 @@ def predict_static(args, batch_generator): input_field = util.InputField(input_slots) feed_list = input_field.feed_list - loader = fluid.io.DataLoader.from_generator( - feed_list=feed_list, capacity=10 + + loader = paddle.io.DataLoader( + batch_generator, + feed_list=feed_list, + return_list=False, + batch_size=None, + places=place, ) # define model @@ -533,6 +546,14 @@ class TestTransformer(unittest.TestCase): ) args.output_file = os.path.join(self.temp_dir.name, args.output_file) batch_generator = util.get_feed_data_reader(args, mode) + if mode == 'train': + batch_generator = util.TransedWMT16TrainDataSet( + batch_generator, args.batch_size * (args.epoch + 1) + ) + else: + batch_generator = util.TransedWMT16TestDataSet( + batch_generator, args.batch_size * (args.epoch + 1) + ) return args, batch_generator def _test_train(self): diff --git a/test/dygraph_to_static/transformer_util.py b/test/dygraph_to_static/transformer_util.py index 9be8c9ee2ea..17017fd3807 100644 --- a/test/dygraph_to_static/transformer_util.py +++ b/test/dygraph_to_static/transformer_util.py @@ -303,6 +303,90 @@ class InputField: ) +class TransedWMT16TrainDataSet(paddle.io.Dataset): + def __init__(self, data_reader, length): + self.src_word = [] + self.src_pos = [] + self.src_slf_attn_bias = [] + self.trg_word = [] + self.trg_pos = [] + self.trg_slf_attn_bias = [] + self.trg_src_attn_bias = [] + self.lbl_word = [] + self.lbl_weight = [] + + self.reader = data_reader() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.reader): + if i >= length: + break + self.src_word.append(data[0]) + self.src_pos.append(data[1]) + self.src_slf_attn_bias.append(data[2]) + self.trg_word.append(data[3]) + self.trg_pos.append(data[4]) + self.trg_slf_attn_bias.append(data[5]) + self.trg_src_attn_bias.append(data[6]) + self.lbl_word.append(data[7]) + self.lbl_weight.append(data[8]) + + def __getitem__(self, idx): + return ( + self.src_word[idx], + self.src_pos[idx], + self.src_slf_attn_bias[idx], + self.trg_word[idx], + self.trg_pos[idx], + self.trg_slf_attn_bias[idx], + self.trg_src_attn_bias[idx], + self.lbl_word[idx], + self.lbl_weight[idx], + ) + + def __len__(self): + return len(self.src_word) + + +class TransedWMT16TestDataSet(paddle.io.Dataset): + def __init__(self, data_reader, length): + self.src_word = [] + self.src_pos = [] + self.src_slf_attn_bias = [] + self.trg_word = [] + self.trg_pos = [] + self.trg_slf_attn_bias = [] + self.trg_src_attn_bias = [] + self.lbl_word = [] + self.lbl_weight = [] + + self.reader = data_reader() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.reader): + if i >= length: + break + self.src_word.append(data[0]) + self.src_pos.append(data[1]) + self.src_slf_attn_bias.append(data[2]) + self.trg_word.append(data[3]) + self.trg_slf_attn_bias.append(data[4]) + + def __getitem__(self, idx): + return ( + self.src_word[idx], + self.src_pos[idx], + self.src_slf_attn_bias[idx], + self.trg_word[idx], + self.trg_slf_attn_bias[idx], + ) + + def __len__(self): + return len(self.src_word) + + def load(program, model_path, executor=None, var_list=None): """ To load python2 saved models in python3. diff --git a/test/prim/model/test_resnet_cinn.py b/test/prim/model/test_resnet_cinn.py index 963b311b5b0..31877e78632 100644 --- a/test/prim/model/test_resnet_cinn.py +++ b/test/prim/model/test_resnet_cinn.py @@ -71,6 +71,27 @@ def reader_decorator(reader): return __reader__ +class TransedFlowerDataSet(paddle.io.Dataset): + def __init__(self, flower_data, length): + self.img = [] + self.label = [] + self.flower_data = flower_data() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.flower_data): + if i >= length: + break + self.img.append(data[0]) + self.label.append(data[1]) + + def __getitem__(self, idx): + return self.img[idx], self.label[idx] + + def __len__(self): + return len(self.img) + + def optimizer_setting(parameter_list=None): optimizer = fluid.optimizer.Momentum( learning_rate=base_lr, @@ -136,8 +157,6 @@ def run(model, data_loader, optimizer, mode): ) ) if batch_id >= end_step: - # avoid dataloader throw abort signaal - data_loader._reset() break print(losses) return losses @@ -153,13 +172,13 @@ def train(to_static, enable_prim, enable_cinn): paddle.framework.random._manual_program_seed(SEED) fluid.core._set_prim_all_enabled(enable_prim) - train_reader = paddle.batch( + dataset = TransedFlowerDataSet( reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), - batch_size=batch_size, - drop_last=True, + batch_size * (10 + 1), + ) + data_loader = paddle.io.DataLoader( + dataset, batch_size=batch_size, drop_last=True ) - data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True) - data_loader.set_sample_list_generator(train_reader) resnet = resnet50(False) if to_static: diff --git a/test/prim/model/test_resnet_prim.py b/test/prim/model/test_resnet_prim.py index d55f795286a..8406bbd298f 100644 --- a/test/prim/model/test_resnet_prim.py +++ b/test/prim/model/test_resnet_prim.py @@ -71,6 +71,27 @@ def reader_decorator(reader): return __reader__ +class TransedFlowerDataSet(paddle.io.Dataset): + def __init__(self, flower_data, length): + self.img = [] + self.label = [] + self.flower_data = flower_data() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.flower_data): + if i >= length: + break + self.img.append(data[0]) + self.label.append(data[1]) + + def __getitem__(self, idx): + return self.img[idx], self.label[idx] + + def __len__(self): + return len(self.img) + + def optimizer_setting(parameter_list=None): optimizer = fluid.optimizer.Momentum( learning_rate=base_lr, @@ -136,8 +157,6 @@ def run(model, data_loader, optimizer, mode): ) ) if batch_id >= end_step: - # avoid dataloader throw abort signaal - data_loader._reset() break print(losses) return losses @@ -153,13 +172,13 @@ def train(to_static, enable_prim, enable_cinn): paddle.framework.random._manual_program_seed(SEED) fluid.core._set_prim_all_enabled(enable_prim) - train_reader = paddle.batch( + dataset = TransedFlowerDataSet( reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), - batch_size=batch_size, - drop_last=True, + batch_size * (10 + 1), + ) + data_loader = paddle.io.DataLoader( + dataset, batch_size=batch_size, drop_last=True ) - data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True) - data_loader.set_sample_list_generator(train_reader) resnet = resnet50(False) if to_static: diff --git a/test/prim/model/test_resnet_prim_cinn.py b/test/prim/model/test_resnet_prim_cinn.py index 0acf6253934..5e20a983cc5 100644 --- a/test/prim/model/test_resnet_prim_cinn.py +++ b/test/prim/model/test_resnet_prim_cinn.py @@ -70,6 +70,27 @@ def reader_decorator(reader): return __reader__ +class TransedFlowerDataSet(paddle.io.Dataset): + def __init__(self, flower_data, length): + self.img = [] + self.label = [] + self.flower_data = flower_data() + self._generate(length) + + def _generate(self, length): + for i, data in enumerate(self.flower_data): + if i >= length: + break + self.img.append(data[0]) + self.label.append(data[1]) + + def __getitem__(self, idx): + return self.img[idx], self.label[idx] + + def __len__(self): + return len(self.img) + + def optimizer_setting(parameter_list=None): optimizer = fluid.optimizer.Momentum( learning_rate=base_lr, @@ -135,8 +156,6 @@ def run(model, data_loader, optimizer, mode): ) ) if batch_id >= end_step: - # avoid dataloader throw abort signaal - data_loader._reset() break print(losses) return losses @@ -152,13 +171,13 @@ def train(to_static, enable_prim, enable_cinn): paddle.framework.random._manual_program_seed(SEED) fluid.core._set_prim_all_enabled(enable_prim) - train_reader = paddle.batch( + dataset = TransedFlowerDataSet( reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), - batch_size=batch_size, - drop_last=True, + batch_size * (10 + 1), + ) + data_loader = paddle.io.DataLoader( + dataset, batch_size=batch_size, drop_last=True ) - data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True) - data_loader.set_sample_list_generator(train_reader) resnet = resnet50(False) if to_static: -- GitLab