Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
64a7e1a0
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
64a7e1a0
编写于
3月 24, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update transformer
上级
0baa8f68
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
107 addition
and
104 deletion
+107
-104
transformer/predict.py
transformer/predict.py
+11
-8
transformer/reader.py
transformer/reader.py
+35
-23
transformer/train.py
transformer/train.py
+61
-73
未找到文件。
transformer/predict.py
浏览文件 @
64a7e1a0
...
@@ -88,14 +88,17 @@ def do_predict(args):
...
@@ -88,14 +88,17 @@ def do_predict(args):
# define model
# define model
inputs
=
[
inputs
=
[
Input
(
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
Input
(
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
"float32"
,
Input
(
name
=
"src_slf_attn_bias"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
"float32"
,
name
=
"src_slf_attn_bias"
),
name
=
"trg_src_attn_bias"
)
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_src_attn_bias"
),
]
]
transformer
=
InferTransformer
(
transformer
=
InferTransformer
(
args
.
src_vocab_size
,
args
.
src_vocab_size
,
...
...
transformer/reader.py
浏览文件 @
64a7e1a0
...
@@ -19,6 +19,15 @@ import tarfile
...
@@ -19,6 +19,15 @@ import tarfile
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
class
TokenBatchSampler
(
BatchSampler
):
def
__init__
(
self
):
pass
def
__iter
(
self
):
pass
def
pad_batch_data
(
insts
,
def
pad_batch_data
(
insts
,
...
@@ -54,7 +63,8 @@ def pad_batch_data(insts,
...
@@ -54,7 +63,8 @@ def pad_batch_data(insts,
if
is_target
:
if
is_target
:
# This is used to avoid attention on paddings and subsequent
# This is used to avoid attention on paddings and subsequent
# words.
# words.
slf_attn_bias_data
=
np
.
ones
((
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
ones
(
(
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_len
,
max_len
])
1
).
reshape
([
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
...
@@ -306,6 +316,7 @@ class DataProcessor(object):
...
@@ -306,6 +316,7 @@ class DataProcessor(object):
:param seed: The seed for random.
:param seed: The seed for random.
:type seed: int
:type seed: int
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
src_vocab_fpath
,
src_vocab_fpath
,
trg_vocab_fpath
,
trg_vocab_fpath
,
...
@@ -360,21 +371,23 @@ class DataProcessor(object):
...
@@ -360,21 +371,23 @@ class DataProcessor(object):
def
load_src_trg_ids
(
self
,
fpattern
,
tar_fname
):
def
load_src_trg_ids
(
self
,
fpattern
,
tar_fname
):
converters
=
[
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
Converter
(
beg
=
self
.
_bos_idx
,
vocab
=
self
.
_src_vocab
,
end
=
self
.
_eos_idx
,
beg
=
self
.
_bos_idx
,
unk
=
self
.
_unk_idx
,
end
=
self
.
_eos_idx
,
delimiter
=
self
.
_token_delimiter
,
unk
=
self
.
_unk_idx
,
add_beg
=
False
)
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
]
]
if
not
self
.
_only_src
:
if
not
self
.
_only_src
:
converters
.
append
(
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
Converter
(
beg
=
self
.
_bos_idx
,
vocab
=
self
.
_trg_vocab
,
end
=
self
.
_eos_idx
,
beg
=
self
.
_bos_idx
,
unk
=
self
.
_unk_idx
,
end
=
self
.
_eos_idx
,
delimiter
=
self
.
_token_delimiter
,
unk
=
self
.
_unk_idx
,
add_beg
=
True
))
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
converters
=
ComposedConverter
(
converters
)
converters
=
ComposedConverter
(
converters
)
...
@@ -402,9 +415,8 @@ class DataProcessor(object):
...
@@ -402,9 +415,8 @@ class DataProcessor(object):
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
for
line
in
f
.
extractfile
(
tar_fname
):
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
self
.
_only_src
and
len
(
fields
)
==
1
):
and
len
(
fields
)
==
1
):
yield
fields
yield
fields
else
:
else
:
for
fpath
in
fpaths
:
for
fpath
in
fpaths
:
...
@@ -414,9 +426,8 @@ class DataProcessor(object):
...
@@ -414,9 +426,8 @@ class DataProcessor(object):
with
open
(
fpath
,
"rb"
)
as
f
:
with
open
(
fpath
,
"rb"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
self
.
_only_src
and
len
(
fields
)
==
1
):
and
len
(
fields
)
==
1
):
yield
fields
yield
fields
@
staticmethod
@
staticmethod
...
@@ -477,7 +488,8 @@ class DataProcessor(object):
...
@@ -477,7 +488,8 @@ class DataProcessor(object):
if
self
.
_only_src
:
if
self
.
_only_src
:
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
else
:
else
:
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
return
__impl__
return
__impl__
...
@@ -512,8 +524,8 @@ class DataProcessor(object):
...
@@ -512,8 +524,8 @@ class DataProcessor(object):
for
item
in
data_reader
():
for
item
in
data_reader
():
inst_num_per_part
=
len
(
item
)
//
count
inst_num_per_part
=
len
(
item
)
//
count
for
i
in
range
(
count
):
for
i
in
range
(
count
):
yield
item
[
inst_num_per_part
*
i
:
inst_num_per_part
*
yield
item
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
(
i
+
1
)]
)]
return
__impl__
return
__impl__
...
@@ -535,7 +547,7 @@ class DataProcessor(object):
...
@@ -535,7 +547,7 @@ class DataProcessor(object):
for
data
in
data_reader
():
for
data
in
data_reader
():
data_inputs
=
prepare_train_input
(
data
,
src_pad_idx
,
data_inputs
=
prepare_train_input
(
data
,
src_pad_idx
,
trg_pad_idx
,
n_head
)
trg_pad_idx
,
n_head
)
yield
data_inputs
yield
data_inputs
[:
-
2
],
data_inputs
[
-
2
:]
def
__for_predict__
():
def
__for_predict__
():
for
data
in
data_reader
():
for
data
in
data_reader
():
...
...
transformer/train.py
浏览文件 @
64a7e1a0
...
@@ -32,9 +32,35 @@ from utils.check import check_gpu, check_version
...
@@ -32,9 +32,35 @@ from utils.check import check_gpu, check_version
import
reader
import
reader
from
transformer
import
Transformer
,
CrossEntropyCriterion
,
NoamDecay
from
transformer
import
Transformer
,
CrossEntropyCriterion
,
NoamDecay
from
model
import
Input
from
model
import
Input
from
callbacks
import
ProgBarLogger
class
LoggerCallback
(
ProgBarLogger
):
def
__init__
(
self
,
log_freq
=
1
,
verbose
=
2
,
loss_normalizer
=
0.
):
super
(
LoggerCallback
,
self
).
__init__
(
log_freq
,
verbose
)
self
.
loss_normalizer
=
loss_normalizer
def
on_train_begin
(
self
,
logs
=
None
):
super
(
LoggerCallback
,
self
).
on_train_begin
(
logs
)
self
.
train_metrics
+=
[
"normalized loss"
,
"ppl"
]
def
on_train_batch_end
(
self
,
step
,
logs
=
None
):
logs
[
"normalized loss"
]
=
logs
[
"loss"
][
0
]
-
self
.
loss_normalizer
logs
[
"ppl"
]
=
np
.
exp
(
min
(
logs
[
"loss"
][
0
],
100
))
super
(
LoggerCallback
,
self
).
on_train_batch_end
(
step
,
logs
)
def
on_eval_begin
(
self
,
logs
=
None
):
super
(
LoggerCallback
,
self
).
on_eval_begin
(
logs
)
self
.
eval_metrics
+=
[
"normalized loss"
,
"ppl"
]
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
logs
[
"normalized loss"
]
=
logs
[
"loss"
][
0
]
-
self
.
loss_normalizer
logs
[
"ppl"
]
=
np
.
exp
(
min
(
logs
[
"loss"
][
0
],
100
))
super
(
LoggerCallback
,
self
).
on_eval_batch_end
(
step
,
logs
)
def
do_train
(
args
):
def
do_train
(
args
):
init_context
(
'dynamic'
if
FLAGS
.
dynamic
else
'static'
)
trainer_count
=
1
#get_nranks()
trainer_count
=
1
#get_nranks()
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
...
@@ -102,24 +128,31 @@ def do_train(args):
...
@@ -102,24 +128,31 @@ def do_train(args):
# define model
# define model
inputs
=
[
inputs
=
[
Input
(
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
Input
(
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
"float32"
,
Input
(
name
=
"src_slf_attn_bias"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
[
None
,
None
],
"int64"
,
name
=
"trg_word"
),
Input
(
"float32"
,
[
None
,
None
],
"int64"
,
name
=
"trg_pos"
),
Input
(
name
=
"src_slf_attn_bias"
),
[
None
,
args
.
n_head
,
None
,
None
],
Input
(
"float32"
,
[
None
,
None
],
"int64"
,
name
=
"trg_word"
),
name
=
"trg_slf_attn_bias"
),
Input
(
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
[
None
,
None
],
"int64"
,
name
=
"trg_pos"
),
"float32"
,
Input
(
name
=
"trg_src_attn_bias"
)
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_slf_attn_bias"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_src_attn_bias"
),
]
]
labels
=
[
labels
=
[
Input
(
Input
(
[
None
,
1
],
"int64"
,
name
=
"label"
),
Input
(
[
None
,
1
],
"int64"
,
name
=
"label"
),
[
None
,
1
],
"float32"
,
name
=
"weight"
)
Input
(
[
None
,
1
],
"float32"
,
name
=
"weight"
),
]
]
transformer
=
Transformer
(
transformer
=
Transformer
(
...
@@ -149,7 +182,8 @@ def do_train(args):
...
@@ -149,7 +182,8 @@ def do_train(args):
## init from some pretrain models, to better solve the current task
## init from some pretrain models, to better solve the current task
if
args
.
init_from_pretrain_model
:
if
args
.
init_from_pretrain_model
:
transformer
.
load
(
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
"transformer"
))
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
"transformer"
),
reset_optimizer
=
True
)
# the best cross-entropy value with label smoothing
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
loss_normalizer
=
-
(
...
@@ -157,63 +191,17 @@ def do_train(args):
...
@@ -157,63 +191,17 @@ def do_train(args):
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
step_idx
=
0
transformer
.
fit
(
train_loader
=
train_loader
,
# train loop
eval_loader
=
val_loader
,
for
pass_id
in
range
(
args
.
epoch
):
epochs
=
1
,
pass_start_time
=
time
.
time
()
eval_freq
=
1
,
batch_id
=
0
save_freq
=
1
,
for
input_data
in
train_loader
():
verbose
=
2
,
losses
=
transformer
.
train
(
input_data
[:
-
2
],
input_data
[
-
2
:])
callbacks
=
[
LoggerCallback
(
if
step_idx
%
args
.
print_step
==
0
:
log_freq
=
args
.
print_step
,
total_avg_cost
=
np
.
sum
(
losses
)
loss_normalizer
=
loss_normalizer
)
])
if
step_idx
==
0
:
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f"
%
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
avg_batch_time
=
time
.
time
()
else
:
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)]),
args
.
print_step
/
(
time
.
time
()
-
avg_batch_time
)))
avg_batch_time
=
time
.
time
()
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
:
# validation: how to accumulate with Model loss
if
args
.
validation_file
:
total_avg_cost
=
0
for
idx
,
input_data
in
enumerate
(
val_loader
()):
losses
=
transformer
.
eval
(
input_data
[:
-
2
],
input_data
[
-
2
:])
total_avg_cost
+=
np
.
sum
(
losses
)
total_avg_cost
/=
idx
+
1
logging
.
info
(
"validation, step_idx: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f"
%
(
step_idx
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
transformer
.
save
(
os
.
path
.
join
(
args
.
save_model
,
"step_"
+
str
(
step_idx
),
"transformer"
))
batch_id
+=
1
step_idx
+=
1
time_consumed
=
time
.
time
()
-
pass_start_time
if
args
.
save_model
:
transformer
.
save
(
os
.
path
.
join
(
args
.
save_model
,
"step_final"
,
"transformer"
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录