未验证 提交 73cb3a0f 编写于 作者: L Leo Chen 提交者: GitHub

fix transformer pe (#5157)

* fix incorrect usage of data loader

* refine code

* remove unused code

* refine code
上级 d4e1f1a4
......@@ -33,7 +33,7 @@ def min_max_filer(data, max_len, min_len=0):
return (data_min_len >= min_len) and (data_max_len <= max_len)
def create_data_loader(args):
def create_data_loader(args, places=None):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
padding_vocab = (
......@@ -67,14 +67,14 @@ def create_data_loader(args):
data_loader = DataLoader(
dataset=dataset,
places=places,
batch_sampler=batch_sampler,
collate_fn=partial(
prepare_train_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
pad_idx=args.bos_idx),
num_workers=0,
return_list=True)
num_workers=0)
data_loaders[i] = (data_loader)
return data_loaders
......
......@@ -34,28 +34,10 @@ def parse_args():
return args
def batch_creator(loader, trainer_count):
batch = []
for data in loader:
batch.append(data)
if len(batch) == trainer_count:
yield batch
batch = []
# DO NOT drop last.
if len(batch) > 0:
while len(batch) < trainer_count:
batch.append(batch[-1])
yield batch
def do_train(args):
paddle.enable_static()
if args.use_gpu:
trainer_count = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
place = paddle.set_device("gpu:0")
else:
trainer_count = int(os.environ['CPU_NUM'])
place = paddle.set_device("cpu")
places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
trainer_count = len(places)
# Set seed for CE
random_seed = eval(str(args.random_seed))
......@@ -63,7 +45,7 @@ def do_train(args):
paddle.seed(random_seed)
# Define data loader
(train_loader), (eval_loader) = reader.create_data_loader(args)
(train_loader), (eval_loader) = reader.create_data_loader(args, places)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -108,11 +90,10 @@ def do_train(args):
optimizer.minimize(avg_cost)
exe = paddle.static.Executor(place)
exe = paddle.static.Executor()
exe.run(startup_program)
build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_inplace = True
exec_strategy = paddle.static.ExecutionStrategy()
compiled_train_program = paddle.static.CompiledProgram(
......@@ -138,7 +119,7 @@ def do_train(args):
batch_id = 0
batch_start = time.time()
pass_start_time = batch_start
for data in batch_creator(train_loader, trainer_count):
for data in train_loader():
# NOTE: used for benchmark and use None as default.
if args.max_iter and step_idx == args.max_iter:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册