未验证 提交 6d8ba1ef 编写于 作者: L liu zhengxi 提交者: GitHub

[Perform] Apply pad factor to pad vocab (#5134)

* add pad factor to pad vocab
上级 60a71fe4
......@@ -66,6 +66,8 @@ n_best: 1
src_vocab_size: 10000
# Size of target word dictionay
trg_vocab_size: 10000
# Used to pad vocab size to be multiple of pad_factor.
pad_factor: 8
# Index for <bos> token
bos_idx: 0
# Index for <eos> token
......
......@@ -36,7 +36,11 @@ def min_max_filer(data, max_len, min_len=0):
def create_data_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [
WMT14ende.get_datasets(
......@@ -107,7 +111,11 @@ def create_data_loader(args):
def create_infer_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets(
mode="test", transform_func=transform_func).filter(
......
......@@ -66,6 +66,8 @@ n_best: 1
src_vocab_size: 10000
# Size of target word dictionay
trg_vocab_size: 10000
# Used to pad vocab size to be multiple of pad_factor.
pad_factor: 8
# Index for <bos> token
bos_idx: 0
# Index for <eos> token
......
......@@ -66,6 +66,8 @@ n_best: 1
src_vocab_size: 10000
# Size of target word dictionay
trg_vocab_size: 10000
# Used to pad vocab size to be multiple of pad_factor.
pad_factor: 8
# Index for <bos> token
bos_idx: 0
# Index for <eos> token
......
......@@ -36,7 +36,11 @@ def min_max_filer(data, max_len, min_len=0):
def create_data_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
datasets = [
WMT14ende.get_datasets(
......@@ -107,7 +111,11 @@ def create_data_loader(args):
def create_infer_loader(args):
root = None if args.root == "None" else args.root
(src_vocab, trg_vocab) = WMT14ende.get_vocab(root=root)
args.src_vocab_size, args.trg_vocab_size = len(src_vocab), len(trg_vocab)
padding_vocab = (
lambda x: (x + args.pad_factor - 1) // args.pad_factor * args.pad_factor
)
args.src_vocab_size = padding_vocab(len(src_vocab))
args.trg_vocab_size = padding_vocab(len(trg_vocab))
transform_func = WMT14ende.get_default_transform_func(root=root)
dataset = WMT14ende.get_datasets(
mode="test", transform_func=transform_func).filter(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册