提交 f3294541 编写于 作者: L liaogang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into cpu_mem

......@@ -4,6 +4,7 @@ cache:
directories:
- $HOME/.ccache
- $HOME/.cache/pip
- $TRAVIS_BUILD_DIR/build/third_party
sudo: required
dist: trusty
os:
......@@ -41,7 +42,9 @@ before_install:
- |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script:
- paddle/scripts/travis/$JOB.sh
- |
timeout 2580 paddle/scripts/travis/${JOB}.sh # 43min timeout
RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi;
notifications:
email:
on_success: change
......
......@@ -18,6 +18,7 @@ func main() {
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
......@@ -29,7 +30,7 @@ func main() {
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout)
if err != nil {
panic(err)
}
......
......@@ -73,7 +73,7 @@ type Service struct {
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
......@@ -103,6 +103,22 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
......@@ -141,6 +157,16 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
return s, nil
}
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
......
......@@ -2,3 +2,5 @@ cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc)
//#include <stdexcept>
//#include <unittest/unittest.h>
#include <sstream>
#include <vector>
......
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <memory>
#include <typeindex>
#include <typeinfo>
#include "paddle/platform/assert.h"
namespace paddle {
namespace framework {
class Variable {
public:
template <typename T>
const T& Get() const {
PADDLE_ASSERT(holder_ != nullptr);
PADDLE_ASSERT(std::type_index(typeid(T)) ==
std::type_index(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr());
}
template <typename T>
T* GetMutable() {
if (holder_ == nullptr ||
std::type_index(typeid(T)) != std::type_index(holder_->Type())) {
holder_.reset(new PlaceholderImpl<T>(new T()));
}
return static_cast<T*>(holder_->Ptr());
}
private:
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_info& Type() const = 0;
virtual void* Ptr() const = 0;
};
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {}
virtual const std::type_info& Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
std::unique_ptr<T> ptr_;
const std::type_info& type_;
};
std::unique_ptr<Placeholder>
holder_; // pointers to a PlaceholderImpl object indeed.
};
} // namespace framework
} // namespace paddle
# Design Doc: Variable
Variable is also known as *blob* in MxNet and Caffe2. It is the input and output type of operators, where a neural network is a graph of operators.
## Requirements: Lazy Memory Allocation
For the flexibility of a DL system, a variable should be able to contain any typed value -- a tensor in most cases, but could also be some integer IDs or a scope of other variables in the case of RNN.
To use the minimum amount of memory, we'd like that a variable to allocate memory when it has to, or, lazy memory allocation. Let's take the following example:
```cpp
Variable vr, v1, v2;
Tensor* t1 = new Tensor();
Tensor* t2 = new Tensor();
Randomize(
/* malloc */ v1.GetMutable<Tensor>().mutable_data<float16>(DDim(100,200)),
/* size */ t1.Size());
Randomize(
/* malloc */ v2.GetMutable<Tensor>().mutable_data<float16>(DDim(200,300)),
/* size */ t2.Size());
Mult(
/*result*/ vr.GetMutable<Tensor>().mutable_data<v1.Type()>(SizeOfMult(v1, v2)),
/*input1*/ v1.Get<Tensor>().data(),
/*input2*/ v2.Get<Tensor>().data());
```
We see that a variable holds nothing until `Variable::GetMutable<Tensor>()` allocates a tensor and puts it in the variable. Similarly, a tensor gets its memory until `Tensor::mutable_data()`.
This syntax for lazy memory allocation when we call `Randomize` and `Mult`, those functions that mutate the variable, so it saves us some line of C++ code.
## Implementation: Type Hiding
To make memory allocation lazy, we cannot assume that we know the type held by a variable at definition time. In other words, `class Variable` cannot be a template `template <T> class Variable`.
Because we don't know the type `T`, we cannot save a `T*` as `Variable's` data member. Instead, we save an interface object `Placeholder`, who can return the pointer to the saved object via `Placeholder::Ptr()` as `void*`.
But anyway, Variable needs to know `T` so could it `delete<T>(ptr)` and so could `Variable::Get` checks the expected type and the saved object's type.
We save `T` in `PlaceholderImpl`, the implementation of `Placeholder`. Please be aware that `PlaceholderImpl` is a class template and `T` is passed in as a template parameter.
Because `PlaceholderImpl` knows `T`, it can save and return `typeid(T)` for the type comparison in `Variable::Get` and `Variable::GetMutable`.
## Conclusion
The technique type hiding utilizes C++ class templates, interface and derivation, and C++ RTTI (typeid). This combination saves us from definition something like `caffe2::TypeMata`, which takes hundreds of lines of C++ code.
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include <string>
#include "gtest/gtest.h"
#include "paddle/framework/variable.h"
TEST(Variable, GetMutable) {
using paddle::framework::Variable;
struct Tensor {
int content_;
};
std::unique_ptr<Variable> v(new Variable());
Tensor* t = v->GetMutable<Tensor>();
t->content_ = 1234;
const Tensor& tt = v->Get<Tensor>();
EXPECT_EQ(1234, tt.content_);
std::string* s = v->GetMutable<std::string>();
*s = "hello";
const std::string& ss = v->Get<std::string>();
EXPECT_EQ("hello", ss);
}
......@@ -2,3 +2,4 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_test(must_check_test SRCS must_check_test.cc)
......@@ -10,24 +10,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
/**
* This header defines some useful attribute by each compiler. It is the
* abstract layer of compilers.
*/
#ifdef __GNUC__
#define GCC_VERSION \
(__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__)
#else
#define GCC_VERSION
#endif
/**
* __must_check macro. It make the function's return value must be used,
* otherwise it will raise a compile warning. And also Paddle treat all compile
* warnings as errors.
*/
#if GCC_VERSION >= 30400
#ifdef __GNUC__
#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400
#define __must_check __attribute__((warn_unused_result))
#else
#define __must_check
#endif
#else
#define __must_check
#endif
#include <gtest/gtest.h>
#include <paddle/platform/must_check.h>
int __must_check SomeFunctionMustCheck() { return 0; }
TEST(MustCheck, all) {
// This line should not be compiled, because the
// return value of SomeFunctionMustCheck marked as __must_check
// SomeFunctionMustCheck();
}
\ No newline at end of file
......@@ -7,6 +7,7 @@ cd $TRAVIS_BUILD_DIR/build
# Compile Documentation only.
cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_STYLE_CHECK=OFF
mkdir output
make -j `nproc`
find .. -name '*whl' | xargs pip install # install all wheels.
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include <stdio.h>
#include <memory>
#include <string>
#include "Compiler.h"
#include "paddle/platform/must_check.h"
namespace paddle {
......
......@@ -31,10 +31,10 @@ images per class.
import cPickle
import itertools
import numpy
from common import download
import paddle.v2.dataset.common
import tarfile
__all__ = ['train100', 'test100', 'train10', 'test10']
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
......@@ -75,7 +75,8 @@ def train100():
:rtype: callable
"""
return reader_creator(
download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'train')
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'train')
def test100():
......@@ -88,7 +89,9 @@ def test100():
:return: Test reader creator.
:rtype: callable
"""
return reader_creator(download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'test')
return reader_creator(
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'test')
def train10():
......@@ -102,7 +105,8 @@ def train10():
:rtype: callable
"""
return reader_creator(
download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch')
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch')
def test10():
......@@ -116,9 +120,20 @@ def test10():
:rtype: callable
"""
return reader_creator(
download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'test_batch')
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch')
def fetch():
download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train100(), 10, "cifar_train100")
paddle.v2.dataset.common.convert(path, test100(), 10, "cifar_test100")
paddle.v2.dataset.common.convert(path, train10(), 10, "cifar_train10")
paddle.v2.dataset.common.convert(path, test10(), 10, "cifar_test10")
......@@ -23,7 +23,10 @@ import paddle.v2.dataset
import cPickle
import glob
__all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader']
__all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
'convert'
]
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
......
......@@ -23,9 +23,9 @@ to initialize SRL model.
import tarfile
import gzip
import itertools
from common import download
import paddle.v2.dataset.common
__all__ = ['test, get_dict', 'get_embedding']
__all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz'
DATA_MD5 = '387719152ae52d60422c016e92a742fc'
......@@ -182,9 +182,15 @@ def get_dict():
"""
Get the word, verb and label dictionary of Wikipedia corpus.
"""
word_dict = load_dict(download(WORDDICT_URL, 'conll05st', WORDDICT_MD5))
verb_dict = load_dict(download(VERBDICT_URL, 'conll05st', VERBDICT_MD5))
label_dict = load_dict(download(TRGDICT_URL, 'conll05st', TRGDICT_MD5))
word_dict = load_dict(
paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st',
WORDDICT_MD5))
verb_dict = load_dict(
paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st',
VERBDICT_MD5))
label_dict = load_dict(
paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st',
TRGDICT_MD5))
return word_dict, verb_dict, label_dict
......@@ -192,7 +198,7 @@ def get_embedding():
"""
Get the trained word vector based on Wikipedia corpus.
"""
return download(EMB_URL, 'conll05st', EMB_MD5)
return paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5)
def test():
......@@ -209,15 +215,23 @@ def test():
"""
word_dict, verb_dict, label_dict = get_dict()
reader = corpus_reader(
download(DATA_URL, 'conll05st', DATA_MD5),
paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5),
words_name='conll05st-release/test.wsj/words/test.wsj.words.gz',
props_name='conll05st-release/test.wsj/props/test.wsj.props.gz')
return reader_creator(reader, word_dict, verb_dict, label_dict)
def fetch():
download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)
download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)
download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)
download(EMB_URL, 'conll05st', EMB_MD5)
download(DATA_URL, 'conll05st', DATA_MD5)
paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)
paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)
paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)
paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5)
paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, test(), 10, "conl105_train")
paddle.v2.dataset.common.convert(path, test(), 10, "conl105_test")
......@@ -28,7 +28,7 @@ import re
import string
import threading
__all__ = ['build_dict', 'train', 'test']
__all__ = ['build_dict', 'train', 'test', 'convert']
URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
......@@ -166,3 +166,12 @@ def word_dict():
def fetch():
paddle.v2.dataset.common.download(URL, 'imdb', MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
w = word_dict()
paddle.v2.dataset.common.convert(path, lambda: train(w), 10, "imdb_train")
paddle.v2.dataset.common.convert(path, lambda: test(w), 10, "imdb_test")
......@@ -22,7 +22,7 @@ import paddle.v2.dataset.common
import collections
import tarfile
__all__ = ['train', 'test', 'build_dict']
__all__ = ['train', 'test', 'build_dict', 'convert']
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
......@@ -146,3 +146,15 @@ def test(word_idx, n, data_type=DataType.NGRAM):
def fetch():
paddle.v2.dataset.common.download(URL, "imikolov", MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
N = 5
word_dict = build_dict()
paddle.v2.dataset.common.convert(path,
train(word_dict, N), 10, "imikolov_train")
paddle.v2.dataset.common.convert(path,
test(word_dict, N), 10, "imikolov_test")
......@@ -21,7 +21,7 @@ import paddle.v2.dataset.common
import subprocess
import numpy
import platform
__all__ = ['train', 'test']
__all__ = ['train', 'test', 'convert']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
......@@ -113,3 +113,11 @@ def fetch():
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train(), 10, "minist_train")
paddle.v2.dataset.common.convert(path, test(), 10, "minist_test")
......@@ -23,14 +23,15 @@ set and test set into paddle reader creators.
"""
import zipfile
from common import download
import paddle.v2.dataset.common
import re
import random
import functools
__all__ = [
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
'age_table', 'movie_categories', 'max_job_id', 'user_info', 'movie_info'
'age_table', 'movie_categories', 'max_job_id', 'user_info', 'movie_info',
'convert'
]
age_table = [1, 18, 25, 35, 45, 50, 56]
......@@ -99,7 +100,7 @@ USER_INFO = None
def __initialize_meta_info__():
fn = download(URL, "movielens", MD5)
fn = paddle.v2.dataset.common.download(URL, "movielens", MD5)
global MOVIE_INFO
if MOVIE_INFO is None:
pattern = re.compile(r'^(.*)\((\d+)\)$')
......@@ -246,7 +247,15 @@ def unittest():
def fetch():
download(URL, "movielens", MD5)
paddle.v2.dataset.common.download(URL, "movielens", MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train(), 10, "movielens_train")
paddle.v2.dataset.common.convert(path, test(), 10, "movielens_test")
if __name__ == '__main__':
......
......@@ -26,9 +26,9 @@ from itertools import chain
import nltk
from nltk.corpus import movie_reviews
import common
import paddle.v2.dataset.common
__all__ = ['train', 'test', 'get_word_dict']
__all__ = ['train', 'test', 'get_word_dict', 'convert']
NUM_TRAINING_INSTANCES = 1600
NUM_TOTAL_INSTANCES = 2000
......@@ -39,12 +39,13 @@ def download_data_if_not_yet():
"""
try:
# make sure that nltk can find the data
if common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(common.DATA_HOME)
if paddle.v2.dataset.common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(paddle.v2.dataset.common.DATA_HOME)
movie_reviews.categories()
except LookupError:
print "Downloading movie_reviews data set, please wait....."
nltk.download('movie_reviews', download_dir=common.DATA_HOME)
nltk.download(
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path
......@@ -128,4 +129,13 @@ def test():
def fetch():
nltk.download('movie_reviews', download_dir=common.DATA_HOME)
nltk.download(
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train, 10, "sentiment_train")
paddle.v2.dataset.common.convert(path, test, 10, "sentiment_test")
......@@ -14,14 +14,14 @@
"""
UCI Housing dataset.
This module will download dataset from
This module will paddle.v2.dataset.common.download dataset from
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and
parse training set and test set into paddle reader creators.
"""
import numpy as np
import os
from common import download
import paddle.v2.dataset.common
__all__ = ['train', 'test']
......@@ -29,7 +29,7 @@ URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing
MD5 = 'd4accdce7a25600298819f8e28e8d593'
feature_names = [
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
'PTRATIO', 'B', 'LSTAT'
'PTRATIO', 'B', 'LSTAT', 'convert'
]
UCI_TRAIN_DATA = None
......@@ -82,7 +82,7 @@ def train():
:rtype: callable
"""
global UCI_TRAIN_DATA
load_data(download(URL, 'uci_housing', MD5))
load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
def reader():
for d in UCI_TRAIN_DATA:
......@@ -102,7 +102,7 @@ def test():
:rtype: callable
"""
global UCI_TEST_DATA
load_data(download(URL, 'uci_housing', MD5))
load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
def reader():
for d in UCI_TEST_DATA:
......@@ -112,4 +112,12 @@ def test():
def fetch():
download(URL, 'uci_housing', MD5)
paddle.v2.dataset.common.download(URL, 'uci_housing', MD5)
def convert(path):
"""
Converts dataset to recordio format
"""
paddle.v2.dataset.common.convert(path, train(), 10, "uci_housing_train")
paddle.v2.dataset.common.convert(path, test(), 10, "uci_houseing_test")
......@@ -22,10 +22,10 @@ parse training set and test set into paddle reader creators.
import tarfile
import gzip
from paddle.v2.dataset.common import download
import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters
__all__ = ['train', 'test', 'build_dict']
__all__ = ['train', 'test', 'build_dict', 'convert']
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
......@@ -115,7 +115,8 @@ def train(dict_size):
:rtype: callable
"""
return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size)
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'train/train', dict_size)
def test(dict_size):
......@@ -130,16 +131,18 @@ def test(dict_size):
:rtype: callable
"""
return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'test/test', dict_size)
def gen(dict_size):
return reader_creator(
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'gen/gen', dict_size)
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN),
'gen/gen', dict_size)
def model():
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'wmt14', MD5_MODEL)
with gzip.open(tar_file, 'r') as f:
parameters = Parameters.from_tar(f)
return parameters
......@@ -148,7 +151,7 @@ def model():
def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = download(URL_TRAIN, 'wmt14', MD5_TRAIN)
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
......@@ -157,5 +160,14 @@ def get_dict(dict_size, reverse=True):
def fetch():
download(URL_TRAIN, 'wmt14', MD5_TRAIN)
download(URL_MODEL, 'wmt14', MD5_MODEL)
paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
paddle.v2.dataset.common.download(URL_MODEL, 'wmt14', MD5_MODEL)
def convert(path):
"""
Converts dataset to recordio format
"""
dict_size = 30000
paddle.v2.dataset.common.convert(path, train(dict_size), 10, "wmt14_train")
paddle.v2.dataset.common.convert(path, test(dict_size), 10, "wmt14_test")
......@@ -13,9 +13,7 @@
# limitations under the License.
import os
import unittest
import numpy as np
import paddle.v2.reader.creator
......
......@@ -13,6 +13,7 @@ packages=['paddle',
setup_requires=["requests",
"numpy",
"protobuf==3.1",
"recordio",
"matplotlib",
"rarfile"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册