Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6a31509d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a31509d
编写于
9月 03, 2018
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Port release 0.15.0 code to python3
上级
8f2ce7ca
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
25 addition
and
19 deletion
+25
-19
python/paddle/fluid/tests/unittests/dist_transformer.py
python/paddle/fluid/tests/unittests/dist_transformer.py
+20
-14
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+5
-5
未找到文件。
python/paddle/fluid/tests/unittests/dist_transformer.py
浏览文件 @
6a31509d
...
@@ -36,6 +36,7 @@ import paddle.fluid as fluid
...
@@ -36,6 +36,7 @@ import paddle.fluid as fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.layers
as
layers
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
import
paddle.compat
as
cpt
from
paddle.compat
import
long_type
from
paddle.compat
import
long_type
import
hashlib
import
hashlib
...
@@ -315,8 +316,9 @@ def pad_batch_data(insts,
...
@@ -315,8 +316,9 @@ def pad_batch_data(insts,
"""
"""
return_list
=
[]
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
num_token
=
reduce
(
lambda
x
,
y
:
x
+
y
,
num_token
=
six
.
moves
.
reduce
(
[
len
(
inst
)
for
inst
in
insts
])
if
return_num_token
else
0
lambda
x
,
y
:
x
+
y
,
[
len
(
inst
)
for
inst
in
insts
])
if
return_num_token
else
0
# Any token included in dict can be used to pad, since the paddings' loss
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
# will be masked out by weights and make no effect on parameter gradients.
inst_data
=
np
.
array
(
inst_data
=
np
.
array
(
...
@@ -328,7 +330,7 @@ def pad_batch_data(insts,
...
@@ -328,7 +330,7 @@ def pad_batch_data(insts,
return_list
+=
[
inst_weight
.
astype
(
"float32"
).
reshape
([
-
1
,
1
])]
return_list
+=
[
inst_weight
.
astype
(
"float32"
).
reshape
([
-
1
,
1
])]
else
:
# position data
else
:
# position data
inst_pos
=
np
.
array
([
inst_pos
=
np
.
array
([
range
(
1
,
len
(
inst
)
+
1
)
+
[
0
]
*
(
max_len
-
len
(
inst
))
list
(
range
(
1
,
len
(
inst
)
+
1
)
)
+
[
0
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
for
inst
in
insts
])
])
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
...
@@ -385,10 +387,11 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
...
@@ -385,10 +387,11 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
return_num_token
=
True
)
return_num_token
=
True
)
data_input_dict
=
dict
(
data_input_dict
=
dict
(
zip
(
data_input_names
,
[
list
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
zip
(
data_input_names
,
[
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
]))
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
])))
return
data_input_dict
,
np
.
asarray
([
num_token
],
dtype
=
"float32"
)
return
data_input_dict
,
np
.
asarray
([
num_token
],
dtype
=
"float32"
)
...
@@ -561,7 +564,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
...
@@ -561,7 +564,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
np
.
log
(
TrainTaskConfig
.
label_smooth_eps
/
(
np
.
log
(
TrainTaskConfig
.
label_smooth_eps
/
(
ModelHyperParams
.
trg_vocab_size
-
1
)
+
1e-20
))
ModelHyperParams
.
trg_vocab_size
-
1
)
+
1e-20
))
init
=
False
init
=
False
for
pass_id
in
xrange
(
TrainTaskConfig
.
pass_num
):
for
pass_id
in
six
.
moves
.
xrange
(
TrainTaskConfig
.
pass_num
):
pass_start_time
=
time
.
time
()
pass_start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_data
()):
for
batch_id
,
data
in
enumerate
(
train_data
()):
if
batch_id
>=
5
:
if
batch_id
>=
5
:
...
@@ -587,11 +590,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
...
@@ -587,11 +590,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
ModelHyperParams
.
d_model
)
total_num_token
+=
num_token
total_num_token
+=
num_token
feed_kv_pairs
=
data_input_dict
.
items
(
)
feed_kv_pairs
=
list
(
data_input_dict
.
items
()
)
if
TrainTaskConfig
.
local
:
if
TrainTaskConfig
.
local
:
feed_kv_pairs
+=
{
feed_kv_pairs
+=
list
(
{
lr_scheduler
.
learning_rate
.
name
:
lr_rate
lr_scheduler
.
learning_rate
.
name
:
lr_rate
}.
items
()
}.
items
()
)
feed_list
.
append
(
dict
(
feed_kv_pairs
))
feed_list
.
append
(
dict
(
feed_kv_pairs
))
if
not
init
:
if
not
init
:
...
@@ -873,6 +876,7 @@ class DataReader(object):
...
@@ -873,6 +876,7 @@ class DataReader(object):
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
for
line
in
f
.
extractfile
(
tar_fname
):
for
line
in
f
.
extractfile
(
tar_fname
):
line
=
cpt
.
to_text
(
line
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
self
.
_only_src
and
len
(
fields
)
==
1
):
...
@@ -882,8 +886,9 @@ class DataReader(object):
...
@@ -882,8 +886,9 @@ class DataReader(object):
if
not
os
.
path
.
isfile
(
fpath
):
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"r"
)
as
f
:
with
open
(
fpath
,
"r
b
"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
line
=
cpt
.
to_text
(
line
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
self
.
_only_src
and
len
(
fields
)
==
1
):
...
@@ -892,8 +897,9 @@ class DataReader(object):
...
@@ -892,8 +897,9 @@ class DataReader(object):
@
staticmethod
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
word_dict
=
{}
with
open
(
dict_path
,
"r"
)
as
fdict
:
with
open
(
dict_path
,
"r
b
"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
for
idx
,
line
in
enumerate
(
fdict
):
line
=
cpt
.
to_text
(
line
)
if
reverse
:
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
else
:
else
:
...
@@ -1034,7 +1040,7 @@ def multi_head_attention(queries,
...
@@ -1034,7 +1040,7 @@ def multi_head_attention(queries,
# size of the input as the output dimension size.
# size of the input as the output dimension size.
return
layers
.
reshape
(
return
layers
.
reshape
(
x
=
trans_x
,
x
=
trans_x
,
shape
=
map
(
int
,
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]]
))
shape
=
list
(
map
(
int
,
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]])
))
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
):
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_model
,
dropout_rate
):
"""
"""
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
6a31509d
...
@@ -293,7 +293,7 @@ class DistributeTranspiler(object):
...
@@ -293,7 +293,7 @@ class DistributeTranspiler(object):
input_deps
=
grad_name_to_send_dummy_out
.
values
()
input_deps
=
grad_name_to_send_dummy_out
.
values
()
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"send_barrier"
,
type
=
"send_barrier"
,
inputs
=
{
"X"
:
input_deps
},
inputs
=
{
"X"
:
list
(
input_deps
)
},
outputs
=
{
"Out"
:
send_barrier_out
},
outputs
=
{
"Out"
:
send_barrier_out
},
attrs
=
{
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"endpoints"
:
pserver_endpoints
,
...
@@ -394,7 +394,7 @@ class DistributeTranspiler(object):
...
@@ -394,7 +394,7 @@ class DistributeTranspiler(object):
Args:
Args:
recv_vars (list): Variable list to recv for current trainer_id
recv_vars (list): Variable list to recv for current trainer_id
eplist (list): A list of strings indicating
eplist (list): A list of strings indicating
Returns:
Returns:
Program: trainer side startup program.
Program: trainer side startup program.
...
@@ -448,7 +448,7 @@ class DistributeTranspiler(object):
...
@@ -448,7 +448,7 @@ class DistributeTranspiler(object):
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
# NOTE: if enable memory optimization, origin vars maybe removed.
# NOTE: if enable memory optimization, origin vars maybe removed.
if
startup_program
.
global_block
().
vars
.
has_key
(
varname
)
:
if
varname
in
startup_program
.
global_block
().
vars
:
orig_param
=
startup_program
.
global_block
().
vars
[
varname
]
orig_param
=
startup_program
.
global_block
().
vars
[
varname
]
else
:
else
:
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
...
@@ -677,7 +677,7 @@ class DistributeTranspiler(object):
...
@@ -677,7 +677,7 @@ class DistributeTranspiler(object):
Args:
Args:
endpoint (str): current pserver endpoint.
endpoint (str): current pserver endpoint.
Returns:
Returns:
tuple: (main_program, startup_program), of type "Program"
tuple: (main_program, startup_program), of type "Program"
"""
"""
...
@@ -700,7 +700,7 @@ class DistributeTranspiler(object):
...
@@ -700,7 +700,7 @@ class DistributeTranspiler(object):
endpoint (str): current pserver endpoint.
endpoint (str): current pserver endpoint.
pserver_program (Program): deprecated, call get_pserver_program first.
pserver_program (Program): deprecated, call get_pserver_program first.
startup_program (Program): deprecated, should pass startup_program
startup_program (Program): deprecated, should pass startup_program
when initalizing
when initalizing
Returns:
Returns:
Program: parameter server side startup program.
Program: parameter server side startup program.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录