未验证 提交 ab60ae92 编写于 作者: J JYChen 提交者: GitHub

Remove old dataloader in ut (#55841)

* done replaced in d2s models, except lac/transformer/sentiment

* replace cinn_test

* fix d2s transfomer model

* fix cinn model ut
上级 75c29ac1
...@@ -497,12 +497,6 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): ...@@ -497,12 +497,6 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
label = paddle.static.data( label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64' name='label', shape=[-1, 1], dtype='int64'
) )
py_reader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=4,
iterable=False,
use_double_buffer=False,
)
net = vgg16_bn_drop(image) net = vgg16_bn_drop(image)
logits = paddle.static.nn.fc( logits = paddle.static.nn.fc(
......
...@@ -283,12 +283,6 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): ...@@ -283,12 +283,6 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
label = paddle.static.data( label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64' name='label', shape=[-1, 1], dtype='int64'
) )
py_reader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=4,
iterable=False,
use_double_buffer=False,
)
zero_var = paddle.tensor.fill_constant( zero_var = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=0 shape=[1], dtype='int64', value=0
) )
......
...@@ -35,6 +35,46 @@ STEP_NUM = 10 ...@@ -35,6 +35,46 @@ STEP_NUM = 10
PRINT_STEP = 2 PRINT_STEP = 2
class FakeBertDataset(paddle.io.Dataset):
def __init__(self, data_reader, steps):
self.src_ids = []
self.pos_ids = []
self.sent_ids = []
self.input_mask = []
self.mask_label = []
self.mask_pos = []
self.labels = []
self.data_reader = data_reader
self._generate_fake_data(1 * (steps + 1))
def _generate_fake_data(self, length):
for i, data in enumerate(self.data_reader.data_generator()()):
if i >= length:
break
self.src_ids.append(data[0])
self.pos_ids.append(data[1])
self.sent_ids.append(data[2])
self.input_mask.append(data[3])
self.mask_label.append(data[4])
self.mask_pos.append(data[5])
self.labels.append(data[6])
def __getitem__(self, idx):
return [
self.src_ids[idx],
self.pos_ids[idx],
self.sent_ids[idx],
self.input_mask[idx],
self.mask_label[idx],
self.mask_pos[idx],
self.labels[idx],
]
def __len__(self):
return len(self.src_ids)
class TestBert(unittest.TestCase): class TestBert(unittest.TestCase):
def setUp(self): def setUp(self):
self.bert_config = get_bert_config() self.bert_config = get_bert_config()
...@@ -56,11 +96,9 @@ class TestBert(unittest.TestCase): ...@@ -56,11 +96,9 @@ class TestBert(unittest.TestCase):
fluid.default_main_program().random_seed = SEED fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED fluid.default_startup_program().random_seed = SEED
data_loader = fluid.io.DataLoader.from_generator( fake_dataset = FakeBertDataset(data_reader, STEP_NUM)
capacity=50, iterable=True data_loader = paddle.io.DataLoader(
) fake_dataset, places=place, batch_size=None
data_loader.set_batch_generator(
data_reader.data_generator(), places=place
) )
bert = PretrainModelLayer( bert = PretrainModelLayer(
...@@ -80,6 +118,7 @@ class TestBert(unittest.TestCase): ...@@ -80,6 +118,7 @@ class TestBert(unittest.TestCase):
mask_pos, mask_pos,
labels, labels,
) = input_data ) = input_data
next_sent_acc, mask_lm_loss, total_loss = bert( next_sent_acc, mask_lm_loss, total_loss = bert(
src_ids=src_ids, src_ids=src_ids,
position_ids=pos_ids, position_ids=pos_ids,
......
...@@ -455,20 +455,33 @@ def create_optimizer(args, parameter_list): ...@@ -455,20 +455,33 @@ def create_optimizer(args, parameter_list):
return optimizer return optimizer
def fake_data_reader(batch_size, label_size): class FakeDataSet(paddle.io.Dataset):
local_random = np.random.RandomState(SEED) def __init__(self, batch_size, label_size, train_steps):
self.local_random = np.random.RandomState(SEED)
self.label_size = label_size
def reader(): self.imgs = []
batch_data = [] self.labels = []
while True:
img = local_random.random_sample([3, 224, 224]).astype('float32')
label = local_random.randint(0, label_size, [1]).astype('int64')
batch_data.append([img, label])
if len(batch_data) == batch_size:
yield batch_data
batch_data = []
return reader self._generate_fake_data(batch_size * (train_steps + 1))
def _generate_fake_data(self, length):
for i in range(length):
img = self.local_random.random_sample([3, 224, 224]).astype(
'float32'
)
label = self.local_random.randint(0, self.label_size, [1]).astype(
'int64'
)
self.imgs.append(img)
self.labels.append(label)
def __getitem__(self, idx):
return [self.imgs[idx], self.labels[idx]]
def __len__(self):
return len(self.imgs)
class Args: class Args:
...@@ -513,9 +526,15 @@ def train_mobilenet(args, to_static): ...@@ -513,9 +526,15 @@ def train_mobilenet(args, to_static):
optimizer = create_optimizer(args=args, parameter_list=net.parameters()) optimizer = create_optimizer(args=args, parameter_list=net.parameters())
# 3. reader # 3. reader
train_reader = fake_data_reader(args.batch_size, args.class_dim) train_dataset = FakeDataSet(
train_data_loader = fluid.io.DataLoader.from_generator(capacity=16) args.batch_size, args.class_dim, args.train_step
train_data_loader.set_sample_list_generator(train_reader) )
BatchSampler = paddle.io.BatchSampler(
train_dataset, batch_size=args.batch_size
)
train_data_loader = paddle.io.DataLoader(
train_dataset, batch_sampler=BatchSampler
)
# 4. train loop # 4. train loop
loss_data = [] loss_data = []
......
...@@ -220,6 +220,27 @@ def reader_decorator(reader): ...@@ -220,6 +220,27 @@ def reader_decorator(reader):
return __reader__ return __reader__
class TransedFlowerDataSet(paddle.io.Dataset):
def __init__(self, flower_data, length):
self.img = []
self.label = []
self.flower_data = flower_data()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.flower_data):
if i >= length:
break
self.img.append(data[0])
self.label.append(data[1])
def __getitem__(self, idx):
return self.img[idx], self.label[idx]
def __len__(self):
return len(self.img)
class ResNetHelper: class ResNetHelper:
def __init__(self): def __init__(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
...@@ -243,15 +264,13 @@ class ResNetHelper: ...@@ -243,15 +264,13 @@ class ResNetHelper:
paddle.seed(SEED) paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
train_reader = paddle.batch( dataset = TransedFlowerDataSet(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size, batch_size * (10 + 1),
drop_last=True,
) )
data_loader = fluid.io.DataLoader.from_generator( data_loader = paddle.io.DataLoader(
capacity=5, iterable=True dataset, batch_size=batch_size, drop_last=True
) )
data_loader.set_sample_list_generator(train_reader)
resnet = ResNet() resnet = ResNet()
if to_static: if to_static:
...@@ -315,8 +334,6 @@ class ResNetHelper: ...@@ -315,8 +334,6 @@ class ResNetHelper:
resnet.state_dict(), resnet.state_dict(),
self.dy_state_dict_save_path + '.pdparams', self.dy_state_dict_save_path + '.pdparams',
) )
# avoid dataloader throw abort signaal
data_loader._reset()
break break
return total_loss.numpy() return total_loss.numpy()
......
...@@ -220,6 +220,27 @@ def reader_decorator(reader): ...@@ -220,6 +220,27 @@ def reader_decorator(reader):
return __reader__ return __reader__
class TransedFlowerDataSet(paddle.io.Dataset):
def __init__(self, flower_data, length):
self.img = []
self.label = []
self.flower_data = flower_data()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.flower_data):
if i >= length:
break
self.img.append(data[0])
self.label.append(data[1])
def __getitem__(self, idx):
return self.img[idx], self.label[idx]
def __len__(self):
return len(self.img)
class TestResnet(unittest.TestCase): class TestResnet(unittest.TestCase):
def setUp(self): def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
...@@ -250,15 +271,13 @@ class TestResnet(unittest.TestCase): ...@@ -250,15 +271,13 @@ class TestResnet(unittest.TestCase):
paddle.seed(SEED) paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
train_reader = paddle.batch( dataset = TransedFlowerDataSet(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size, batch_size * (10 + 1),
drop_last=True,
) )
data_loader = paddle.fluid.io.DataLoader.from_generator( data_loader = paddle.io.DataLoader(
capacity=5, iterable=True dataset, batch_size=batch_size, drop_last=True
) )
data_loader.set_sample_list_generator(train_reader)
resnet = ResNet() resnet = ResNet()
optimizer = optimizer_setting(parameter_list=resnet.parameters()) optimizer = optimizer_setting(parameter_list=resnet.parameters())
...@@ -311,8 +330,6 @@ class TestResnet(unittest.TestCase): ...@@ -311,8 +330,6 @@ class TestResnet(unittest.TestCase):
resnet.state_dict(), resnet.state_dict(),
self.dy_state_dict_save_path + '.pdparams', self.dy_state_dict_save_path + '.pdparams',
) )
# avoid dataloader throw abort signaal
data_loader._reset()
break break
paddle.enable_static() paddle.enable_static()
......
...@@ -73,8 +73,8 @@ def fake_vocabulary(): ...@@ -73,8 +73,8 @@ def fake_vocabulary():
vocab = fake_vocabulary() vocab = fake_vocabulary()
class FakeReaderProcessor: class FakeReaderProcessor(paddle.io.Dataset):
def __init__(self, args, vocab): def __init__(self, args, vocab, length):
self.vocab = vocab self.vocab = vocab
self.seq_len = args.seq_len self.seq_len = args.seq_len
self.sample_size = args.fake_sample_size self.sample_size = args.fake_sample_size
...@@ -86,6 +86,10 @@ class FakeReaderProcessor: ...@@ -86,6 +86,10 @@ class FakeReaderProcessor:
self.data_samples.append( self.data_samples.append(
np.array([query, pos_title, neg_title]).astype(np.int64) np.array([query, pos_title, neg_title]).astype(np.int64)
) )
self.query = []
self.pos_title = []
self.neg_title = []
self._init_data(length)
def get_reader(self, mode, epoch=0): def get_reader(self, mode, epoch=0):
def reader_with_pairwise(): def reader_with_pairwise():
...@@ -95,8 +99,25 @@ class FakeReaderProcessor: ...@@ -95,8 +99,25 @@ class FakeReaderProcessor:
return reader_with_pairwise return reader_with_pairwise
def _init_data(self, length):
reader = self.get_reader("train", epoch=args.epoch)()
for i, yield_data in enumerate(reader):
if i >= length:
break
self.query.append(yield_data[0])
self.pos_title.append(yield_data[1])
self.neg_title.append(yield_data[2])
simnet_process = FakeReaderProcessor(args, vocab) def __getitem__(self, idx):
return self.query[idx], self.pos_title[idx], self.neg_title[idx]
def __len__(self):
return len(self.query)
simnet_process = FakeReaderProcessor(
args, vocab, args.batch_size * (args.epoch + 1)
)
def train(conf_dict, to_static): def train(conf_dict, to_static):
...@@ -133,14 +154,8 @@ def train(conf_dict, to_static): ...@@ -133,14 +154,8 @@ def train(conf_dict, to_static):
global_step = 0 global_step = 0
losses = [] losses = []
train_loader = fluid.io.DataLoader.from_generator( train_loader = paddle.io.DataLoader(
capacity=16, return_list=True, iterable=True, use_double_buffer=True simnet_process, batch_size=args.batch_size, places=[place]
)
get_train_examples = simnet_process.get_reader(
"train", epoch=args.epoch
)
train_loader.set_sample_list_generator(
paddle.batch(get_train_examples, batch_size=args.batch_size), place
) )
for left, pos_right, neg_right in train_loader(): for left, pos_right, neg_right in train_loader():
......
...@@ -72,8 +72,8 @@ def fake_vocabulary(): ...@@ -72,8 +72,8 @@ def fake_vocabulary():
vocab = fake_vocabulary() vocab = fake_vocabulary()
class FakeReaderProcessor: class FakeReaderProcessor(paddle.io.Dataset):
def __init__(self, args, vocab): def __init__(self, args, vocab, length):
self.vocab = vocab self.vocab = vocab
self.seq_len = args.seq_len self.seq_len = args.seq_len
self.sample_size = args.fake_sample_size self.sample_size = args.fake_sample_size
...@@ -85,6 +85,10 @@ class FakeReaderProcessor: ...@@ -85,6 +85,10 @@ class FakeReaderProcessor:
self.data_samples.append( self.data_samples.append(
np.array([query, pos_title, neg_title]).astype(np.int64) np.array([query, pos_title, neg_title]).astype(np.int64)
) )
self.query = []
self.pos_title = []
self.neg_title = []
self._init_data(length)
def get_reader(self, mode, epoch=0): def get_reader(self, mode, epoch=0):
def reader_with_pairwise(): def reader_with_pairwise():
...@@ -94,8 +98,25 @@ class FakeReaderProcessor: ...@@ -94,8 +98,25 @@ class FakeReaderProcessor:
return reader_with_pairwise return reader_with_pairwise
def _init_data(self, length):
reader = self.get_reader("train", epoch=args.epoch)()
for i, yield_data in enumerate(reader):
if i >= length:
break
self.query.append(yield_data[0])
self.pos_title.append(yield_data[1])
self.neg_title.append(yield_data[2])
simnet_process = FakeReaderProcessor(args, vocab) def __getitem__(self, idx):
return self.query[idx], self.pos_title[idx], self.neg_title[idx]
def __len__(self):
return len(self.query)
simnet_process = FakeReaderProcessor(
args, vocab, args.batch_size * (args.epoch + 1)
)
def train(conf_dict, to_static): def train(conf_dict, to_static):
...@@ -132,12 +153,8 @@ def train(conf_dict, to_static): ...@@ -132,12 +153,8 @@ def train(conf_dict, to_static):
global_step = 0 global_step = 0
losses = [] losses = []
train_loader = paddle.fluid.io.DataLoader.from_generator( train_loader = paddle.io.DataLoader(
capacity=16, return_list=True, iterable=True, use_double_buffer=True simnet_process, batch_size=args.batch_size
)
get_train_examples = simnet_process.get_reader("train", epoch=args.epoch)
train_loader.set_sample_list_generator(
paddle.batch(get_train_examples, batch_size=args.batch_size), place
) )
for left, pos_right, neg_right in train_loader(): for left, pos_right, neg_right in train_loader():
......
...@@ -63,10 +63,14 @@ def train_static(args, batch_generator): ...@@ -63,10 +63,14 @@ def train_static(args, batch_generator):
] ]
input_field = util.InputField(input_slots) input_field = util.InputField(input_slots)
# Define DataLoader # Define DataLoader
data_loader = fluid.io.DataLoader.from_generator( data_loader = paddle.io.DataLoader(
input_field.feed_list, capacity=60 batch_generator,
feed_list=input_field.feed_list,
return_list=False,
batch_size=None,
places=place,
) )
data_loader.set_batch_generator(batch_generator, places=place)
# define model # define model
transformer = Transformer( transformer = Transformer(
args.src_vocab_size, args.src_vocab_size,
...@@ -183,8 +187,11 @@ def train_dygraph(args, batch_generator): ...@@ -183,8 +187,11 @@ def train_dygraph(args, batch_generator):
paddle.seed(SEED) paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
# define data loader # define data loader
train_loader = fluid.io.DataLoader.from_generator(capacity=10)
train_loader.set_batch_generator(batch_generator, places=place) train_loader = paddle.io.DataLoader(
batch_generator, batch_size=None, places=place
)
# define model # define model
transformer = Transformer( transformer = Transformer(
args.src_vocab_size, args.src_vocab_size,
...@@ -322,8 +329,9 @@ def predict_dygraph(args, batch_generator): ...@@ -322,8 +329,9 @@ def predict_dygraph(args, batch_generator):
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
# define data loader # define data loader
test_loader = fluid.io.DataLoader.from_generator(capacity=10) test_loader = paddle.io.DataLoader(
test_loader.set_batch_generator(batch_generator, places=place) batch_generator, batch_size=None, places=place
)
# define model # define model
transformer = Transformer( transformer = Transformer(
...@@ -433,8 +441,13 @@ def predict_static(args, batch_generator): ...@@ -433,8 +441,13 @@ def predict_static(args, batch_generator):
input_field = util.InputField(input_slots) input_field = util.InputField(input_slots)
feed_list = input_field.feed_list feed_list = input_field.feed_list
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list, capacity=10 loader = paddle.io.DataLoader(
batch_generator,
feed_list=feed_list,
return_list=False,
batch_size=None,
places=place,
) )
# define model # define model
...@@ -533,6 +546,14 @@ class TestTransformer(unittest.TestCase): ...@@ -533,6 +546,14 @@ class TestTransformer(unittest.TestCase):
) )
args.output_file = os.path.join(self.temp_dir.name, args.output_file) args.output_file = os.path.join(self.temp_dir.name, args.output_file)
batch_generator = util.get_feed_data_reader(args, mode) batch_generator = util.get_feed_data_reader(args, mode)
if mode == 'train':
batch_generator = util.TransedWMT16TrainDataSet(
batch_generator, args.batch_size * (args.epoch + 1)
)
else:
batch_generator = util.TransedWMT16TestDataSet(
batch_generator, args.batch_size * (args.epoch + 1)
)
return args, batch_generator return args, batch_generator
def _test_train(self): def _test_train(self):
......
...@@ -303,6 +303,90 @@ class InputField: ...@@ -303,6 +303,90 @@ class InputField:
) )
class TransedWMT16TrainDataSet(paddle.io.Dataset):
def __init__(self, data_reader, length):
self.src_word = []
self.src_pos = []
self.src_slf_attn_bias = []
self.trg_word = []
self.trg_pos = []
self.trg_slf_attn_bias = []
self.trg_src_attn_bias = []
self.lbl_word = []
self.lbl_weight = []
self.reader = data_reader()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.reader):
if i >= length:
break
self.src_word.append(data[0])
self.src_pos.append(data[1])
self.src_slf_attn_bias.append(data[2])
self.trg_word.append(data[3])
self.trg_pos.append(data[4])
self.trg_slf_attn_bias.append(data[5])
self.trg_src_attn_bias.append(data[6])
self.lbl_word.append(data[7])
self.lbl_weight.append(data[8])
def __getitem__(self, idx):
return (
self.src_word[idx],
self.src_pos[idx],
self.src_slf_attn_bias[idx],
self.trg_word[idx],
self.trg_pos[idx],
self.trg_slf_attn_bias[idx],
self.trg_src_attn_bias[idx],
self.lbl_word[idx],
self.lbl_weight[idx],
)
def __len__(self):
return len(self.src_word)
class TransedWMT16TestDataSet(paddle.io.Dataset):
def __init__(self, data_reader, length):
self.src_word = []
self.src_pos = []
self.src_slf_attn_bias = []
self.trg_word = []
self.trg_pos = []
self.trg_slf_attn_bias = []
self.trg_src_attn_bias = []
self.lbl_word = []
self.lbl_weight = []
self.reader = data_reader()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.reader):
if i >= length:
break
self.src_word.append(data[0])
self.src_pos.append(data[1])
self.src_slf_attn_bias.append(data[2])
self.trg_word.append(data[3])
self.trg_slf_attn_bias.append(data[4])
def __getitem__(self, idx):
return (
self.src_word[idx],
self.src_pos[idx],
self.src_slf_attn_bias[idx],
self.trg_word[idx],
self.trg_slf_attn_bias[idx],
)
def __len__(self):
return len(self.src_word)
def load(program, model_path, executor=None, var_list=None): def load(program, model_path, executor=None, var_list=None):
""" """
To load python2 saved models in python3. To load python2 saved models in python3.
......
...@@ -71,6 +71,27 @@ def reader_decorator(reader): ...@@ -71,6 +71,27 @@ def reader_decorator(reader):
return __reader__ return __reader__
class TransedFlowerDataSet(paddle.io.Dataset):
def __init__(self, flower_data, length):
self.img = []
self.label = []
self.flower_data = flower_data()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.flower_data):
if i >= length:
break
self.img.append(data[0])
self.label.append(data[1])
def __getitem__(self, idx):
return self.img[idx], self.label[idx]
def __len__(self):
return len(self.img)
def optimizer_setting(parameter_list=None): def optimizer_setting(parameter_list=None):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=base_lr,
...@@ -136,8 +157,6 @@ def run(model, data_loader, optimizer, mode): ...@@ -136,8 +157,6 @@ def run(model, data_loader, optimizer, mode):
) )
) )
if batch_id >= end_step: if batch_id >= end_step:
# avoid dataloader throw abort signaal
data_loader._reset()
break break
print(losses) print(losses)
return losses return losses
...@@ -153,13 +172,13 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -153,13 +172,13 @@ def train(to_static, enable_prim, enable_cinn):
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(enable_prim) fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch( dataset = TransedFlowerDataSet(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size, batch_size * (10 + 1),
drop_last=True, )
data_loader = paddle.io.DataLoader(
dataset, batch_size=batch_size, drop_last=True
) )
data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True)
data_loader.set_sample_list_generator(train_reader)
resnet = resnet50(False) resnet = resnet50(False)
if to_static: if to_static:
......
...@@ -71,6 +71,27 @@ def reader_decorator(reader): ...@@ -71,6 +71,27 @@ def reader_decorator(reader):
return __reader__ return __reader__
class TransedFlowerDataSet(paddle.io.Dataset):
def __init__(self, flower_data, length):
self.img = []
self.label = []
self.flower_data = flower_data()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.flower_data):
if i >= length:
break
self.img.append(data[0])
self.label.append(data[1])
def __getitem__(self, idx):
return self.img[idx], self.label[idx]
def __len__(self):
return len(self.img)
def optimizer_setting(parameter_list=None): def optimizer_setting(parameter_list=None):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=base_lr,
...@@ -136,8 +157,6 @@ def run(model, data_loader, optimizer, mode): ...@@ -136,8 +157,6 @@ def run(model, data_loader, optimizer, mode):
) )
) )
if batch_id >= end_step: if batch_id >= end_step:
# avoid dataloader throw abort signaal
data_loader._reset()
break break
print(losses) print(losses)
return losses return losses
...@@ -153,13 +172,13 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -153,13 +172,13 @@ def train(to_static, enable_prim, enable_cinn):
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(enable_prim) fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch( dataset = TransedFlowerDataSet(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size, batch_size * (10 + 1),
drop_last=True, )
data_loader = paddle.io.DataLoader(
dataset, batch_size=batch_size, drop_last=True
) )
data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True)
data_loader.set_sample_list_generator(train_reader)
resnet = resnet50(False) resnet = resnet50(False)
if to_static: if to_static:
......
...@@ -70,6 +70,27 @@ def reader_decorator(reader): ...@@ -70,6 +70,27 @@ def reader_decorator(reader):
return __reader__ return __reader__
class TransedFlowerDataSet(paddle.io.Dataset):
def __init__(self, flower_data, length):
self.img = []
self.label = []
self.flower_data = flower_data()
self._generate(length)
def _generate(self, length):
for i, data in enumerate(self.flower_data):
if i >= length:
break
self.img.append(data[0])
self.label.append(data[1])
def __getitem__(self, idx):
return self.img[idx], self.label[idx]
def __len__(self):
return len(self.img)
def optimizer_setting(parameter_list=None): def optimizer_setting(parameter_list=None):
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
learning_rate=base_lr, learning_rate=base_lr,
...@@ -135,8 +156,6 @@ def run(model, data_loader, optimizer, mode): ...@@ -135,8 +156,6 @@ def run(model, data_loader, optimizer, mode):
) )
) )
if batch_id >= end_step: if batch_id >= end_step:
# avoid dataloader throw abort signaal
data_loader._reset()
break break
print(losses) print(losses)
return losses return losses
...@@ -152,13 +171,13 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -152,13 +171,13 @@ def train(to_static, enable_prim, enable_cinn):
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(enable_prim) fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch( dataset = TransedFlowerDataSet(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size, batch_size * (10 + 1),
drop_last=True, )
data_loader = paddle.io.DataLoader(
dataset, batch_size=batch_size, drop_last=True
) )
data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True)
data_loader.set_sample_list_generator(train_reader)
resnet = resnet50(False) resnet = resnet50(False)
if to_static: if to_static:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册