提交 2d065ac7 编写于 作者: S smallv0221

fix warmup step bug and qa padding

上级 bec2933a
...@@ -181,7 +181,7 @@ def do_train(args): ...@@ -181,7 +181,7 @@ def do_train(args):
args.learning_rate, args.learning_rate,
lambda current_step, warmup_proportion=args.warmup_proportion, lambda current_step, warmup_proportion=args.warmup_proportion,
num_training_steps=args.max_steps if args.max_steps > 0 else num_training_steps=args.max_steps if args.max_steps > 0 else
(len(train_ds.examples)//args.batch_size*args.num_train_epochs): float( (len(train_data_loader)*args.num_train_epochs): float(
current_step) / float(max(1, warmup_proportion*num_training_steps)) current_step) / float(max(1, warmup_proportion*num_training_steps))
if current_step < warmup_proportion*num_training_steps else max( if current_step < warmup_proportion*num_training_steps else max(
0.0, 0.0,
......
...@@ -243,7 +243,14 @@ class SQuAD(Dataset): ...@@ -243,7 +243,14 @@ class SQuAD(Dataset):
segment_ids.append(1) segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = input_ids + [
tokenizer.vocab[tokenizer.pad_token]
for _ in range(self.max_seq_length - len(input_ids))
]
segment_ids = segment_ids + [
tokenizer.vocab[tokenizer.pad_token]
for _ in range(self.max_seq_length - len(segment_ids))
]
input_mask = [1] * len(input_ids) input_mask = [1] * len(input_ids)
start_position = None start_position = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册