Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a0e6f416
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0e6f416
编写于
9月 04, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Port transformer
上级
c838fa31
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
20 addition
and
14 deletion
+20
-14
python/paddle/fluid/tests/unittests/dist_transformer.py
python/paddle/fluid/tests/unittests/dist_transformer.py
+20
-14
未找到文件。
python/paddle/fluid/tests/unittests/dist_transformer.py
浏览文件 @
a0e6f416
...
...
@@ -36,6 +36,7 @@ import paddle.fluid as fluid
import
paddle.fluid.layers
as
layers
from
paddle.fluid
import
core
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
import
paddle.compat
as
cpt
from
paddle.compat
import
long_type
import
hashlib
...
...
@@ -315,8 +316,9 @@ def pad_batch_data(insts,
"""
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
num_token
=
reduce
(
lambda
x
,
y
:
x
+
y
,
[
len
(
inst
)
for
inst
in
insts
])
if
return_num_token
else
0
num_token
=
six
.
moves
.
reduce
(
lambda
x
,
y
:
x
+
y
,
[
len
(
inst
)
for
inst
in
insts
])
if
return_num_token
else
0
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data
=
np
.
array
(
...
...
@@ -328,7 +330,7 @@ def pad_batch_data(insts,
return_list
+=
[
inst_weight
.
astype
(
"float32"
).
reshape
([
-
1
,
1
])]
else
:
# position data
inst_pos
=
np
.
array
([
range
(
1
,
len
(
inst
)
+
1
)
+
[
0
]
*
(
max_len
-
len
(
inst
))
list
(
range
(
1
,
len
(
inst
)
+
1
)
)
+
[
0
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
...
...
@@ -385,10 +387,11 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
return_num_token
=
True
)
data_input_dict
=
dict
(
zip
(
data_input_names
,
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]))
list
(
zip
(
data_input_names
,
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
])))
return
data_input_dict
,
np
.
asarray
([
num_token
],
dtype
=
"float32"
)
...
...
@@ -561,7 +564,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
np
.
log
(
TrainTaskConfig
.
label_smooth_eps
/
(
ModelHyperParams
.
trg_vocab_size
-
1
)
+
1e-20
))
init
=
False
for
pass_id
in
xrange
(
TrainTaskConfig
.
pass_num
):
for
pass_id
in
six
.
moves
.
xrange
(
TrainTaskConfig
.
pass_num
):
pass_start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_data
()):
if
batch_id
>=
5
:
...
...
@@ -587,11 +590,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
total_num_token
+=
num_token
feed_kv_pairs
=
data_input_dict
.
items
(
)
feed_kv_pairs
=
list
(
data_input_dict
.
items
()
)
if
TrainTaskConfig
.
local
:
feed_kv_pairs
+=
{
feed_kv_pairs
+=
list
(
{
lr_scheduler
.
learning_rate
.
name
:
lr_rate
}.
items
()
}.
items
()
)
feed_list
.
append
(
dict
(
feed_kv_pairs
))
if
not
init
:
...
...
@@ -873,6 +876,7 @@ class DataReader(object):
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
for
line
in
f
.
extractfile
(
tar_fname
):
line
=
cpt
.
to_text
(
line
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
...
...
@@ -882,8 +886,9 @@ class DataReader(object):
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"r"
)
as
f
:
with
open
(
fpath
,
"r
b
"
)
as
f
:
for
line
in
f
:
line
=
cpt
.
to_text
(
line
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
...
...
@@ -892,8 +897,9 @@ class DataReader(object):
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
with
open
(
dict_path
,
"r"
)
as
fdict
:
with
open
(
dict_path
,
"r
b
"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
line
=
cpt
.
to_text
(
line
)
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
else
:
...
...
@@ -1034,7 +1040,7 @@ def multi_head_attention(queries,
# size of the input as the output dimension size.
return
layers
.
reshape
(
x
=
trans_x
,
shape
=
map
(
int
,
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]]
))
shape
=
list
(
map
(
int
,
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]])
))
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录