Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
26ff946f
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
26ff946f
编写于
3月 30, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update Transformer dataloader, fit, parallel.
上级
0e47f4c4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
28 deletion
+26
-28
transformer/reader.py
transformer/reader.py
+25
-28
transformer/train.py
transformer/train.py
+1
-0
未找到文件。
transformer/reader.py
浏览文件 @
26ff946f
...
...
@@ -257,23 +257,21 @@ class Seq2SeqDataset(Dataset):
def
load_src_trg_ids
(
self
,
fpattern
,
tar_fname
):
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
]
if
not
self
.
_only_src
:
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
converters
=
ComposedConverter
(
converters
)
...
...
@@ -301,8 +299,9 @@ class Seq2SeqDataset(Dataset):
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
for
fpath
in
fpaths
:
...
...
@@ -332,7 +331,8 @@ class Seq2SeqDataset(Dataset):
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:]
)
if
not
self
.
_only_src
else
self
.
_src_seq_ids
[
idx
]
def
__len__
(
self
):
...
...
@@ -365,13 +365,14 @@ class Seq2SeqBatchSampler(BatchSampler):
def
__iter__
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
dataset
.
_sample_infos
infos
=
self
.
_
dataset
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
dataset
.
_sample_infos
infos
=
self
.
_
dataset
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
...
...
@@ -385,9 +386,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
...
...
@@ -422,8 +423,4 @@ class Seq2SeqBatchSampler(BatchSampler):
yield
batch_indices
def
__len__
(
self
):
pass
@
property
def
dev_id
(
self
):
return
self
.
_dev_id
return
100
transformer/train.py
浏览文件 @
26ff946f
...
...
@@ -123,6 +123,7 @@ def do_train(args):
num_workers
=
0
,
return_list
=
True
)
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录