Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fbea3918
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fbea3918
编写于
4月 12, 2017
作者:
T
Tao Luo
提交者:
GitHub
4月 12, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1774 from luotao1/wmt14
add wmt14 trg_dict
上级
6226b821
a42233c2
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
53 addition
and
41 deletion
+53
-41
demo/seqToseq/api_train_v2.py
demo/seqToseq/api_train_v2.py
+46
-40
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+7
-1
未找到文件。
demo/seqToseq/api_train_v2.py
浏览文件 @
fbea3918
...
...
@@ -126,33 +126,28 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False):
def
main
():
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
is_generating
=
True
# source and target dict dim.
dict_size
=
30000
source_dict_dim
=
target_dict_dim
=
dict_size
# define network topology
# train the network
if
not
is_generating
:
cost
=
seqToseq_net
(
source_dict_dim
,
target_dict_dim
)
parameters
=
paddle
.
parameters
.
create
(
cost
)
# define optimize method and trainer
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
5e-5
,
regularization
=
paddle
.
optimizer
.
L2Regularization
(
rate
=
1e-3
))
regularization
=
paddle
.
optimizer
.
L2Regularization
(
rate
=
8e-4
))
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# define data reader
feeding
=
{
'source_language_word'
:
0
,
'target_language_word'
:
1
,
'target_language_next_word'
:
2
}
wmt14_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
wmt14
.
train
(
dict_size
=
dict_size
),
buf_size
=
8192
),
paddle
.
dataset
.
wmt14
.
train
(
dict_size
),
buf_size
=
8192
),
batch_size
=
5
)
# define event_handler callback
...
...
@@ -160,17 +155,28 @@ def main():
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
10
==
0
:
print
"
\n
Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
# start to train
trainer
.
train
(
reader
=
wmt14_reader
,
event_handler
=
event_handler
,
num_passes
=
10000
,
feeding
=
feeding
)
reader
=
wmt14_reader
,
event_handler
=
event_handler
,
num_passes
=
2
)
# generate a english sequence to french
else
:
gen_creator
=
paddle
.
dataset
.
wmt14
.
test
(
dict_size
)
gen_data
=
[]
for
item
in
gen_creator
():
gen_data
.
append
((
item
[
0
],
))
if
len
(
gen_data
)
==
3
:
break
beam_gen
=
seqToseq_net
(
source_dict_dim
,
target_dict_dim
,
is_generating
)
parameters
=
paddle
.
dataset
.
wmt14
.
model
()
trg_dict
=
paddle
.
dataset
.
wmt14
.
trg_dict
(
dict_size
)
if
__name__
==
'__main__'
:
...
...
python/paddle/v2/dataset/wmt14.py
浏览文件 @
fbea3918
...
...
@@ -29,7 +29,7 @@ URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
MD5_TRAIN
=
'a755315dd01c2c35bde29a744ede23a6'
# this is the pretrained model, whose bleu = 26.92
URL_MODEL
=
'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL
=
'
6b097d23e15654608c6f74923e975535
'
MD5_MODEL
=
'
4ce14a26607fb8a1cc23bcdedb1895e4
'
START
=
"<s>"
END
=
"<e>"
...
...
@@ -115,6 +115,12 @@ def model():
return
parameters
def
trg_dict
(
dict_size
):
tar_file
=
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
src_dict
,
trg_dict
=
__read_to_dict__
(
tar_file
,
dict_size
)
return
trg_dict
def
fetch
():
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
download
(
URL_MODEL
,
'wmt14'
,
MD5_MODEL
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录