Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8bef3f4d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8bef3f4d
编写于
3月 04, 2017
作者:
Y
Yu Yang
提交者:
GitHub
3月 04, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1553 from helinwang/wmt14
dataset reader for wmt14
上级
af607df2
d6c6a996
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
144 addition
and
1 deletion
+144
-1
python/paddle/v2/dataset/__init__.py
python/paddle/v2/dataset/__init__.py
+2
-1
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+142
-0
未找到文件。
python/paddle/v2/dataset/__init__.py
浏览文件 @
8bef3f4d
...
...
@@ -20,8 +20,9 @@ import movielens
import
conll05
import
uci_housing
import
sentiment
import
wmt14
__all__
=
[
'mnist'
,
'imikolov'
,
'imdb'
,
'cifar'
,
'movielens'
,
'conll05'
,
'sentiment'
'uci_housing'
'uci_housing'
,
'wmt14'
]
python/paddle/v2/dataset/wmt14.py
0 → 100644
浏览文件 @
8bef3f4d
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
wmt14 dataset
"""
import
paddle.v2.dataset.common
import
tarfile
import
os.path
import
itertools
__all__
=
[
'train'
,
'test'
,
'build_dict'
]
URL_DEV_TEST
=
'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST
=
'7d7897317ddd8ba0ae5c5fa7248d3ff5'
URL_TRAIN
=
'http://localhost:8000/train.tgz'
MD5_TRAIN
=
'72de99da2830ea5a3a2c4eb36092bbc7'
def
word_count
(
f
,
word_freq
=
None
):
add
=
paddle
.
v2
.
dataset
.
common
.
dict_add
if
word_freq
==
None
:
word_freq
=
{}
for
l
in
f
:
for
w
in
l
.
strip
().
split
():
add
(
word_freq
,
w
)
add
(
word_freq
,
'<s>'
)
add
(
word_freq
,
'<e>'
)
return
word_freq
def
get_word_dix
(
word_freq
):
TYPO_FREQ
=
50
word_freq
=
filter
(
lambda
x
:
x
[
1
]
>
TYPO_FREQ
,
word_freq
.
items
())
word_freq_sorted
=
sorted
(
word_freq
,
key
=
lambda
x
:
(
-
x
[
1
],
x
[
0
]))
words
,
_
=
list
(
zip
(
*
word_freq_sorted
))
word_idx
=
dict
(
zip
(
words
,
xrange
(
len
(
words
))))
word_idx
[
'<unk>'
]
=
len
(
words
)
return
word_idx
def
get_word_freq
(
train
,
dev
):
word_freq
=
word_count
(
train
,
word_count
(
dev
))
if
'<unk>'
in
word_freq
:
# remove <unk> for now, since we will set it as last index
del
word_freq
[
'<unk>'
]
return
word_freq
def
build_dict
():
base_dir
=
'./wmt14-data'
train_en_filename
=
base_dir
+
'/train/train.en'
train_fr_filename
=
base_dir
+
'/train/train.fr'
dev_en_filename
=
base_dir
+
'/dev/ntst1213.en'
dev_fr_filename
=
base_dir
+
'/dev/ntst1213.fr'
if
not
os
.
path
.
exists
(
train_en_filename
)
or
not
os
.
path
.
exists
(
train_fr_filename
):
with
tarfile
.
open
(
paddle
.
v2
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
))
as
tf
:
tf
.
extractall
(
base_dir
)
if
not
os
.
path
.
exists
(
dev_en_filename
)
or
not
os
.
path
.
exists
(
dev_fr_filename
):
with
tarfile
.
open
(
paddle
.
v2
.
dataset
.
common
.
download
(
URL_DEV_TEST
,
'wmt14'
,
MD5_DEV_TEST
))
as
tf
:
tf
.
extractall
(
base_dir
)
f_en
=
open
(
train_en_filename
)
f_fr
=
open
(
train_fr_filename
)
f_en_dev
=
open
(
dev_en_filename
)
f_fr_dev
=
open
(
dev_fr_filename
)
word_freq_en
=
get_word_freq
(
f_en
,
f_en_dev
)
word_freq_fr
=
get_word_freq
(
f_fr
,
f_fr_dev
)
f_en
.
close
()
f_fr
.
close
()
f_en_dev
.
close
()
f_fr_dev
.
close
()
return
get_word_dix
(
word_freq_en
),
get_word_dix
(
word_freq_fr
)
def
reader_creator
(
directory
,
path_en
,
path_fr
,
URL
,
MD5
,
dict_en
,
dict_fr
):
def
reader
():
if
not
os
.
path
.
exists
(
path_en
)
or
not
os
.
path
.
exists
(
path_fr
):
with
tarfile
.
open
(
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
'wmt14'
,
MD5
))
as
tf
:
tf
.
extractall
(
directory
)
f_en
=
open
(
path_en
)
f_fr
=
open
(
path_fr
)
UNK_en
=
dict_en
[
'<unk>'
]
UNK_fr
=
dict_fr
[
'<unk>'
]
for
en
,
fr
in
itertools
.
izip
(
f_en
,
f_fr
):
src_ids
=
[
dict_en
.
get
(
w
,
UNK_en
)
for
w
in
en
.
strip
().
split
()]
tar_ids
=
[
dict_fr
.
get
(
w
,
UNK_fr
)
for
w
in
[
'<s>'
]
+
fr
.
strip
().
split
()
+
[
'<e>'
]
]
# remove sequence whose length > 80 in training mode
if
len
(
src_ids
)
==
0
or
len
(
tar_ids
)
<=
1
or
len
(
src_ids
)
>
80
or
len
(
tar_ids
)
>
80
:
continue
yield
src_ids
,
tar_ids
[:
-
1
],
tar_ids
[
1
:]
f_en
.
close
()
f_fr
.
close
()
return
reader
def
train
(
dict_en
,
dict_fr
):
directory
=
'./wmt14-data'
return
reader_creator
(
directory
,
directory
+
'/train/train.en'
,
directory
+
'/train/train.fr'
,
URL_TRAIN
,
MD5_TRAIN
,
dict_en
,
dict_fr
)
def
test
(
dict_en
,
dict_fr
):
directory
=
'./wmt14-data'
return
reader_creator
(
directory
,
directory
+
'/dev/ntst1213.en'
,
directory
+
'/dev/ntst1213.fr'
,
URL_DEV_TEST
,
MD5_DEV_TEST
,
dict_en
,
dict_fr
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录