Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c51ab429
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c51ab429
编写于
4月 14, 2017
作者:
T
Tao Luo
提交者:
GitHub
4月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1784 from luotao1/beam
add seqtext_print for seqToseq demo
上级
92edc2d8
555b2dfd
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
48 addition
and
7 deletion
+48
-7
demo/seqToseq/api_train_v2.py
demo/seqToseq/api_train_v2.py
+35
-4
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+13
-3
未找到文件。
demo/seqToseq/api_train_v2.py
浏览文件 @
c51ab429
...
...
@@ -126,7 +126,7 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
def
main
():
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
is_generating
=
Tru
e
is_generating
=
Fals
e
# source and target dict dim.
dict_size
=
30000
...
...
@@ -167,16 +167,47 @@ def main():
# generate a english sequence to french
else
:
gen_creator
=
paddle
.
dataset
.
wmt14
.
test
(
dict_size
)
# use the first 3 samples for generation
gen_creator
=
paddle
.
dataset
.
wmt14
.
gen
(
dict_size
)
gen_data
=
[]
gen_num
=
3
for
item
in
gen_creator
():
gen_data
.
append
((
item
[
0
],
))
if
len
(
gen_data
)
==
3
:
if
len
(
gen_data
)
==
gen_num
:
break
beam_gen
=
seqToseq_net
(
source_dict_dim
,
target_dict_dim
,
is_generating
)
# get the pretrained model, whose bleu = 26.92
parameters
=
paddle
.
dataset
.
wmt14
.
model
()
trg_dict
=
paddle
.
dataset
.
wmt14
.
trg_dict
(
dict_size
)
# prob is the prediction probabilities, and id is the prediction word.
beam_result
=
paddle
.
infer
(
output_layer
=
beam_gen
,
parameters
=
parameters
,
input
=
gen_data
,
field
=
[
'prob'
,
'id'
])
# get the dictionary
src_dict
,
trg_dict
=
paddle
.
dataset
.
wmt14
.
get_dict
(
dict_size
)
# the delimited element of generated sequences is -1,
# the first element of each generated sequence is the sequence length
seq_list
=
[]
seq
=
[]
for
w
in
beam_result
[
1
]:
if
w
!=
-
1
:
seq
.
append
(
w
)
else
:
seq_list
.
append
(
' '
.
join
([
trg_dict
.
get
(
w
)
for
w
in
seq
[
1
:]]))
seq
=
[]
prob
=
beam_result
[
0
]
beam_size
=
3
for
i
in
xrange
(
gen_num
):
print
"
\n
*******************************************************
\n
"
print
"src:"
,
' '
.
join
(
[
src_dict
.
get
(
w
)
for
w
in
gen_data
[
i
][
0
]]),
"
\n
"
for
j
in
xrange
(
beam_size
):
print
"prob = %f:"
%
(
prob
[
i
][
j
]),
seq_list
[
i
*
beam_size
+
j
]
if
__name__
==
'__main__'
:
...
...
python/paddle/v2/dataset/wmt14.py
浏览文件 @
c51ab429
...
...
@@ -26,7 +26,7 @@ URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/de
MD5_DEV_TEST
=
'7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN
=
'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
MD5_TRAIN
=
'
a755315dd01c2c35bde29a744ede23a6
'
MD5_TRAIN
=
'
0791583d57d5beb693b9414c5b36798c
'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL
=
'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL
=
'4ce14a26607fb8a1cc23bcdedb1895e4'
...
...
@@ -108,6 +108,11 @@ def test(dict_size):
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'test/test'
,
dict_size
)
def
gen
(
dict_size
):
return
reader_creator
(
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'gen/gen'
,
dict_size
)
def
model
():
tar_file
=
download
(
URL_MODEL
,
'wmt14'
,
MD5_MODEL
)
with
gzip
.
open
(
tar_file
,
'r'
)
as
f
:
...
...
@@ -115,10 +120,15 @@ def model():
return
parameters
def
trg_dict
(
dict_size
):
def
get_dict
(
dict_size
,
reverse
=
True
):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file
=
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
src_dict
,
trg_dict
=
__read_to_dict__
(
tar_file
,
dict_size
)
return
trg_dict
if
reverse
:
src_dict
=
{
v
:
k
for
k
,
v
in
src_dict
.
items
()}
trg_dict
=
{
v
:
k
for
k
,
v
in
trg_dict
.
items
()}
return
src_dict
,
trg_dict
def
fetch
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录