Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9a97c7f7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9a97c7f7
编写于
1月 17, 2018
作者:
Y
ying
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add wmt16 into dataset.
上级
38c61053
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
482 addition
and
20 deletion
+482
-20
python/paddle/v2/dataset/__init__.py
python/paddle/v2/dataset/__init__.py
+14
-2
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+15
-6
python/paddle/v2/dataset/tests/wmt16_test.py
python/paddle/v2/dataset/tests/wmt16_test.py
+66
-0
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+13
-5
python/paddle/v2/dataset/wmt16.py
python/paddle/v2/dataset/wmt16.py
+348
-0
python/paddle/v2/fluid/layers/control_flow.py
python/paddle/v2/fluid/layers/control_flow.py
+26
-7
未找到文件。
python/paddle/v2/dataset/__init__.py
浏览文件 @
9a97c7f7
...
...
@@ -24,11 +24,23 @@ import conll05
import
uci_housing
import
sentiment
import
wmt14
import
wmt16
import
mq2007
import
flowers
import
voc2012
__all__
=
[
'mnist'
,
'imikolov'
,
'imdb'
,
'cifar'
,
'movielens'
,
'conll05'
,
'sentiment'
'uci_housing'
,
'wmt14'
,
'mq2007'
,
'flowers'
,
'voc2012'
'mnist'
,
'imikolov'
,
'imdb'
,
'cifar'
,
'movielens'
,
'conll05'
,
'sentiment'
'uci_housing'
,
'wmt14'
,
'wmt16'
,
'mq2007'
,
'flowers'
,
'voc2012'
,
]
python/paddle/v2/dataset/common.py
浏览文件 @
9a97c7f7
...
...
@@ -25,8 +25,12 @@ import glob
import
cPickle
as
pickle
__all__
=
[
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
,
'convert'
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
,
'convert'
,
]
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
...
...
@@ -58,12 +62,15 @@ def md5file(fname):
return
hash_md5
.
hexdigest
()
def
download
(
url
,
module_name
,
md5sum
):
def
download
(
url
,
module_name
,
md5sum
,
save_name
=
None
):
dirname
=
os
.
path
.
join
(
DATA_HOME
,
module_name
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
filename
=
os
.
path
.
join
(
dirname
,
url
.
split
(
'/'
)[
-
1
])
filename
=
os
.
path
.
join
(
dirname
,
url
.
split
(
'/'
)[
-
1
]
if
save_name
is
None
else
save_name
)
retry
=
0
retry_limit
=
3
while
not
(
os
.
path
.
exists
(
filename
)
and
md5file
(
filename
)
==
md5sum
):
...
...
@@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix):
Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param reader: a data reader, from which the convert program will read
data instances.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
:param max_lines_to_shuffle: the max lines numbers to shuffle before
writing.
"""
assert
line_count
>=
1
...
...
python/paddle/v2/dataset/tests/wmt16_test.py
0 → 100644
浏览文件 @
9a97c7f7
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.v2.dataset.wmt16
import
unittest
class
TestWMT16
(
unittest
.
TestCase
):
def
checkout_one_sample
(
self
,
sample
):
# train data has 3 field: source language word indices,
# target language word indices, and target next word indices.
self
.
assertEqual
(
len
(
sample
),
3
)
# test start mark and end mark in source word indices.
self
.
assertEqual
(
sample
[
0
][
0
],
0
)
self
.
assertEqual
(
sample
[
0
][
-
1
],
1
)
# test start mask in target word indices
self
.
assertEqual
(
sample
[
1
][
0
],
0
)
# test en mask in target next word indices
self
.
assertEqual
(
sample
[
2
][
-
1
],
1
)
def
test_train
(
self
):
for
idx
,
sample
in
enumerate
(
paddle
.
v2
.
dataset
.
wmt16
.
train
(
src_dict_size
=
100000
,
trg_dict_size
=
100000
)()):
if
idx
>=
10
:
break
self
.
checkout_one_sample
(
sample
)
def
test_test
(
self
):
for
idx
,
sample
in
enumerate
(
paddle
.
v2
.
dataset
.
wmt16
.
test
(
src_dict_size
=
1000
,
trg_dict_size
=
1000
)()):
if
idx
>=
10
:
break
self
.
checkout_one_sample
(
sample
)
def
test_val
(
self
):
for
idx
,
sample
in
enumerate
(
paddle
.
v2
.
dataset
.
wmt16
.
validation
(
src_dict_size
=
1000
,
trg_dict_size
=
1000
)()):
if
idx
>=
10
:
break
self
.
checkout_one_sample
(
sample
)
def
test_get_dict
(
self
):
dict_size
=
1000
word_dict
=
paddle
.
v2
.
dataset
.
wmt16
.
get_dict
(
"en"
,
dict_size
,
True
)
self
.
assertEqual
(
len
(
word_dict
),
dict_size
)
self
.
assertEqual
(
word_dict
[
0
],
"<s>"
)
self
.
assertEqual
(
word_dict
[
1
],
"<e>"
)
self
.
assertEqual
(
word_dict
[
2
],
"<unk>"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/dataset/wmt14.py
浏览文件 @
9a97c7f7
...
...
@@ -25,12 +25,20 @@ import gzip
import
paddle.v2.dataset.common
from
paddle.v2.parameters
import
Parameters
__all__
=
[
'train'
,
'test'
,
'build_dict'
,
'convert'
]
URL_DEV_TEST
=
'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
__all__
=
[
'train'
,
'test'
,
'get_dict'
,
'convert'
,
]
URL_DEV_TEST
=
(
'http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz'
)
MD5_DEV_TEST
=
'7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN
=
'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
# this is a small set of data for test. The original data is too large and
# will be add later.
URL_TRAIN
=
(
'http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz'
)
MD5_TRAIN
=
'0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92
URL_MODEL
=
'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
...
...
python/paddle/v2/dataset/wmt16.py
0 → 100644
浏览文件 @
9a97c7f7
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ACL2016 Multimodal Machine Translation. Please see this websit for more details:
http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions.
@article{elliott-EtAl:2016:VL16,
author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
title = {Multi30K: Multilingual English-German Image Descriptions},
booktitle = {Proceedings of the 6th Workshop on Vision and Language},
year = {2016},
pages = {70--74},
year = 2016
}
"""
import
os
import
tarfile
import
gzip
from
collections
import
defaultdict
import
paddle.v2.dataset.common
__all__
=
[
"train"
,
"test"
,
"validation"
,
"convert"
,
"fetch"
,
"get_dict"
,
]
DATA_URL
=
(
"http://cloud.dlnel.org/filepub/"
"?uuid=46a0808e-ddd8-427c-bacd-0dbc6d045fed"
)
DATA_MD5
=
"0c38be43600334966403524a40dcd81e"
TOTAL_EN_WORDS
=
11250
TOTAL_DE_WORDS
=
19220
START_MARK
=
"<s>"
END_MARK
=
"<e>"
UNK_MARK
=
"<unk>"
def
__build_dict__
(
tar_file
,
dict_size
,
save_path
,
lang
):
word_dict
=
defaultdict
(
int
)
with
tarfile
.
open
(
tar_file
,
mode
=
"r"
)
as
f
:
for
line
in
f
.
extractfile
(
"wmt16/train"
):
line_split
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
line_split
)
!=
2
:
continue
sen
=
line_split
[
0
]
if
lang
==
"en"
else
line_split
[
1
]
for
w
in
sen
.
split
():
word_dict
[
w
]
+=
1
with
open
(
save_path
,
"w"
)
as
fout
:
fout
.
write
(
"%s
\n
%s
\n
%s
\n
"
%
(
START_MARK
,
END_MARK
,
UNK_MARK
))
for
idx
,
word
in
enumerate
(
sorted
(
word_dict
.
iteritems
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)):
if
idx
+
3
==
dict_size
:
break
fout
.
write
(
"%s
\n
"
%
(
word
[
0
]))
def
__load_dict__
(
tar_file
,
dict_size
,
lang
,
reverse
=
False
):
dict_path
=
os
.
path
.
join
(
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
,
"wmt16/%s_%d.dict"
%
(
lang
,
dict_size
))
if
not
os
.
path
.
exists
(
dict_path
)
or
(
len
(
open
(
dict_path
,
"r"
).
readlines
())
!=
dict_size
):
__build_dict__
(
tar_file
,
dict_size
,
dict_path
,
lang
)
word_dict
=
{}
with
open
(
dict_path
,
"r"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
()
else
:
word_dict
[
line
.
strip
()]
=
idx
return
word_dict
def
__get_dict_size__
(
src_dict_size
,
trg_dict_size
,
src_lang
):
src_dict_size
=
min
(
src_dict_size
,
(
TOTAL_EN_WORDS
if
src_lang
==
"en"
else
TOTAL_DE_WORDS
))
trg_dict_size
=
min
(
trg_dict_size
,
(
TOTAL_DE_WORDS
if
src_lang
==
"en"
else
TOTAL_ENG_WORDS
))
return
src_dict_size
,
trg_dict_size
def
reader_creator
(
tar_file
,
file_name
,
src_dict_size
,
trg_dict_size
,
src_lang
):
def
reader
():
src_dict
=
__load_dict__
(
tar_file
,
src_dict_size
,
src_lang
)
trg_dict
=
__load_dict__
(
tar_file
,
trg_dict_size
,
(
"de"
if
src_lang
==
"en"
else
"en"
))
# the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
# dictionary to determine their indices.
start_id
=
src_dict
[
START_MARK
]
end_id
=
src_dict
[
END_MARK
]
unk_id
=
src_dict
[
UNK_MARK
]
src_col
=
0
if
src_lang
==
"en"
else
1
trg_col
=
1
-
src_col
with
tarfile
.
open
(
tar_file
,
mode
=
"r"
)
as
f
:
for
line
in
f
.
extractfile
(
file_name
):
line_split
=
line
.
strip
().
split
(
"
\t
"
)
if
len
(
line_split
)
!=
2
:
continue
src_words
=
line_split
[
src_col
].
split
()
src_ids
=
[
start_id
]
+
[
src_dict
.
get
(
w
,
unk_id
)
for
w
in
src_words
]
+
[
end_id
]
trg_words
=
line_split
[
trg_col
].
split
()
trg_ids
=
[
trg_dict
.
get
(
w
,
unk_id
)
for
w
in
trg_words
]
trg_ids_next
=
trg_ids
+
[
end_id
]
trg_ids
=
[
start_id
]
+
trg_ids
yield
src_ids
,
trg_ids
,
trg_ids_next
return
reader
def
train
(
src_dict_size
,
trg_dict_size
,
src_lang
=
"en"
):
"""
WMT16 train set reader.
This function returns the reader for train data. Each sample the reader
returns is made up of three fields: the source language word index sequence,
target language word index sequence and next word index sequence.
NOTE:
The original like for training data is:
http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz
paddle.dataset.wmt16 provides a tokenized version of the original dataset by
using moses's tokenization script:
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
Args:
src_dict_size(int): Size of the source language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
trg_dict_size(int): Size of the target language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
src_lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
Returns:
callable: The train reader.
"""
assert
(
src_lang
in
[
"en"
,
"de"
],
(
"An error language type. Only support: "
"en (for English); de(for Germany)"
))
src_dict_size
,
trg_dict_size
=
__get_dict_size__
(
src_dict_size
,
trg_dict_size
,
src_lang
)
return
reader_creator
(
tar_file
=
paddle
.
v2
.
dataset
.
common
.
download
(
DATA_URL
,
"wmt16"
,
DATA_MD5
,
"wmt16.tar.gz"
),
file_name
=
"wmt16/train"
,
src_dict_size
=
src_dict_size
,
trg_dict_size
=
trg_dict_size
,
src_lang
=
src_lang
)
def
test
(
src_dict_size
,
trg_dict_size
,
src_lang
=
"en"
):
"""
WMT16 test set reader.
This function returns the reader for test data. Each sample the reader
returns is made up of three fields: the source language word index sequence,
target language word index sequence and next word index sequence.
NOTE:
The original like for test data is:
http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz
paddle.dataset.wmt16 provides a tokenized version of the original dataset by
using moses's tokenization script:
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
Args:
src_dict_size(int): Size of the source language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
trg_dict_size(int): Size of the target language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
src_lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
Returns:
callable: The test reader.
"""
assert
(
src_lang
in
[
"en"
,
"de"
],
(
"An error language type. "
"Only support: en (for English); de(for Germany)"
))
src_dict_size
,
trg_dict_size
=
__get_dict_size__
(
src_dict_size
,
trg_dict_size
,
src_lang
)
return
reader_creator
(
tar_file
=
paddle
.
v2
.
dataset
.
common
.
download
(
DATA_URL
,
"wmt16"
,
DATA_MD5
,
"wmt16.tar.gz"
),
file_name
=
"wmt16/test"
,
src_dict_size
=
src_dict_size
,
trg_dict_size
=
trg_dict_size
,
src_lang
=
src_lang
)
def
validation
(
src_dict_size
,
trg_dict_size
,
src_lang
=
"en"
):
"""
WMT16 validation set reader.
This function returns the reader for validation data. Each sample the reader
returns is made up of three fields: the source language word index sequence,
target language word index sequence and next word index sequence.
NOTE:
The original like for validation data is:
http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz
paddle.dataset.wmt16 provides a tokenized version of the original dataset by
using moses's tokenization script:
https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl
Args:
src_dict_size(int): Size of the source language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
trg_dict_size(int): Size of the target language dictionary. Three
special tokens will be added into the dictionary:
<s> for start mark, <e> for end mark, and <unk> for
unknown word.
src_lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
Returns:
callable: The validation reader.
"""
assert
(
src_lang
in
[
"en"
,
"de"
],
(
"An error language type. "
"Only support: en (for English); de(for Germany)"
))
src_dict_size
,
trg_dict_size
=
__get_dict_size__
(
src_dict_size
,
trg_dict_size
,
src_lang
)
return
reader_creator
(
tar_file
=
paddle
.
v2
.
dataset
.
common
.
download
(
DATA_URL
,
"wmt16"
,
DATA_MD5
,
"wmt16.tar.gz"
),
file_name
=
"wmt16/val"
,
src_dict_size
=
src_dict_size
,
trg_dict_size
=
trg_dict_size
,
src_lang
=
src_lang
)
def
get_dict
(
lang
,
dict_size
,
reverse
=
False
):
"""
return the word dictionary for the specified language.
Args:
lang(string): A string indicating which language is the source
language. Available options are: "en" for English
and "de" for Germany.
dict_size(int): Size of the specified language dictionary.
reverse(bool): If reverse is set to False, the returned python
dictionary will use word as key and use index as value.
If reverse is set to True, the returned python
dictionary will use index as key and word as value.
Returns:
dict: The word dictionary for the specific language.
"""
if
lang
==
"en"
:
dict_size
=
min
(
dict_size
,
TOTAL_EN_WORDS
)
else
:
dict_size
=
min
(
dict_size
,
TOTAL_DE_WORDS
)
dict_path
=
os
.
path
.
join
(
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
,
"wmt16/%s_%d.dict"
%
(
lang
,
dict_size
))
assert
(
os
.
path
.
exists
(
dict_path
),
"Word dictionary does not exist. "
"Please invoke paddle.dataset.wmt16.train/test/validation "
"first to build the dictionary."
)
tar_file
=
os
.
path
.
join
(
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
,
"wmt16.tar.gz"
)
return
__load_dict__
(
tar_file
,
dict_size
,
lang
,
reverse
)
def
fetch
():
"""download the entire dataset.
"""
paddle
.
v4
.
dataset
.
common
.
download
(
DATA_URL
,
"wmt16"
,
DATA_MD5
,
"wmt16.tar.gz"
)
def
convert
(
path
,
src_dict_size
,
trg_dict_size
,
src_lang
):
"""Converts dataset to recordio format.
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
(
src_dict_size
=
src_dict_size
,
trg_dict_size
=
trg_dict_size
,
src_lang
=
src_lang
),
1000
,
"wmt16_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(
src_dict_size
=
src_dict_size
,
trg_dict_size
=
trg_dict_size
,
src_lang
=
src_lang
),
1000
,
"wmt16_test"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
validation
(
src_dict_size
=
src_dict_size
,
trg_dict_size
=
trg_dict_size
,
src_lang
=
src_lang
),
1000
,
"wmt16_validation"
)
python/paddle/v2/fluid/layers/control_flow.py
浏览文件 @
9a97c7f7
...
...
@@ -19,13 +19,32 @@ import contextlib
from
..registry
import
autodoc
__all__
=
[
'split_lod_tensor'
,
'merge_lod_tensor'
,
'BlockGuard'
,
'BlockGuardWithCompletion'
,
'StaticRNNMemoryLink'
,
'WhileGuard'
,
'While'
,
'lod_rank_table'
,
'max_sequence_len'
,
'topk'
,
'lod_tensor_to_array'
,
'array_to_lod_tensor'
,
'increment'
,
'array_write'
,
'create_array'
,
'less_than'
,
'array_read'
,
'shrink_memory'
,
'array_length'
,
'IfElse'
,
'DynamicRNN'
,
'ConditionalBlock'
,
'StaticRNN'
,
'reorder_lod_tensor_by_rank'
,
'ParallelDo'
,
'Print'
'split_lod_tensor'
,
'merge_lod_tensor'
,
'BlockGuard'
,
'BlockGuardWithCompletion'
,
'StaticRNNMemoryLink'
,
'WhileGuard'
,
'While'
,
'lod_rank_table'
,
'max_sequence_len'
,
'topk'
,
'lod_tensor_to_array'
,
'array_to_lod_tensor'
,
'increment'
,
'array_write'
,
'create_array'
,
'less_than'
,
'array_read'
,
'shrink_memory'
,
'array_length'
,
'IfElse'
,
'DynamicRNN'
,
'ConditionalBlock'
,
'StaticRNN'
,
'reorder_lod_tensor_by_rank'
,
'ParallelDo'
,
'Print'
,
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录