diff --git a/.travis.yml b/.travis.yml index 376c693602b56fe719decfeb41c217497e143e12..8c8c6699d3d9abddd65a3a224c2bceedc7d88348 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,7 +38,7 @@ before_install: # Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python # protobuf version. - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker - - pip install rarfile + - pip install rarfile nltk==3.2.2 scipy==0.19.0 recordio matplotlib Pillow - curl https://glide.sh/get | bash - eval "$(GIMME_GO_VERSION=1.8.3 gimme)" - go get -u github.com/alecthomas/gometalinter diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index dc7c6d2e59726a235c791a5ae340509345ba10fa..cb330ea5e1b914587a725c9b90a33053f3fbbc3d 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -257,6 +257,11 @@ seq_concat .. autoclass:: paddle.v2.layer.seq_concat :noindex: +kmax_sequence_score +------------------- +.. autoclass:: paddle.v2.layer.kmax_sequence_score + :noindex: + sub_nested_seq -------------- .. autoclass:: paddle.v2.layer.sub_nested_seq diff --git a/doc/templates/conf.py.cn.in b/doc/templates/conf.py.cn.in index 95cad835b11816f4d2e256c2abd662a545a5bad2..673948dfe7928240817b552141ec9bc2f8a672b7 100644 --- a/doc/templates/conf.py.cn.in +++ b/doc/templates/conf.py.cn.in @@ -13,15 +13,11 @@ # serve to show the default. import sys import os, subprocess +sys.path.insert(0, os.path.abspath('@PROJ_ROOT@/python')) import shlex from recommonmark import parser, transform -try: - import py_paddle - import paddle - import paddle.v2 -except ImportError: - print("Must install paddle python package before generating documentation") - sys.exit(1) +import paddle +import paddle.v2 MarkdownParser = parser.CommonMarkParser AutoStructify = transform.AutoStructify diff --git a/doc/templates/conf.py.en.in b/doc/templates/conf.py.en.in index b477f0120c4fa0544012080b7cfb8572d3c44b04..b6b50b7dcd5647b50a13703160489323ed90a1b4 100644 --- a/doc/templates/conf.py.en.in +++ b/doc/templates/conf.py.en.in @@ -13,15 +13,11 @@ # serve to show the default. import sys import os, subprocess +sys.path.insert(0, os.path.abspath('@PROJ_ROOT@/python')) import shlex from recommonmark import parser, transform -try: - import py_paddle - import paddle - import paddle.v2 -except ImportError: - print("Must install paddle python package before generating documentation") - sys.exit(1) +import paddle +import paddle.v2 MarkdownParser = parser.CommonMarkParser diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1ebab6b8ab107089bd6f52bf99245b82e5b494e0..ad0785e2702e078220ece4cb760b99a1eeecd114 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -38,14 +38,15 @@ cc_test(backward_test SRCS backward_test.cc DEPS backward) if(WITH_PYTHON) cc_library(paddle_pybind SHARED - SRCS pybind.cc - DEPS pybind python backward - fc_op - sgd_op - add_op - mean_op - cross_entropy_op - gaussian_random_op - fill_zeros_like_op - recurrent_op) + SRCS pybind.cc + DEPS pybind python backward + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_op + uniform_random_op + gaussian_random_op + fill_zeros_like_op) endif(WITH_PYTHON) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 27ee47c2b55df774155b014c791826ea24335d4c..6b515c9bb178e250b098717b5c352f86fadbe34f 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -43,6 +43,8 @@ USE_OP(rowwise_add); USE_OP(fill_zeros_like); USE_OP_WITHOUT_KERNEL(recurrent_op); USE_OP(gaussian_random); +USE_OP(uniform_random); + namespace paddle { namespace framework { template diff --git a/paddle/gserver/layers/KmaxSeqScoreLayer.cpp b/paddle/gserver/layers/KmaxSeqScoreLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ce591d4762466e1ed4b2970cb9cae9203bc0a2b --- /dev/null +++ b/paddle/gserver/layers/KmaxSeqScoreLayer.cpp @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +class KmaxSeqScoreLayer : public Layer { +private: + MatrixPtr scores_; + size_t beamSize_; + void kmaxScorePerSeq(const real* score, + real* sortedRes, + const ICpuGpuVectorPtr seqStartPos); + +public: + explicit KmaxSeqScoreLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(kmax_seq_score, KmaxSeqScoreLayer); + +bool KmaxSeqScoreLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + bool ret = Layer::init(layerMap, parameterMap); + CHECK_EQ(1U, inputLayers_.size()); + + beamSize_ = config_.beam_size(); + CHECK_GE(beamSize_, 1U); + + setNeedSequenceInfo(false); + setNeedGradient(false); + return ret; +} + +void KmaxSeqScoreLayer::kmaxScorePerSeq(const real* scores, + real* sortedIds, + const ICpuGpuVectorPtr seqStartPos) { + int* starts = seqStartPos->getMutableData(false); + std::vector indices; + for (size_t i = 0; i < seqStartPos->getSize() - 1; ++i) { + int seqLen = starts[i + 1] - starts[i]; + int k = std::min(static_cast(beamSize_), seqLen); + + indices.resize(seqLen, 0); + std::iota(begin(indices), end(indices), 0.); + std::vector tmpScore(scores + starts[i], scores + starts[i + 1]); + std::partial_sort( + begin(indices), + begin(indices) + k, + end(indices), + [&](size_t a, size_t b) { return tmpScore[a] > tmpScore[b]; }); + memcpy(sortedIds + (i * beamSize_), indices.data(), k * sizeof(real)); + } +} + +void KmaxSeqScoreLayer::forward(PassType passType) { + Layer::forward(passType); + + const Argument& input = getInput(0); + const MatrixPtr inputScore = getInputValue(0); + + CHECK(input.hasSeq() || input.hasSubseq()) + << "input of " << getName() + << " must be a sequence or a nested sequence."; + CHECK_EQ(input.value->getWidth(), 1UL) + << "input of " << getName() + << " is score over a sequence or a nested sequence, so its width " + << " must be 1."; + + if (useGpu_) { + // this Layer runs only in CPU, if the model is runing on GPU, + // then copy the input to this layer from GPU to CPU. + Matrix::resizeOrCreate(scores_, + inputScore->getHeight(), + 1, + false /* trans */, + false /* useGpu */); + scores_->copyFrom(*inputScore); + } else { + scores_ = inputScore; + } + + Matrix::resizeOrCreate( + output_.value, + input.hasSubseq() ? input.getNumSubSequences() : input.getNumSequences(), + beamSize_, + false, + false); + output_.value->one(); + output_.value->mulScalar(-1.); + + kmaxScorePerSeq(scores_->getData(), + output_.value->getData(), + input.hasSubseq() ? input.subSequenceStartPositions + : input.sequenceStartPositions); +} + +void KmaxSeqScoreLayer::backward(const UpdateCallback& callback) {} + +} // namespace paddle diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index 5511ab6b8bb05108e76cc0913264d864d2fecf5b..209d0ab9c8d7e8463c8636b1412622a94f359fb1 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -66,6 +66,16 @@ add_unittest_without_exec(test_BatchNorm add_test(NAME test_BatchNorm COMMAND test_BatchNorm) + + +################# test_KmaxSeqScore ####################### +add_unittest_without_exec(test_KmaxSeqScore + test_KmaxSeqScore.cpp + LayerGradUtil.cpp) + +add_test(NAME test_KmaxSeqScore + COMMAND test_KmaxSeqScore) + ################## test_Evaluator ####################### add_unittest(test_Evaluator test_Evaluator.cpp) diff --git a/paddle/gserver/tests/test_KmaxSeqScore.cpp b/paddle/gserver/tests/test_KmaxSeqScore.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f958b4974d45ef65f8f374148a31ad3a6ce7632f --- /dev/null +++ b/paddle/gserver/tests/test_KmaxSeqScore.cpp @@ -0,0 +1,160 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "ModelConfig.pb.h" +#include "paddle/gserver/layers/DataLayer.h" +#include "paddle/trainer/Trainer.h" +#include "paddle/utils/GlobalConstants.h" + +#include "LayerGradUtil.h" +#include "paddle/testing/TestUtil.h" + +using namespace paddle; // NOLINT +using namespace std; // NOLINT + +DECLARE_bool(use_gpu); +DECLARE_int32(gpu_id); +DECLARE_bool(thread_local_rand_use_global_seed); + +vector randSampling(int range, int n) { + CHECK_GE(range, n); + vector num(range); + iota(begin(num), end(num), 0); + if (range == n) return num; + + random_shuffle(begin(num), end(num)); + num.resize(n); + return num; +} + +void genRandomSeqInfo(vector& seqStartPosition, + vector& subSeqStartPosition) { + const int maxSeqNum = 100; + // generate random start position information + int seqNum = 1 + (rand() % maxSeqNum); + seqStartPosition.resize(seqNum + 1, 0); + subSeqStartPosition.resize(1, 0); + + for (int i = 0; i < seqNum; ++i) { + int subSeqLen = 1 + (rand() % maxSeqNum); + for (int j = 0; j < subSeqLen; ++j) + subSeqStartPosition.push_back(subSeqStartPosition.back() + subSeqLen); + seqStartPosition[i + 1] = subSeqStartPosition.back(); + } +} + +void genRandomGroundTruth(real* values, + vector>& groundTruth, + vector& startPos, + size_t beamSize) { + groundTruth.resize(startPos.size() - 1, vector(beamSize, -1)); + for (size_t i = 0; i < startPos.size() - 1; ++i) { + int seqLen = startPos[i + 1] - startPos[i]; + vector pos = + randSampling(seqLen, min(static_cast(beamSize), seqLen)); + for (size_t j = 0; j < pos.size(); ++j) { + groundTruth[i][j] = pos[j]; + values[startPos[i] + pos[j]] = 1.; + } + } +} + +void checkLayerOut(vector> groundTruth, + real* layerOut, + size_t beamSize) { + for (size_t i = 0; i < groundTruth.size(); ++i) { + int begPos = i * beamSize; + vector tmp(layerOut + begPos, layerOut + begPos + beamSize); + sort(begin(tmp), end(tmp)); + sort(begin(groundTruth[i]), end(groundTruth[i])); + for (size_t j = 0; j < beamSize; ++j) CHECK_EQ(tmp[j], groundTruth[i][j]); + } +} + +TEST(Layer, kmaxSeqScoreLayer) { + const size_t maxBeamSize = 100; + int beamSize = 1 + (rand() % maxBeamSize); + + vector seqStartPosition; + vector subSeqStartPosition; + genRandomSeqInfo(seqStartPosition, subSeqStartPosition); + MatrixPtr inValue = + Matrix::create(subSeqStartPosition.back(), 1, false, false); + + for (auto hasSubseq : {false, true}) { + vector> groundTruth; + inValue->randomizeUniform(); + genRandomGroundTruth(inValue->getData(), + groundTruth, + hasSubseq ? subSeqStartPosition : seqStartPosition, + beamSize); + + for (auto useGpu : {false, true}) { + TestConfig config; + config.layerConfig.set_type("kmax_seq_score"); + config.layerConfig.set_beam_size(beamSize); + + if (hasSubseq) { + config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, + "scores", + inValue, + seqStartPosition, + subSeqStartPosition}); + } else { + config.inputDefs.push_back( + {INPUT_SELF_DEFINE_DATA, "scores", inValue, seqStartPosition}); + } + config.layerConfig.add_inputs(); + + // data layer initialize + std::vector dataLayers; + LayerMap layerMap; + vector datas; + initDataLayer( + config, + &dataLayers, + &datas, + &layerMap, + "kmax_seq_score", + 100 /* actually this parameter is unused in self-defined input*/, + false, + useGpu); + // test layer initialize + std::vector parameters; + LayerPtr kmaxSeqScoreLayer; + FLAGS_use_gpu = useGpu; + initTestLayer(config, &layerMap, ¶meters, &kmaxSeqScoreLayer); + kmaxSeqScoreLayer->forward(PASS_TRAIN); + + const MatrixPtr outValue = kmaxSeqScoreLayer->getOutputValue(); + CHECK_EQ(outValue->getHeight(), + hasSubseq ? subSeqStartPosition.size() - 1 + : seqStartPosition.size() - 1); + CHECK_EQ(outValue->getWidth(), beamSize); + checkLayerOut(groundTruth, outValue->getData(), beamSize); + } + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + initMain(argc, argv); + FLAGS_thread_local_rand_use_global_seed = true; + srand((size_t)(time(NULL))); + return RUN_ALL_TESTS(); +} diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 3b60df021868655b94b164456a866f9f89c5b8fc..e103049f1968732ab5990b4d267a39a6c722aa08 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -67,3 +67,5 @@ op_library(fc_op op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS op_desc tensor op_registry operator net_op) cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) +op_library(uniform_random_op + SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..405b84b76d2e24db25d2ff16e99495f2f132ef09 --- /dev/null +++ b/paddle/operators/uniform_random_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class CPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(context.op_.GetAttr("min")), + static_cast(context.op_.GetAttr("max"))); + for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { + data[i] = dist(engine); + } + } +}; + +class UniformRandomOp : public framework::OperatorWithKernel { + protected: + void InferShape(const framework::InferShapeContext& ctx) const override { + PADDLE_ENFORCE(GetAttr("min") < GetAttr("max"), + "uniform_random's min must less then max"); + auto* tensor = ctx.Output(0); + auto dims = GetAttr>("dims"); + tensor->Resize(framework::make_ddim(dims)); + } +}; + +class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { + public: + UniformRandomOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "The output tensor of uniform random op"); + AddComment(R"DOC(Uniform random operator. + +Used to initialize tensor with uniform random generator. +)DOC"); + AddAttr>("dims", "the dimension of random tensor"); + AddAttr("min", "Minimum value of uniform random").SetDefault(-1.0f); + AddAttr("max", "Maximun value of uniform random").SetDefault(1.0f); + AddAttr("seed", + "Random seed of uniform random. " + "0 means generate a seed by system") + .SetDefault(0); + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp, + paddle::operators::UniformRandomOpMaker); +REGISTER_OP_CPU_KERNEL(uniform_random, + paddle::operators::CPUUniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..f1a63e52ec0d3d46a505a89d7d7916bf93a58221 --- /dev/null +++ b/paddle/operators/uniform_random_op.cu @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output(0); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + seed = std::random_device()(); + } + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel); diff --git a/paddle/scripts/travis/build_doc.sh b/paddle/scripts/travis/build_doc.sh index 33fb5d84e2701c163b5d1b1bb3362ee81ebb34ea..dfcff38302703066e868c60e213f0f7cbc55a31e 100755 --- a/paddle/scripts/travis/build_doc.sh +++ b/paddle/scripts/travis/build_doc.sh @@ -5,15 +5,9 @@ set -e mkdir -p $TRAVIS_BUILD_DIR/build cd $TRAVIS_BUILD_DIR/build -# Compile paddle binaries first -cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=OFF -DWITH_MKLDNN=OFF -DWITH_MKLML=OFF -DWITH_GOLANG=ON -DWITH_STYLE_CHECK=OFF - -mkdir output -make -j `nproc` -find .. -name '*whl' | xargs pip install # install all wheels. -rm -rf * # Compile Documentation only. cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DWITH_MKLML=OFF -DWITH_DOC=ON +make -j `nproc` gen_proto_py make -j `nproc` paddle_docs paddle_docs_cn # check websites for broken links @@ -35,6 +29,7 @@ TARGET_BRANCH="gh-pages" SOURCE_BRANCH="master" # Clone the repo to output directory +mkdir output git clone $REPO output cd output diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index 18584cafe7971bad281b498908c54780250791b7..e1cea8bd0de5394020a498725485cea025512e48 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -17,7 +17,7 @@ foreach(filename ${proto_filenames}) COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} ARGS "--python_out=${PROJ_ROOT}/python/paddle/proto" "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL} - DEPENDS ${ABS_FIL} ${external_project_dependencies}) + DEPENDS ${ABS_FIL} protoc) endforeach() add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index c8fc49e20da2e212330e0cccc10fbeb4e25b87a8..b7b696ef0c13e1bae2e910e08d1a1ea3e45cd5d5 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3248,6 +3248,16 @@ class CTCLayer(LayerBase): config_assert(len(self.inputs) == 2, 'CTCLayer must have 2 inputs') +@config_layer('kmax_seq_score') +class KmaxSeqScoreLayer(LayerBase): + def __init__(self, name, inputs, beam_size, **xargs): + super(KmaxSeqScoreLayer, self).__init__( + name, 'kmax_seq_score', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, 'KmaxSeqScoreLayer has only one input.') + self.config.beam_size = beam_size + + @config_layer('warp_ctc') class WarpCTCLayer(LayerBase): def __init__(self, diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 2c7cebc359173da5bf713e251f2773e7a04c76ed..1bc55c869601551aff5fc0311458f906385522d2 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -132,6 +132,7 @@ __all__ = [ 'sub_nested_seq_layer', 'clip_layer', 'slice_projection', + 'kmax_sequence_score_layer', ] @@ -228,6 +229,8 @@ class LayerType(object): SUB_NESTED_SEQ = 'sub_nested_seq' CLIP_LAYER = 'clip' + KMAX_SEQ_SCORE = 'kmax_seq_score' + @staticmethod def is_layer_type(type_name): """ @@ -6158,7 +6161,8 @@ def clip_layer(input, min, max, name=None): :type min: double :param max: The upper threshold for clipping. :type max: double - :return: LayerOutput + :return: LayerOutput object. + :rtype: LayerOutput """ Layer( name=name, @@ -6168,3 +6172,41 @@ def clip_layer(input, min, max, name=None): max=max) return LayerOutput( name, LayerType.CLIP_LAYER, parents=[input], size=input.size) + + +@wrap_name_default() +@layer_support() +def kmax_sequence_score_layer(input, name=None, beam_size=1): + """ + This layer accepts one input which are scores over a sequence or a nested + sequence, and returns indices of beam_size sequences with highest scores. + + .. code-block:: python + + kmax_indices = kmax_sequence_score_layer(input=input_layer, beam_size) + + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. It stores scores over a sequence or a nested + sequence and its size must be 1. + :type input: LayerOutput. + :param beam_size: squence indices with top beam_size scores are returned. + :type beam_size: double + :return: LayerOutput object. + :rtype: LayerOutput + """ + assert isinstance(input, LayerOutput), ("kmax_sequence_score_layer " + "accepts only one input.") + assert input.size == 1, ( + "input of kmax_sequence_score_layer is a score" + "over a sequence or a nested sequence, so its width must be 1.") + + Layer( + name=name, + type=LayerType.KMAX_SEQ_SCORE, + inputs=[input.name], + beam_size=beam_size) + + return LayerOutput( + name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index a627369275112aad7cbf745c78c4e2df2804e0f5..a61beb871ad064c617fa141451afcb2a5ac64854 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer -test_seq_select_layers) +test_kmax_seq_socre_layer test_seq_select_layers) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..81bd71f68eb3f2c04ccd46ee3b77a07543395c60 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_kmax_seq_socre_layer.protostr @@ -0,0 +1,66 @@ +type: "nn" +layers { + name: "input" + type: "data" + size: 300 + active_type: "" +} +layers { + name: "data" + type: "data" + size: 128 + active_type: "" +} +layers { + name: "__fc_layer_0__" + type: "fc" + size: 1 + active_type: "exponential" + inputs { + input_layer_name: "data" + input_parameter_name: "___fc_layer_0__.w0" + } + bias_parameter_name: "___fc_layer_0__.wbias" +} +layers { + name: "__kmax_sequence_score_layer_0__" + type: "kmax_seq_score" + active_type: "" + inputs { + input_layer_name: "__fc_layer_0__" + } + beam_size: 5 +} +parameters { + name: "___fc_layer_0__.w0" + size: 128 + initial_mean: 0.0 + initial_std: 0.0883883476483 + dims: 128 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___fc_layer_0__.wbias" + size: 1 + initial_mean: 0.0 + initial_std: 0.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "__kmax_sequence_score_layer_0__" +sub_models { + name: "root" + layer_names: "input" + layer_names: "data" + layer_names: "__fc_layer_0__" + layer_names: "__kmax_sequence_score_layer_0__" + input_layer_names: "data" + output_layer_names: "__kmax_sequence_score_layer_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..d245c5a41c793e1f02f306bfe64071bd9885906e --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_kmax_seq_socre_layer.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +#coding=utf-8 +from paddle.trainer_config_helpers import * + +data = data_layer(name='input', size=300) + +data = data_layer(name="data", size=128) +scores = fc_layer(input=data, size=1, act=ExpActivation()) +kmax_seq_id = kmax_sequence_score_layer(input=scores, beam_size=5) + +outputs(kmax_seq_id) diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 8b9f1b1b1110eeac504dff8d1b0a77c045c007e5..cbe9369118255aec1fde808125a0ff61ff0f0f56 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -25,3 +25,4 @@ py_test(test_op_creation_methods SRCS test_op_creation_methods.py) py_test(test_operator SRCS test_operator.py) py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_uniform_random_op SRCS test_uniform_random_op.py) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d2bb44da3977c0899b2609a8efe15b7e1789f2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -0,0 +1,35 @@ +import unittest +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core +import numpy + + +class UniformRandomTest(unittest.TestCase): + def test_uniform_random_cpu(self): + self.uniform_random_test(place=core.CPUPlace()) + + def test_uniform_random_gpu(self): + if core.is_compile_gpu(): + self.uniform_random_test(place=core.GPUPlace(0)) + + def uniform_random_test(self, place): + scope = core.Scope() + scope.new_var("X").get_tensor() + + op = Operator( + "uniform_random", + Out="X", + dims=[1000, 784], + min=-5.0, + max=10.0, + seed=10) + + op.infer_shape(scope) + ctx = core.DeviceContext.create(place) + op.run(scope, ctx) + tensor = numpy.array(scope.find_var("X").get_tensor()) + self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1) + + +if __name__ == '__main__': + unittest.main()