Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cd161926
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cd161926
编写于
6月 26, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/paddlepaddle/paddle
into memory_cpu_allocator
上级
f149d183
bf57345e
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
168 addition
and
60 deletion
+168
-60
paddle/gserver/gradientmachines/MultiGradientMachine.cpp
paddle/gserver/gradientmachines/MultiGradientMachine.cpp
+15
-7
paddle/gserver/gradientmachines/MultiGradientMachine.h
paddle/gserver/gradientmachines/MultiGradientMachine.h
+4
-2
python/paddle/v2/dataset/cifar.py
python/paddle/v2/dataset/cifar.py
+23
-8
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+4
-1
python/paddle/v2/dataset/conll05.py
python/paddle/v2/dataset/conll05.py
+26
-12
python/paddle/v2/dataset/imdb.py
python/paddle/v2/dataset/imdb.py
+10
-1
python/paddle/v2/dataset/imikolov.py
python/paddle/v2/dataset/imikolov.py
+13
-1
python/paddle/v2/dataset/mnist.py
python/paddle/v2/dataset/mnist.py
+9
-1
python/paddle/v2/dataset/movielens.py
python/paddle/v2/dataset/movielens.py
+13
-4
python/paddle/v2/dataset/sentiment.py
python/paddle/v2/dataset/sentiment.py
+16
-6
python/paddle/v2/dataset/uci_housing.py
python/paddle/v2/dataset/uci_housing.py
+14
-6
python/paddle/v2/dataset/wmt14.py
python/paddle/v2/dataset/wmt14.py
+21
-9
python/paddle/v2/reader/tests/creator_test.py
python/paddle/v2/reader/tests/creator_test.py
+0
-2
未找到文件。
paddle/gserver/gradientmachines/MultiGradientMachine.cpp
浏览文件 @
cd161926
...
...
@@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
outArgStream_
=
HPPL_STREAM_1
;
start
();
}
void
MultiGradientMachine
::
start
()
{
for
(
auto
&
thread
:
threads_
)
{
thread
->
start
();
}
}
void
MultiGradientMachine
::
finish
()
{
for
(
auto
&
thread
:
threads_
)
{
thread
->
stop
();
}
}
std
::
vector
<
const
std
::
vector
<
ParameterPtr
>*>
MultiGradientMachine
::
getSlaveParameters
()
{
std
::
vector
<
const
std
::
vector
<
ParameterPtr
>*>
vec
;
...
...
@@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() {
}
}
void
MultiGradientMachine
::
finish
()
{
for
(
auto
&
thread
:
threads_
)
{
thread
->
stop
();
}
}
Evaluator
*
MultiGradientMachine
::
makeEvaluator
()
const
{
return
threads_
[
0
]
->
getGradientMachine
()
->
makeEvaluator
();
}
...
...
@@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
gradStream_
=
HPPL_STREAM_2
;
valueStream_
=
HPPL_STREAM_3
;
stopping_
=
fals
e
;
stopping_
=
tru
e
;
updateCounter_
=
0
;
parameterUpdated_
=
false
;
}
...
...
@@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread
::~
TrainerThread
()
{
stop
();
}
void
TrainerThread
::
start
()
{
if
(
!
stopping_
)
return
;
stopping_
=
false
;
gradientMachine_
->
start
();
computeThread_
.
reset
(
new
std
::
thread
([
this
]()
{
computeThread
();
}));
...
...
paddle/gserver/gradientmachines/MultiGradientMachine.h
浏览文件 @
cd161926
...
...
@@ -176,6 +176,10 @@ public:
explicit
MultiGradientMachine
(
const
ModelConfig
&
config
,
bool
useGpu
);
virtual
void
start
();
virtual
void
finish
();
virtual
void
prefetch
(
const
std
::
vector
<
Argument
>&
inArgs
);
virtual
void
forward
(
const
std
::
vector
<
Argument
>&
inArgs
,
...
...
@@ -193,8 +197,6 @@ public:
virtual
void
onPassEnd
();
virtual
void
finish
();
virtual
Evaluator
*
makeEvaluator
()
const
;
virtual
void
eval
(
Evaluator
*
evaluator
)
const
;
...
...
python/paddle/v2/dataset/cifar.py
浏览文件 @
cd161926
...
...
@@ -31,10 +31,10 @@ images per class.
import
cPickle
import
itertools
import
numpy
from
common
import
download
import
paddle.v2.dataset.common
import
tarfile
__all__
=
[
'train100'
,
'test100'
,
'train10'
,
'test10'
]
__all__
=
[
'train100'
,
'test100'
,
'train10'
,
'test10'
,
'convert'
]
URL_PREFIX
=
'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL
=
URL_PREFIX
+
'cifar-10-python.tar.gz'
...
...
@@ -75,7 +75,8 @@ def train100():
:rtype: callable
"""
return
reader_creator
(
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
),
'train'
)
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
),
'train'
)
def
test100
():
...
...
@@ -88,7 +89,9 @@ def test100():
:return: Test reader creator.
:rtype: callable
"""
return
reader_creator
(
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
),
'test'
)
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
),
'test'
)
def
train10
():
...
...
@@ -102,7 +105,8 @@ def train10():
:rtype: callable
"""
return
reader_creator
(
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
'data_batch'
)
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
'data_batch'
)
def
test10
():
...
...
@@ -116,9 +120,20 @@ def test10():
:rtype: callable
"""
return
reader_creator
(
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
'test_batch'
)
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
'test_batch'
)
def
fetch
():
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
)
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train100
(),
10
,
"cifar_train100"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test100
(),
10
,
"cifar_test100"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train10
(),
10
,
"cifar_train10"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test10
(),
10
,
"cifar_test10"
)
python/paddle/v2/dataset/common.py
浏览文件 @
cd161926
...
...
@@ -23,7 +23,10 @@ import paddle.v2.dataset
import
cPickle
import
glob
__all__
=
[
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
]
__all__
=
[
'DATA_HOME'
,
'download'
,
'md5file'
,
'split'
,
'cluster_files_reader'
,
'convert'
]
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
...
...
python/paddle/v2/dataset/conll05.py
浏览文件 @
cd161926
...
...
@@ -23,9 +23,9 @@ to initialize SRL model.
import
tarfile
import
gzip
import
itertools
from
common
import
download
import
paddle.v2.dataset.common
__all__
=
[
'test, get_dict'
,
'get_embedding'
]
__all__
=
[
'test, get_dict'
,
'get_embedding'
,
'convert'
]
DATA_URL
=
'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_MD5
=
'387719152ae52d60422c016e92a742fc'
...
...
@@ -182,9 +182,15 @@ def get_dict():
"""
Get the word, verb and label dictionary of Wikipedia corpus.
"""
word_dict
=
load_dict
(
download
(
WORDDICT_URL
,
'conll05st'
,
WORDDICT_MD5
))
verb_dict
=
load_dict
(
download
(
VERBDICT_URL
,
'conll05st'
,
VERBDICT_MD5
))
label_dict
=
load_dict
(
download
(
TRGDICT_URL
,
'conll05st'
,
TRGDICT_MD5
))
word_dict
=
load_dict
(
paddle
.
v2
.
dataset
.
common
.
download
(
WORDDICT_URL
,
'conll05st'
,
WORDDICT_MD5
))
verb_dict
=
load_dict
(
paddle
.
v2
.
dataset
.
common
.
download
(
VERBDICT_URL
,
'conll05st'
,
VERBDICT_MD5
))
label_dict
=
load_dict
(
paddle
.
v2
.
dataset
.
common
.
download
(
TRGDICT_URL
,
'conll05st'
,
TRGDICT_MD5
))
return
word_dict
,
verb_dict
,
label_dict
...
...
@@ -192,7 +198,7 @@ def get_embedding():
"""
Get the trained word vector based on Wikipedia corpus.
"""
return
download
(
EMB_URL
,
'conll05st'
,
EMB_MD5
)
return
paddle
.
v2
.
dataset
.
common
.
download
(
EMB_URL
,
'conll05st'
,
EMB_MD5
)
def
test
():
...
...
@@ -209,15 +215,23 @@ def test():
"""
word_dict
,
verb_dict
,
label_dict
=
get_dict
()
reader
=
corpus_reader
(
download
(
DATA_URL
,
'conll05st'
,
DATA_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
DATA_URL
,
'conll05st'
,
DATA_MD5
),
words_name
=
'conll05st-release/test.wsj/words/test.wsj.words.gz'
,
props_name
=
'conll05st-release/test.wsj/props/test.wsj.props.gz'
)
return
reader_creator
(
reader
,
word_dict
,
verb_dict
,
label_dict
)
def
fetch
():
download
(
WORDDICT_URL
,
'conll05st'
,
WORDDICT_MD5
)
download
(
VERBDICT_URL
,
'conll05st'
,
VERBDICT_MD5
)
download
(
TRGDICT_URL
,
'conll05st'
,
TRGDICT_MD5
)
download
(
EMB_URL
,
'conll05st'
,
EMB_MD5
)
download
(
DATA_URL
,
'conll05st'
,
DATA_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
WORDDICT_URL
,
'conll05st'
,
WORDDICT_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
VERBDICT_URL
,
'conll05st'
,
VERBDICT_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
TRGDICT_URL
,
'conll05st'
,
TRGDICT_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
EMB_URL
,
'conll05st'
,
EMB_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
DATA_URL
,
'conll05st'
,
DATA_MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(),
10
,
"conl105_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(),
10
,
"conl105_test"
)
python/paddle/v2/dataset/imdb.py
浏览文件 @
cd161926
...
...
@@ -28,7 +28,7 @@ import re
import
string
import
threading
__all__
=
[
'build_dict'
,
'train'
,
'test'
]
__all__
=
[
'build_dict'
,
'train'
,
'test'
,
'convert'
]
URL
=
'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
MD5
=
'7c2ac02c03563afcf9b574c7e56c153a'
...
...
@@ -166,3 +166,12 @@ def word_dict():
def
fetch
():
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
'imdb'
,
MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
w
=
word_dict
()
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
lambda
:
train
(
w
),
10
,
"imdb_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
lambda
:
test
(
w
),
10
,
"imdb_test"
)
python/paddle/v2/dataset/imikolov.py
浏览文件 @
cd161926
...
...
@@ -22,7 +22,7 @@ import paddle.v2.dataset.common
import
collections
import
tarfile
__all__
=
[
'train'
,
'test'
,
'build_dict'
]
__all__
=
[
'train'
,
'test'
,
'build_dict'
,
'convert'
]
URL
=
'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5
=
'30177ea32e27c525793142b6bf2c8e2d'
...
...
@@ -146,3 +146,15 @@ def test(word_idx, n, data_type=DataType.NGRAM):
def
fetch
():
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
"imikolov"
,
MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
N
=
5
word_dict
=
build_dict
()
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
(
word_dict
,
N
),
10
,
"imikolov_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(
word_dict
,
N
),
10
,
"imikolov_test"
)
python/paddle/v2/dataset/mnist.py
浏览文件 @
cd161926
...
...
@@ -21,7 +21,7 @@ import paddle.v2.dataset.common
import
subprocess
import
numpy
import
platform
__all__
=
[
'train'
,
'test'
]
__all__
=
[
'train'
,
'test'
,
'convert'
]
URL_PREFIX
=
'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
...
...
@@ -113,3 +113,11 @@ def fetch():
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_LABEL_URL
,
'mnist'
,
TRAIN_LABEL_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_IMAGE_URL
,
'mnist'
,
TEST_IMAGE_MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_LABEL_URL
,
'mnist'
,
TRAIN_LABEL_MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
(),
10
,
"minist_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(),
10
,
"minist_test"
)
python/paddle/v2/dataset/movielens.py
浏览文件 @
cd161926
...
...
@@ -23,14 +23,15 @@ set and test set into paddle reader creators.
"""
import
zipfile
from
common
import
download
import
paddle.v2.dataset.common
import
re
import
random
import
functools
__all__
=
[
'train'
,
'test'
,
'get_movie_title_dict'
,
'max_movie_id'
,
'max_user_id'
,
'age_table'
,
'movie_categories'
,
'max_job_id'
,
'user_info'
,
'movie_info'
'age_table'
,
'movie_categories'
,
'max_job_id'
,
'user_info'
,
'movie_info'
,
'convert'
]
age_table
=
[
1
,
18
,
25
,
35
,
45
,
50
,
56
]
...
...
@@ -99,7 +100,7 @@ USER_INFO = None
def
__initialize_meta_info__
():
fn
=
download
(
URL
,
"movielens"
,
MD5
)
fn
=
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
"movielens"
,
MD5
)
global
MOVIE_INFO
if
MOVIE_INFO
is
None
:
pattern
=
re
.
compile
(
r
'^(.*)\((\d+)\)$'
)
...
...
@@ -246,7 +247,15 @@ def unittest():
def
fetch
():
download
(
URL
,
"movielens"
,
MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
"movielens"
,
MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
(),
10
,
"movielens_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(),
10
,
"movielens_test"
)
if
__name__
==
'__main__'
:
...
...
python/paddle/v2/dataset/sentiment.py
浏览文件 @
cd161926
...
...
@@ -26,9 +26,9 @@ from itertools import chain
import
nltk
from
nltk.corpus
import
movie_reviews
import
common
import
paddle.v2.dataset.
common
__all__
=
[
'train'
,
'test'
,
'get_word_dict'
]
__all__
=
[
'train'
,
'test'
,
'get_word_dict'
,
'convert'
]
NUM_TRAINING_INSTANCES
=
1600
NUM_TOTAL_INSTANCES
=
2000
...
...
@@ -39,12 +39,13 @@ def download_data_if_not_yet():
"""
try
:
# make sure that nltk can find the data
if
common
.
DATA_HOME
not
in
nltk
.
data
.
path
:
nltk
.
data
.
path
.
append
(
common
.
DATA_HOME
)
if
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
not
in
nltk
.
data
.
path
:
nltk
.
data
.
path
.
append
(
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
)
movie_reviews
.
categories
()
except
LookupError
:
print
"Downloading movie_reviews data set, please wait....."
nltk
.
download
(
'movie_reviews'
,
download_dir
=
common
.
DATA_HOME
)
nltk
.
download
(
'movie_reviews'
,
download_dir
=
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
)
print
"Download data set success....."
print
"Path is "
+
nltk
.
data
.
find
(
'corpora/movie_reviews'
).
path
...
...
@@ -128,4 +129,13 @@ def test():
def
fetch
():
nltk
.
download
(
'movie_reviews'
,
download_dir
=
common
.
DATA_HOME
)
nltk
.
download
(
'movie_reviews'
,
download_dir
=
paddle
.
v2
.
dataset
.
common
.
DATA_HOME
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
,
10
,
"sentiment_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
,
10
,
"sentiment_test"
)
python/paddle/v2/dataset/uci_housing.py
浏览文件 @
cd161926
...
...
@@ -14,14 +14,14 @@
"""
UCI Housing dataset.
This module will download dataset from
This module will
paddle.v2.dataset.common.
download dataset from
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and
parse training set and test set into paddle reader creators.
"""
import
numpy
as
np
import
os
from
common
import
download
import
paddle.v2.dataset.common
__all__
=
[
'train'
,
'test'
]
...
...
@@ -29,7 +29,7 @@ URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing
MD5
=
'd4accdce7a25600298819f8e28e8d593'
feature_names
=
[
'CRIM'
,
'ZN'
,
'INDUS'
,
'CHAS'
,
'NOX'
,
'RM'
,
'AGE'
,
'DIS'
,
'RAD'
,
'TAX'
,
'PTRATIO'
,
'B'
,
'LSTAT'
'PTRATIO'
,
'B'
,
'LSTAT'
,
'convert'
]
UCI_TRAIN_DATA
=
None
...
...
@@ -82,7 +82,7 @@ def train():
:rtype: callable
"""
global
UCI_TRAIN_DATA
load_data
(
download
(
URL
,
'uci_housing'
,
MD5
))
load_data
(
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
'uci_housing'
,
MD5
))
def
reader
():
for
d
in
UCI_TRAIN_DATA
:
...
...
@@ -102,7 +102,7 @@ def test():
:rtype: callable
"""
global
UCI_TEST_DATA
load_data
(
download
(
URL
,
'uci_housing'
,
MD5
))
load_data
(
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
'uci_housing'
,
MD5
))
def
reader
():
for
d
in
UCI_TEST_DATA
:
...
...
@@ -112,4 +112,12 @@ def test():
def
fetch
():
download
(
URL
,
'uci_housing'
,
MD5
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL
,
'uci_housing'
,
MD5
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
(),
10
,
"uci_housing_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(),
10
,
"uci_houseing_test"
)
python/paddle/v2/dataset/wmt14.py
浏览文件 @
cd161926
...
...
@@ -22,10 +22,10 @@ parse training set and test set into paddle reader creators.
import
tarfile
import
gzip
from
paddle.v2.dataset.common
import
download
import
paddle.v2.dataset.common
from
paddle.v2.parameters
import
Parameters
__all__
=
[
'train'
,
'test'
,
'build_dict'
]
__all__
=
[
'train'
,
'test'
,
'build_dict'
,
'convert'
]
URL_DEV_TEST
=
'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST
=
'7d7897317ddd8ba0ae5c5fa7248d3ff5'
...
...
@@ -115,7 +115,8 @@ def train(dict_size):
:rtype: callable
"""
return
reader_creator
(
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'train/train'
,
dict_size
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'train/train'
,
dict_size
)
def
test
(
dict_size
):
...
...
@@ -130,16 +131,18 @@ def test(dict_size):
:rtype: callable
"""
return
reader_creator
(
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'test/test'
,
dict_size
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'test/test'
,
dict_size
)
def
gen
(
dict_size
):
return
reader_creator
(
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'gen/gen'
,
dict_size
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
),
'gen/gen'
,
dict_size
)
def
model
():
tar_file
=
download
(
URL_MODEL
,
'wmt14'
,
MD5_MODEL
)
tar_file
=
paddle
.
v2
.
dataset
.
common
.
download
(
URL_MODEL
,
'wmt14'
,
MD5_MODEL
)
with
gzip
.
open
(
tar_file
,
'r'
)
as
f
:
parameters
=
Parameters
.
from_tar
(
f
)
return
parameters
...
...
@@ -148,7 +151,7 @@ def model():
def
get_dict
(
dict_size
,
reverse
=
True
):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file
=
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
tar_file
=
paddle
.
v2
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
src_dict
,
trg_dict
=
__read_to_dict__
(
tar_file
,
dict_size
)
if
reverse
:
src_dict
=
{
v
:
k
for
k
,
v
in
src_dict
.
items
()}
...
...
@@ -157,5 +160,14 @@ def get_dict(dict_size, reverse=True):
def
fetch
():
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
download
(
URL_MODEL
,
'wmt14'
,
MD5_MODEL
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL_TRAIN
,
'wmt14'
,
MD5_TRAIN
)
paddle
.
v2
.
dataset
.
common
.
download
(
URL_MODEL
,
'wmt14'
,
MD5_MODEL
)
def
convert
(
path
):
"""
Converts dataset to recordio format
"""
dict_size
=
30000
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
train
(
dict_size
),
10
,
"wmt14_train"
)
paddle
.
v2
.
dataset
.
common
.
convert
(
path
,
test
(
dict_size
),
10
,
"wmt14_test"
)
python/paddle/v2/reader/tests/creator_test.py
浏览文件 @
cd161926
...
...
@@ -13,9 +13,7 @@
# limitations under the License.
import
os
import
unittest
import
numpy
as
np
import
paddle.v2.reader.creator
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录