diff --git a/paddle/gserver/layers/SequenceLastInstanceLayer.cpp b/paddle/gserver/layers/SequenceLastInstanceLayer.cpp index c70c2b74211814559b8982aae94eed92258444e8..d29e981ad66a59e3606178834c701df908ec2221 100644 --- a/paddle/gserver/layers/SequenceLastInstanceLayer.cpp +++ b/paddle/gserver/layers/SequenceLastInstanceLayer.cpp @@ -40,7 +40,6 @@ class SequenceLastInstanceLayer : public SequencePoolLayer { protected: MatrixPtr tmpSrc_; MatrixPtr tmpDest_; - bool select_first_; std::vector insId_; public: @@ -59,7 +58,7 @@ REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer); bool SequenceLastInstanceLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { SequencePoolLayer::init(layerMap, parameterMap); - select_first_ = config_.select_first(); + reversed_ = config_.select_first(); tmpSrc_ = Matrix::create(nullptr, /* height= */ 1, 1, /* trans= */ false, useGpu_); @@ -83,7 +82,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) { insId_.clear(); for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) { - int insId = select_first_ ? starts[seqId] : starts[seqId + 1] - 1; + int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1; insId_.push_back(insId); outputValue->subMatrix(seqId, 1, tmpDest_) diff --git a/paddle/gserver/layers/SequencePoolLayer.cpp b/paddle/gserver/layers/SequencePoolLayer.cpp index f853905103a0e5814d84b04a6dd0eb5ca6beb8d8..8c49502011582b534a2ba4113ffeffaa2f06a51c 100644 --- a/paddle/gserver/layers/SequencePoolLayer.cpp +++ b/paddle/gserver/layers/SequencePoolLayer.cpp @@ -68,8 +68,9 @@ void SequencePoolLayer::forward(PassType passType) { } if (stride_ > 0) { CHECK_EQ(input.hasSubseq(), 0UL) - << "sequence stride pooling is not suitable for hasSubseq now"; - output_.poolSequenceWithStride(input, stride_, &stridePositions_); + << "sequence stride pooling is invalid for hasSubseq now"; + output_.poolSequenceWithStride( + input, stride_, &stridePositions_, reversed_); newBatchSize_ = stridePositions_->getSize() - 1; } diff --git a/paddle/gserver/layers/SequencePoolLayer.h b/paddle/gserver/layers/SequencePoolLayer.h index 92d7a841f0c73421e26e5882241f2b0d0e2fba50..ff67c0ccadd20de5ec6a9b3a85c536a09c753873 100644 --- a/paddle/gserver/layers/SequencePoolLayer.h +++ b/paddle/gserver/layers/SequencePoolLayer.h @@ -49,6 +49,8 @@ protected: int stride_; // store the start position of each stride window IVectorPtr stridePositions_; + // Whether it is reversed sequence + bool reversed_ = false; public: explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {} diff --git a/paddle/parameter/Argument.cpp b/paddle/parameter/Argument.cpp index 3cc637587bc28ce02218c38489f48df352b1e574..afbda8bdc403f205f918cdf77388361687b568b9 100644 --- a/paddle/parameter/Argument.cpp +++ b/paddle/parameter/Argument.cpp @@ -561,11 +561,13 @@ void Argument::degradeSequence(const Argument& input) { void Argument::poolSequenceWithStride(const Argument& input, size_t stride, - IVectorPtr* stridePostions) { + IVectorPtr* stridePostions, + bool reversed) { /* * If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5, - * then sequenceStartPositions = [0, 2, 3, 4, 7], - * and stridePostions = [0, 5, 9, 14, 17, 22, 27, 30] + * then sequenceStartPositions = [0, 2, 3, 4, 7]. + * If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30]; + * else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30] */ CHECK(input.sequenceStartPositions); CHECK_EQ(input.hasSubseq(), 0UL); @@ -584,14 +586,13 @@ void Argument::poolSequenceWithStride(const Argument& input, if (seqLength == 0) { // empty sequence tgtBuf[seqId + 1] = tgtBuf[seqId]; - } else if (seqLength < stride) { - tgtBuf[seqId + 1] = tgtBuf[seqId] + 1; } else { - tgtBuf[seqId + 1] = tgtBuf[seqId] + ceil((float)seqLength / stride); - int size = - (seqLength % stride) ? seqLength / stride : seqLength / stride - 1; - for (int i = 0; i < size; i++) { - stridePos.emplace_back(stridePos.back() + stride); + int size = ceil((float)seqLength / stride); + tgtBuf[seqId + 1] = tgtBuf[seqId] + size; + for (int i = 0; i < size - 1; i++) { + int cur = reversed ? starts[seqId + 1] - (size - 1 - i) * stride + : stridePos.back() + stride; + stridePos.emplace_back(cur); } } } diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 95ea90ffc2a604046252add36b0bb2e493b6050f..49a0660ccf155f24f70788f54fe0e42f718b6169 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -298,7 +298,8 @@ struct Argument { */ void poolSequenceWithStride(const Argument& input, size_t stride, - IVectorPtr* stridePositions); + IVectorPtr* stridePositions, + bool reversed = false); /** * @brief getValueString will return the argument's output in string. There * are several kinds of output. The keys of output dictionary are 'value', diff --git a/paddle/parameter/tests/test_argument.cpp b/paddle/parameter/tests/test_argument.cpp index 692bbada10d03f87e591e6e15a5fa0fb0569c7fd..81fe4ee397351a013c8616ad08fb8cb4b8dae4d0 100644 --- a/paddle/parameter/tests/test_argument.cpp +++ b/paddle/parameter/tests/test_argument.cpp @@ -27,20 +27,26 @@ TEST(Argument, poolSequenceWithStride) { inStart[3] = 17; inStart[4] = 30; - IVectorPtr stridePositions; - output.poolSequenceWithStride(input, 5 /* stride */, &stridePositions); - - const int* outStart = output.sequenceStartPositions->getData(false); - CHECK_EQ(outStart[0], 0); - CHECK_EQ(outStart[1], 2); - CHECK_EQ(outStart[2], 3); - CHECK_EQ(outStart[3], 4); - CHECK_EQ(outStart[4], 7); - - CHECK_EQ(stridePositions->getSize(), 8); int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30}; - for (int i = 0; i < 8; i++) { - CHECK_EQ(stridePositions->getData()[i], strideResult[i]); + int strideResultReversed[] = {0, 4, 9, 14, 17, 20, 25, 30}; + + for (auto reversed : {false, true}) { + IVectorPtr stridePositions; + output.poolSequenceWithStride( + input, 5 /* stride */, &stridePositions, reversed); + + const int* outStart = output.sequenceStartPositions->getData(false); + CHECK_EQ(outStart[0], 0); + CHECK_EQ(outStart[1], 2); + CHECK_EQ(outStart[2], 3); + CHECK_EQ(outStart[3], 4); + CHECK_EQ(outStart[4], 7); + + CHECK_EQ(stridePositions->getSize(), 8); + auto result = reversed ? strideResultReversed : strideResult; + for (int i = 0; i < 8; i++) { + CHECK_EQ(stridePositions->getData()[i], result[i]); + } } } diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 1a6d1c512d3931a838b413152b436d1294d91ca1..dc89419c40f8d527a3de0dc90ede0397f6f650c5 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2497,7 +2497,7 @@ class SequenceLastInstanceLayer(LayerBase): config_assert( len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input') if trans_type == 'seq': - config_assert(stride == -1, 'subseq do not support stride window') + config_assert(stride == -1, 'subseq does not support stride window') self.config.trans_type = trans_type self.config.seq_pool_stride = stride self.set_layer_size(self.get_input_layer(0).size)