未验证 提交 73cb3a0f 编写于 作者: L Leo Chen 提交者: GitHub

fix transformer pe (#5157)

* fix incorrect usage of data loader

* refine code

* remove unused code

* refine code
上级 d4e1f1a4
...@@ -33,7 +33,7 @@ def min_max_filer(data, max_len, min_len=0): ...@@ -33,7 +33,7 @@ def min_max_filer(data, max_len, min_len=0):
return (data_min_len >= min_len) and (data_max_len <= max_len) return (data_min_len >= min_len) and (data_max_len <= max_len)
def create_data_loader(args): def create_data_loader(args, places=None):
root = None if args.root == "None" else args.root root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root) (src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
padding_vocab = ( padding_vocab = (
...@@ -67,14 +67,14 @@ def create_data_loader(args): ...@@ -67,14 +67,14 @@ def create_data_loader(args):
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
places=places,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=partial( collate_fn=partial(
prepare_train_input, prepare_train_input,
bos_idx=args.bos_idx, bos_idx=args.bos_idx,
eos_idx=args.eos_idx, eos_idx=args.eos_idx,
pad_idx=args.bos_idx), pad_idx=args.bos_idx),
num_workers=0, num_workers=0)
return_list=True)
data_loaders[i] = (data_loader) data_loaders[i] = (data_loader)
return data_loaders return data_loaders
......
...@@ -34,28 +34,10 @@ def parse_args(): ...@@ -34,28 +34,10 @@ def parse_args():
return args return args
def batch_creator(loader, trainer_count):
batch = []
for data in loader:
batch.append(data)
if len(batch) == trainer_count:
yield batch
batch = []
# DO NOT drop last.
if len(batch) > 0:
while len(batch) < trainer_count:
batch.append(batch[-1])
yield batch
def do_train(args): def do_train(args):
paddle.enable_static() paddle.enable_static()
if args.use_gpu: places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
trainer_count = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) trainer_count = len(places)
place = paddle.set_device("gpu:0")
else:
trainer_count = int(os.environ['CPU_NUM'])
place = paddle.set_device("cpu")
# Set seed for CE # Set seed for CE
random_seed = eval(str(args.random_seed)) random_seed = eval(str(args.random_seed))
...@@ -63,7 +45,7 @@ def do_train(args): ...@@ -63,7 +45,7 @@ def do_train(args):
paddle.seed(random_seed) paddle.seed(random_seed)
# Define data loader # Define data loader
(train_loader), (eval_loader) = reader.create_data_loader(args) (train_loader), (eval_loader) = reader.create_data_loader(args, places)
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -108,11 +90,10 @@ def do_train(args): ...@@ -108,11 +90,10 @@ def do_train(args):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
exe = paddle.static.Executor(place) exe = paddle.static.Executor()
exe.run(startup_program) exe.run(startup_program)
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_inplace = True
exec_strategy = paddle.static.ExecutionStrategy() exec_strategy = paddle.static.ExecutionStrategy()
compiled_train_program = paddle.static.CompiledProgram( compiled_train_program = paddle.static.CompiledProgram(
...@@ -138,7 +119,7 @@ def do_train(args): ...@@ -138,7 +119,7 @@ def do_train(args):
batch_id = 0 batch_id = 0
batch_start = time.time() batch_start = time.time()
pass_start_time = batch_start pass_start_time = batch_start
for data in batch_creator(train_loader, trainer_count): for data in train_loader():
# NOTE: used for benchmark and use None as default. # NOTE: used for benchmark and use None as default.
if args.max_iter and step_idx == args.max_iter: if args.max_iter and step_idx == args.max_iter:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册