diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake index 0accf1a8dd83560324716f0f4936be56dd7a9f1b..93d7275df05d723d7dd66ef0c5ac15672c051c34 100644 --- a/cmake/external/python.cmake +++ b/cmake/external/python.cmake @@ -221,3 +221,7 @@ ENDIF(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${PYTHON_NUMPY_INCLUDE_DIR}) + +IF(NOT WITH_PYTHON) + SET(PYTHON_LIBRARIES "") +ENDIF() diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index db33a20487e579cda67a01c52ee646829df0f4e6..05817ec85455ac58566e90956a54cb86541f8488 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -109,6 +109,12 @@ sum_to_one_norm :members: sum_to_one_norm :noindex: +cross_channel_norm +------------------ +.. automodule:: paddle.v2.layer + :members: cross_channel_norm + :noindex: + Recurrent Layers ================ diff --git a/paddle/gserver/layers/CrossChannelNormLayer.cpp b/paddle/gserver/layers/CrossChannelNormLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3fbccc11032caa4878ce8dcfe7c34a261acee68b --- /dev/null +++ b/paddle/gserver/layers/CrossChannelNormLayer.cpp @@ -0,0 +1,122 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" +#include "NormLayer.h" +#include "paddle/math/BaseMatrix.h" +#include "paddle/math/Matrix.h" + +namespace paddle { + +MatrixPtr CrossChannelNormLayer::createSampleMatrix(MatrixPtr data, + size_t iter, + size_t spatialDim) { + return Matrix::create(data->getData() + iter * channels_ * spatialDim, + channels_, + spatialDim, + false, + useGpu_); +} + +MatrixPtr CrossChannelNormLayer::createSpatialMatrix(MatrixPtr data, + size_t iter, + size_t spatialDim) { + return Matrix::create( + data->getData() + iter * spatialDim, 1, spatialDim, false, useGpu_); +} + +void CrossChannelNormLayer::forward(PassType passType) { + Layer::forward(passType); + MatrixPtr inV = getInputValue(0); + + size_t batchSize = inV->getHeight(); + size_t dataDim = inV->getWidth(); + CHECK_EQ(getSize(), dataDim); + + reserveOutput(batchSize, dataDim); + MatrixPtr outV = getOutputValue(); + size_t spatialDim = dataDim / channels_; + + Matrix::resizeOrCreate(dataBuffer_, batchSize, dataDim, false, useGpu_); + Matrix::resizeOrCreate(spatialBuffer_, 1, spatialDim, false, useGpu_); + Matrix::resizeOrCreate(normBuffer_, batchSize, spatialDim, false, useGpu_); + normBuffer_->zeroMem(); + // add eps to avoid overflow + normBuffer_->addScalar(*normBuffer_, 1e-6); + inV->square2(*dataBuffer_); + for (size_t i = 0; i < batchSize; i++) { + const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim); + const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim); + MatrixPtr outVTmp = createSampleMatrix(outV, i, spatialDim); + MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim); + + // compute norm. + spatialBuffer_->sumCols(*dataTmp, 1, 0); + spatialBuffer_->sqrt2(*spatialBuffer_); + normTmp->copyFrom(*spatialBuffer_); + outVTmp->copyFrom(*inVTmp); + outVTmp->divRowVector(*spatialBuffer_); + // scale the layer. + outVTmp->mulColVector(*scale_->getW()); + } +} + +void CrossChannelNormLayer::backward(const UpdateCallback& callback) { + MatrixPtr inG = getInputGrad(0); + MatrixPtr inV = getInputValue(0); + MatrixPtr outG = getOutputGrad(); + MatrixPtr outV = getOutputValue(); + + size_t batchSize = inG->getHeight(); + size_t dataDim = inG->getWidth(); + size_t spatialDim = dataDim / channels_; + + dataBuffer_->dotMul(*outG, *outV); + Matrix::resizeOrCreate(scaleDiff_, channels_, 1, false, useGpu_); + Matrix::resizeOrCreate(channelBuffer_, channels_, 1, false, useGpu_); + Matrix::resizeOrCreate(sampleBuffer_, channels_, spatialDim, false, useGpu_); + scaleDiff_->zeroMem(); + for (size_t i = 0; i < batchSize; i++) { + MatrixPtr outGTmp = createSampleMatrix(outG, i, spatialDim); + const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim); + const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim); + const MatrixPtr inGTmp = createSampleMatrix(inG, i, spatialDim); + const MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim); + + channelBuffer_->sumRows(*dataTmp, 1, 0); + channelBuffer_->dotDiv(*channelBuffer_, *(scale_->getW())); + // store a / scale[i] in scaleDiff_ temporary + scaleDiff_->add(*channelBuffer_, 1.); + + sampleBuffer_->dotMul(*inVTmp, *outGTmp); + spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.); + // scale the grad + inGTmp->copyFrom(*inVTmp); + inGTmp->mulRowVector(*spatialBuffer_); + // divide by square of norm + spatialBuffer_->dotMul(*normTmp, *normTmp); + inGTmp->divRowVector(*spatialBuffer_); + // subtract + inGTmp->add(*outGTmp, -1, 1); + // divide by norm + inGTmp->divRowVector(*normTmp); + // scale the diff + inGTmp->mulColVector(*scale_->getW()); + } + // updata scale + if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_); + scale_->getParameterPtr()->incUpdate(callback); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/NormLayer.cpp b/paddle/gserver/layers/NormLayer.cpp index 3db0af2515ee9f64aa6c0b0a441e88562d9e398e..e094078bfe86e30c06e1b80ebc04c8213fe9abcf 100644 --- a/paddle/gserver/layers/NormLayer.cpp +++ b/paddle/gserver/layers/NormLayer.cpp @@ -26,6 +26,8 @@ Layer* NormLayer::create(const LayerConfig& config) { return new ResponseNormLayer(config); } else if (norm == "cmrnorm-projection") { return new CMRProjectionNormLayer(config); + } else if (norm == "cross-channel-norm") { + return new CrossChannelNormLayer(config); } else { LOG(FATAL) << "Unknown norm type: " << norm; return nullptr; @@ -54,4 +56,14 @@ bool ResponseNormLayer::init(const LayerMap& layerMap, return true; } +bool CrossChannelNormLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK(parameters_[0]); + const NormConfig& conf = config_.inputs(0).norm_conf(); + channels_ = conf.channels(); + scale_.reset(new Weight(channels_, 1, parameters_[0])); + return true; +} + } // namespace paddle diff --git a/paddle/gserver/layers/NormLayer.h b/paddle/gserver/layers/NormLayer.h index e77faaa322570933b3ea2de877b7859857306432..7c238ac944e52c3a83c2aa5deac18de3aff6db61 100644 --- a/paddle/gserver/layers/NormLayer.h +++ b/paddle/gserver/layers/NormLayer.h @@ -65,4 +65,35 @@ public: } }; +/** + * This layer applys normalization across the channels of each sample to a + * conv layer's output, and scales the output by a group of trainable factors + * whose dimensions equal to the number of channels. + * - Input: One and only one input layer are accepted. + * - Output: The normalized data of the input data. + * Reference: + * Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, + * Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector + */ +class CrossChannelNormLayer : public NormLayer { +public: + explicit CrossChannelNormLayer(const LayerConfig& config) + : NormLayer(config) {} + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + void forward(PassType passType); + void backward(const UpdateCallback& callback); + MatrixPtr createSampleMatrix(MatrixPtr data, size_t iter, size_t spatialDim); + MatrixPtr createSpatialMatrix(MatrixPtr data, size_t iter, size_t spatialDim); + +protected: + size_t channels_; + std::unique_ptr scale_; + MatrixPtr scaleDiff_; + MatrixPtr normBuffer_; + MatrixPtr dataBuffer_; + MatrixPtr channelBuffer_; + MatrixPtr spatialBuffer_; + MatrixPtr sampleBuffer_; +}; + } // namespace paddle diff --git a/paddle/gserver/layers/PriorBox.cpp b/paddle/gserver/layers/PriorBox.cpp index bcf5e912a50fef2cec8ebdf1e0dad9efa43fba2f..331bc7672ec0d39a7317c39f1d14e8dcadea471a 100644 --- a/paddle/gserver/layers/PriorBox.cpp +++ b/paddle/gserver/layers/PriorBox.cpp @@ -20,7 +20,7 @@ namespace paddle { /** * @brief A layer for generating priorbox locations and variances. * - Input: Two and only two input layer are accepted. The input layer must be - * be a data output layer and a convolution output layer. + * be a data output layer and a convolution output layer. * - Output: The priorbox locations and variances of the input data. * Reference: * Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, @@ -45,27 +45,32 @@ protected: MatrixPtr buffer_; }; +REGISTER_LAYER(priorbox, PriorBoxLayer); + bool PriorBoxLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { Layer::init(layerMap, parameterMap); auto pbConf = config_.inputs(0).priorbox_conf(); + std::vector tmp; + aspectRatio_.push_back(1.); std::copy(pbConf.min_size().begin(), pbConf.min_size().end(), std::back_inserter(minSize_)); std::copy(pbConf.max_size().begin(), pbConf.max_size().end(), std::back_inserter(maxSize_)); - std::copy(pbConf.aspect_ratio().begin(), - pbConf.aspect_ratio().end(), - std::back_inserter(aspectRatio_)); std::copy(pbConf.variance().begin(), pbConf.variance().end(), std::back_inserter(variance_)); + std::copy(pbConf.aspect_ratio().begin(), + pbConf.aspect_ratio().end(), + std::back_inserter(tmp)); // flip - int inputRatioLength = aspectRatio_.size(); - for (int index = 0; index < inputRatioLength; index++) - aspectRatio_.push_back(1 / aspectRatio_[index]); - aspectRatio_.push_back(1.); + int inputRatioLength = tmp.size(); + for (int index = 0; index < inputRatioLength; index++) { + aspectRatio_.push_back(tmp[index]); + aspectRatio_.push_back(1 / tmp[index]); + } numPriors_ = aspectRatio_.size(); if (maxSize_.size() > 0) numPriors_++; return true; @@ -94,12 +99,12 @@ void PriorBoxLayer::forward(PassType passType) { for (int w = 0; w < layerWidth; ++w) { real centerX = (w + 0.5) * stepW; real centerY = (h + 0.5) * stepH; - int minSize = 0; + real minSize = 0; for (size_t s = 0; s < minSize_.size(); s++) { // first prior. minSize = minSize_[s]; - int boxWidth = minSize; - int boxHeight = minSize; + real boxWidth = minSize; + real boxHeight = minSize; // xmin, ymin, xmax, ymax. tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; @@ -112,7 +117,7 @@ void PriorBoxLayer::forward(PassType passType) { CHECK_EQ(minSize_.size(), maxSize_.size()); // second prior. for (size_t s = 0; s < maxSize_.size(); s++) { - int maxSize = maxSize_[s]; + real maxSize = maxSize_[s]; boxWidth = boxHeight = sqrt(minSize * maxSize); tmpPtr[idx++] = (centerX - boxWidth / 2.) / imageWidth; tmpPtr[idx++] = (centerY - boxHeight / 2.) / imageHeight; @@ -145,6 +150,5 @@ void PriorBoxLayer::forward(PassType passType) { MatrixPtr outV = getOutputValue(); outV->copyFrom(buffer_->data_, dim * 2); } -REGISTER_LAYER(priorbox, PriorBoxLayer); } // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 5f8a7b79a06e014e3d9cb03ab033e0bce47a432a..0c22896d6e58f8705f4284b95d0a6e132cb8903d 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1642,6 +1642,25 @@ TEST(Layer, PadLayer) { } } +TEST(Layer, CrossChannelNormLayer) { + TestConfig config; + config.layerConfig.set_type("norm"); + config.layerConfig.set_size(100); + LayerInputConfig* input = config.layerConfig.add_inputs(); + NormConfig* norm = input->mutable_norm_conf(); + norm->set_norm_type("cross-channel-norm"); + norm->set_channels(10); + norm->set_size(100); + norm->set_scale(0); + norm->set_pow(0); + norm->set_blocked(0); + config.inputDefs.push_back({INPUT_DATA, "layer_0", 100, 10}); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "cross-channel-norm", 10, false, useGpu, false, 5); + } +} + TEST(Layer, smooth_l1) { TestConfig config; config.layerConfig.set_type("smooth_l1"); diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index 0a0d92d1ae65f5b6020eb71fe2a6db5a3c625d9c..de48b6fac9c7d8125a552022c52353ef6bcef995 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -1453,6 +1453,24 @@ void BaseMatrixT::divRowVector(BaseMatrixT& b) { true_type() /* bAsRowVector */, false_type()); } +template +void BaseMatrixT::mulColVector(BaseMatrixT& b) { + MatrixOffset offset(0, 0, 0, 0); + int numRows = height_; + int numCols = width_; + applyBinary(binary::DotMul(), b, numRows, numCols, offset, + false_type(), true_type() /* bAsColVector */); +} + +template +void BaseMatrixT::divColVector(BaseMatrixT& b) { + MatrixOffset offset(0, 0, 0, 0); + int numRows = height_; + int numCols = width_; + applyBinary(binary::DotDiv(), b, numRows, numCols, offset, + false_type(), true_type() /* bAsColVector */); +} + template<> template int BaseMatrixT::applyRow(Agg agg, BaseMatrixT& b) { diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h index 8691c87ac3b88499a9676d59af533e0f4713dfc3..6ed48c8d88ee698689de6f7a7f470b97a094ea5b 100644 --- a/paddle/math/BaseMatrix.h +++ b/paddle/math/BaseMatrix.h @@ -545,6 +545,9 @@ public: void mulRowVector(BaseMatrixT& b); void divRowVector(BaseMatrixT& b); + void mulColVector(BaseMatrixT& b); + void divColVector(BaseMatrixT& b); + void addP2P(BaseMatrixT& b); /** diff --git a/paddle/math/tests/test_BaseMatrix.cpp b/paddle/math/tests/test_BaseMatrix.cpp index 21918b86e1ad98766ceaf09dea3020d6e8592191..22ce39701fca7b650fc03794cb0701e0987d2dae 100644 --- a/paddle/math/tests/test_BaseMatrix.cpp +++ b/paddle/math/tests/test_BaseMatrix.cpp @@ -110,6 +110,8 @@ TEST(BaseMatrix, BaseMatrix) { compare(&BaseMatrix::addRowVector); compare(&BaseMatrix::mulRowVector); compare(&BaseMatrix::divRowVector); + compare(&BaseMatrix::mulColVector); + compare(&BaseMatrix::divColVector); compare(&BaseMatrix::addP2P); compare(&BaseMatrix::invSqrt); } diff --git a/paddle/scripts/docker/README.md b/paddle/scripts/docker/README.md index e5af5c9a1e6f96c5112895a1ec0b0c6ac57da666..7c90316ad82a6430d6c12d72e07b166b6d9d98a9 100644 --- a/paddle/scripts/docker/README.md +++ b/paddle/scripts/docker/README.md @@ -94,7 +94,7 @@ docker build -t paddle:dev --build-arg UBUNTU_MIRROR=mirror://mirrors.ubuntu.com Given the development image `paddle:dev`, the following command builds PaddlePaddle from the source tree on the development computer (host): ```bash -docker run -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "TEST=OFF" paddle:dev +docker run -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" -e "WITH_TEST=OFF" -e "RUN_TEST=OFF" paddle:dev ``` This command mounts the source directory on the host into `/paddle` in the container, so the default entry point of `paddle:dev`, `build.sh`, could build the source code with possible local changes. When it writes to `/paddle/build` in the container, it writes to `$PWD/build` on the host indeed. @@ -108,7 +108,11 @@ This command mounts the source directory on the host into `/paddle` in the conta Users can specify the following Docker build arguments with either "ON" or "OFF" value: - `WITH_GPU`: ***Required***. Generates NVIDIA CUDA GPU code and relies on CUDA libraries. - `WITH_AVX`: ***Required***. Set to "OFF" prevents from generating AVX instructions. If you don't know what is AVX, you might want to set "ON". -- `TEST`: ***Optional, default OFF***. Build unit tests and run them after building. +- `WITH_TEST`: ***Optional, default OFF***. Build unit tests binaries. Once you've built the unit tests, you can run these test manually by the following command: + ```bash + docker run -v $PWD:/paddle -e "WITH_GPU=OFF" -e "WITH_AVX=ON" paddle:dev sh -c "cd /paddle/build; make coverall" + ``` +- `RUN_TEST`: ***Optional, default OFF***. Run unit tests after building. You can't run unit tests without building it. ### Build the Production Docker Image diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 8d50ced23fc7430edd23d380c9fa12b2cd200a39..a0da561dfe962b7a0a0515d4104940175ebdecad 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -33,10 +33,10 @@ cmake .. \ -DWITH_SWIG_PY=ON \ -DCUDNN_ROOT=/usr/ \ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} \ - -DWITH_COVERAGE=${TEST:-OFF} \ + -DON_COVERALLS=${WITH_TEST:-OFF} \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON make -j `nproc` -if [[ ${TEST:-OFF} == "ON" ]]; then +if [[ ${RUN_TEST:-OFF} == "ON" ]]; then make coveralls fi make install diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 1394773b4ff12aa751b8659a4461f94ee706892e..77361f8bc751446d89d8a812f48d33cd3dffc665 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1220,9 +1220,11 @@ def parse_image(image, input_layer_name, image_conf): def parse_norm(norm, input_layer_name, norm_conf): norm_conf.norm_type = norm.norm_type - config_assert(norm.norm_type in ['rnorm', 'cmrnorm-projection'], - "norm-type %s is not in [rnorm, 'cmrnorm-projection']" % - norm.norm_type) + config_assert( + norm.norm_type in + ['rnorm', 'cmrnorm-projection', 'cross-channel-norm'], + "norm-type %s is not in [rnorm, cmrnorm-projection, cross-channel-norm]" + % norm.norm_type) norm_conf.channels = norm.channels norm_conf.size = norm.size norm_conf.scale = norm.scale @@ -1898,6 +1900,9 @@ class NormLayer(LayerBase): norm_conf) self.set_cnn_layer(name, norm_conf.output_y, norm_conf.output_x, norm_conf.channels, False) + if norm_conf.norm_type == "cross-channel-norm": + self.create_input_parameter(0, norm_conf.channels, + [norm_conf.channels, 1]) @config_layer('pool') diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b006eb46d99fd09c7bc31e5de41ebdb39659b663..8d2329292b5b8b408473c2e33fc43b2e586d89b6 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -112,6 +112,7 @@ __all__ = [ 'out_prod_layer', 'print_layer', 'priorbox_layer', + 'cross_channel_norm_layer', 'spp_layer', 'pad_layer', 'eos_layer', @@ -1008,6 +1009,46 @@ def priorbox_layer(input, size=size) +@wrap_name_default("cross_channel_norm") +def cross_channel_norm_layer(input, name=None, param_attr=None): + """ + Normalize a layer's output. This layer is necessary for ssd. + This layer applys normalize across the channels of each sample to + a conv layer's output and scale the output by a group of trainable + factors which dimensions equal to the channel's number. + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput + :param param_attr: The Parameter Attribute|list. + :type param_attr: ParameterAttribute + :return: LayerOutput + """ + assert input.num_filters is not None + Layer( + name=name, + type=LayerType.NORM_LAYER, + inputs=[ + Input( + input.name, + norm=Norm( + norm_type="cross-channel-norm", + channels=input.num_filters, + size=input.size, + scale=0, + pow=0, + blocked=0), + **param_attr.attr) + ]) + return LayerOutput( + name, + LayerType.NORM_LAYER, + parents=input, + num_filters=input.num_filters, + size=input.size) + + @wrap_name_default("seq_pooling") @wrap_bias_attr_default(has_bias=False) @wrap_param_default(['pooling_type'], default_factory=lambda _: MaxPooling()) diff --git a/python/paddle/v2/dataset/cifar.py b/python/paddle/v2/dataset/cifar.py index d9f7a830ee60a331b55a1e218923e690103e1c5b..3a8b98b8f045b0eb58be69649486cbd0a571f118 100644 --- a/python/paddle/v2/dataset/cifar.py +++ b/python/paddle/v2/dataset/cifar.py @@ -20,7 +20,7 @@ TODO(yuyang18): Complete the comments. import cPickle import itertools import numpy -import paddle.v2.dataset.common +from common import download import tarfile __all__ = ['train100', 'test100', 'train10', 'test10'] @@ -55,23 +55,23 @@ def reader_creator(filename, sub_name): def train100(): return reader_creator( - paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), - 'train') + download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'train') def test100(): - return reader_creator( - paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), - 'test') + return reader_creator(download(CIFAR100_URL, 'cifar', CIFAR100_MD5), 'test') def train10(): return reader_creator( - paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), - 'data_batch') + download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch') def test10(): return reader_creator( - paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), - 'test_batch') + download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'test_batch') + + +def fetch(): + download(CIFAR10_URL, 'cifar', CIFAR10_MD5) + download(CIFAR100_URL, 'cifar', CIFAR100_MD5) diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 3021b68ddb02ecaa874e21681796c0912ad4cc06..7021a6da05dec6be216534112c2df2586e73390f 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -17,6 +17,8 @@ import hashlib import os import shutil import sys +import importlib +import paddle.v2.dataset __all__ = ['DATA_HOME', 'download', 'md5file'] @@ -69,3 +71,13 @@ def dict_add(a_dict, ele): a_dict[ele] += 1 else: a_dict[ele] = 1 + + +def fetch_all(): + for module_name in filter(lambda x: not x.startswith("__"), + dir(paddle.v2.dataset)): + if "fetch" in dir( + importlib.import_module("paddle.v2.dataset.%s" % module_name)): + getattr( + importlib.import_module("paddle.v2.dataset.%s" % module_name), + "fetch")() diff --git a/python/paddle/v2/dataset/conll05.py b/python/paddle/v2/dataset/conll05.py index 9eab49ee39325c1c60fc511e0bd834e83aa987f0..f1b0ce16f21ad13d4564242c2359355236093032 100644 --- a/python/paddle/v2/dataset/conll05.py +++ b/python/paddle/v2/dataset/conll05.py @@ -196,3 +196,11 @@ def test(): words_name='conll05st-release/test.wsj/words/test.wsj.words.gz', props_name='conll05st-release/test.wsj/props/test.wsj.props.gz') return reader_creator(reader, word_dict, verb_dict, label_dict) + + +def fetch(): + download(WORDDICT_URL, 'conll05st', WORDDICT_MD5) + download(VERBDICT_URL, 'conll05st', VERBDICT_MD5) + download(TRGDICT_URL, 'conll05st', TRGDICT_MD5) + download(EMB_URL, 'conll05st', EMB_MD5) + download(DATA_URL, 'conll05st', DATA_MD5) diff --git a/python/paddle/v2/dataset/imdb.py b/python/paddle/v2/dataset/imdb.py index 76019d9f54020ff6f02c17eb6047cbd014a8ccf2..5284017ce08de8beb559f58fb6006639f40f5580 100644 --- a/python/paddle/v2/dataset/imdb.py +++ b/python/paddle/v2/dataset/imdb.py @@ -123,3 +123,7 @@ def test(word_idx): def word_dict(): return build_dict( re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150) + + +def fetch(): + paddle.v2.dataset.common.download(URL, 'imdb', MD5) diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py index 97c160f111d09d61eb860c7f02552e635f2400a7..2931d06e7eb65bde887c56a8bc20e7a9c5e4d4e4 100644 --- a/python/paddle/v2/dataset/imikolov.py +++ b/python/paddle/v2/dataset/imikolov.py @@ -89,3 +89,7 @@ def train(word_idx, n): def test(word_idx, n): return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) + + +def fetch(): + paddle.v2.dataset.common.download(URL, "imikolov", MD5) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 16f2fcb99de4cb1971a7375a97b5daa209ee95ef..48a39b5493a8004d6eb034498a797af9c662bd19 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -106,3 +106,10 @@ def test(): TEST_IMAGE_MD5), paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5), 100) + + +def fetch(): + paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5) + paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) + paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5) + paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) diff --git a/python/paddle/v2/dataset/movielens.py b/python/paddle/v2/dataset/movielens.py index 25fd8227da2f219d75c6b830e65627ecf35be453..e148ddeca0370cd76128a31ce3a4d488e9737d98 100644 --- a/python/paddle/v2/dataset/movielens.py +++ b/python/paddle/v2/dataset/movielens.py @@ -30,6 +30,9 @@ __all__ = [ age_table = [1, 18, 25, 35, 45, 50, 56] +URL = 'http://files.grouplens.org/datasets/movielens/ml-1m.zip' +MD5 = 'c4d9eecfca2ab87c1945afe126590906' + class MovieInfo(object): def __init__(self, index, categories, title): @@ -77,10 +80,7 @@ USER_INFO = None def __initialize_meta_info__(): - fn = download( - url='http://files.grouplens.org/datasets/movielens/ml-1m.zip', - module_name='movielens', - md5sum='c4d9eecfca2ab87c1945afe126590906') + fn = download(URL, "movielens", MD5) global MOVIE_INFO if MOVIE_INFO is None: pattern = re.compile(r'^(.*)\((\d+)\)$') @@ -205,5 +205,9 @@ def unittest(): print train_count, test_count +def fetch(): + download(URL, "movielens", MD5) + + if __name__ == '__main__': unittest() diff --git a/python/paddle/v2/dataset/sentiment.py b/python/paddle/v2/dataset/sentiment.py index 71689fd61b6b14a7b5072caff4e2fd48a7f74072..0eeb6d5affd8c280fb74edc82cf24bf418ca8ef9 100644 --- a/python/paddle/v2/dataset/sentiment.py +++ b/python/paddle/v2/dataset/sentiment.py @@ -125,3 +125,7 @@ def test(): """ data_set = load_sentiment_data() return reader_creator(data_set[NUM_TRAINING_INSTANCES:]) + + +def fetch(): + nltk.download('movie_reviews', download_dir=common.DATA_HOME) diff --git a/python/paddle/v2/dataset/uci_housing.py b/python/paddle/v2/dataset/uci_housing.py index 27f454b137e3a40febd19cf085e2f4034cc16b24..dab8620441c966b19d8218025f8d8fa5b40d1c2c 100644 --- a/python/paddle/v2/dataset/uci_housing.py +++ b/python/paddle/v2/dataset/uci_housing.py @@ -89,3 +89,7 @@ def test(): yield d[:-1], d[-1:] return reader + + +def fetch(): + download(URL, 'uci_housing', MD5) diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index c686870a497668517d1c78c11c616ad8a71a2980..ee63a93f5ad918b5bbc949ae6ba29082b3f6abd5 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -16,7 +16,7 @@ wmt14 dataset """ import tarfile -import paddle.v2.dataset.common +from paddle.v2.dataset.common import download __all__ = ['train', 'test', 'build_dict'] @@ -95,11 +95,13 @@ def reader_creator(tar_file, file_name, dict_size): def train(dict_size): return reader_creator( - paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN), - 'train/train', dict_size) + download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size) def test(dict_size): return reader_creator( - paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN), - 'test/test', dict_size) + download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size) + + +def fetch(): + download(URL_TRAIN, 'wmt14', MD5_TRAIN)