提交 d369577f 编写于 作者: L Luo Tao

add reversed poolSequenceWithStride

上级 08d6622d
...@@ -40,7 +40,6 @@ class SequenceLastInstanceLayer : public SequencePoolLayer { ...@@ -40,7 +40,6 @@ class SequenceLastInstanceLayer : public SequencePoolLayer {
protected: protected:
MatrixPtr tmpSrc_; MatrixPtr tmpSrc_;
MatrixPtr tmpDest_; MatrixPtr tmpDest_;
bool select_first_;
std::vector<int> insId_; std::vector<int> insId_;
public: public:
...@@ -59,7 +58,7 @@ REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer); ...@@ -59,7 +58,7 @@ REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer);
bool SequenceLastInstanceLayer::init(const LayerMap& layerMap, bool SequenceLastInstanceLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
SequencePoolLayer::init(layerMap, parameterMap); SequencePoolLayer::init(layerMap, parameterMap);
select_first_ = config_.select_first(); reversed_ = config_.select_first();
tmpSrc_ = tmpSrc_ =
Matrix::create(nullptr, /* height= */ 1, 1, /* trans= */ false, useGpu_); Matrix::create(nullptr, /* height= */ 1, 1, /* trans= */ false, useGpu_);
...@@ -83,7 +82,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) { ...@@ -83,7 +82,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) {
insId_.clear(); insId_.clear();
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) { for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
int insId = select_first_ ? starts[seqId] : starts[seqId + 1] - 1; int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1;
insId_.push_back(insId); insId_.push_back(insId);
outputValue->subMatrix(seqId, 1, tmpDest_) outputValue->subMatrix(seqId, 1, tmpDest_)
......
...@@ -68,8 +68,9 @@ void SequencePoolLayer::forward(PassType passType) { ...@@ -68,8 +68,9 @@ void SequencePoolLayer::forward(PassType passType) {
} }
if (stride_ > 0) { if (stride_ > 0) {
CHECK_EQ(input.hasSubseq(), 0UL) CHECK_EQ(input.hasSubseq(), 0UL)
<< "sequence stride pooling is not suitable for hasSubseq now"; << "sequence stride pooling is invalid for hasSubseq now";
output_.poolSequenceWithStride(input, stride_, &stridePositions_); output_.poolSequenceWithStride(
input, stride_, &stridePositions_, reversed_);
newBatchSize_ = stridePositions_->getSize() - 1; newBatchSize_ = stridePositions_->getSize() - 1;
} }
......
...@@ -49,6 +49,8 @@ protected: ...@@ -49,6 +49,8 @@ protected:
int stride_; int stride_;
// store the start position of each stride window // store the start position of each stride window
IVectorPtr stridePositions_; IVectorPtr stridePositions_;
// Whether it is reversed sequence
bool reversed_ = false;
public: public:
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {} explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}
......
...@@ -561,11 +561,13 @@ void Argument::degradeSequence(const Argument& input) { ...@@ -561,11 +561,13 @@ void Argument::degradeSequence(const Argument& input) {
void Argument::poolSequenceWithStride(const Argument& input, void Argument::poolSequenceWithStride(const Argument& input,
size_t stride, size_t stride,
IVectorPtr* stridePostions) { IVectorPtr* stridePostions,
bool reversed) {
/* /*
* If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5, * If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
* then sequenceStartPositions = [0, 2, 3, 4, 7], * then sequenceStartPositions = [0, 2, 3, 4, 7].
* and stridePostions = [0, 5, 9, 14, 17, 22, 27, 30] * If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30];
* else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
*/ */
CHECK(input.sequenceStartPositions); CHECK(input.sequenceStartPositions);
CHECK_EQ(input.hasSubseq(), 0UL); CHECK_EQ(input.hasSubseq(), 0UL);
...@@ -584,14 +586,13 @@ void Argument::poolSequenceWithStride(const Argument& input, ...@@ -584,14 +586,13 @@ void Argument::poolSequenceWithStride(const Argument& input,
if (seqLength == 0) { if (seqLength == 0) {
// empty sequence // empty sequence
tgtBuf[seqId + 1] = tgtBuf[seqId]; tgtBuf[seqId + 1] = tgtBuf[seqId];
} else if (seqLength < stride) {
tgtBuf[seqId + 1] = tgtBuf[seqId] + 1;
} else { } else {
tgtBuf[seqId + 1] = tgtBuf[seqId] + ceil((float)seqLength / stride); int size = ceil((float)seqLength / stride);
int size = tgtBuf[seqId + 1] = tgtBuf[seqId] + size;
(seqLength % stride) ? seqLength / stride : seqLength / stride - 1; for (int i = 0; i < size - 1; i++) {
for (int i = 0; i < size; i++) { int cur = reversed ? starts[seqId + 1] - (size - 1 - i) * stride
stridePos.emplace_back(stridePos.back() + stride); : stridePos.back() + stride;
stridePos.emplace_back(cur);
} }
} }
} }
......
...@@ -298,7 +298,8 @@ struct Argument { ...@@ -298,7 +298,8 @@ struct Argument {
*/ */
void poolSequenceWithStride(const Argument& input, void poolSequenceWithStride(const Argument& input,
size_t stride, size_t stride,
IVectorPtr* stridePositions); IVectorPtr* stridePositions,
bool reversed = false);
/** /**
* @brief getValueString will return the argument's output in string. There * @brief getValueString will return the argument's output in string. There
* are several kinds of output. The keys of output dictionary are 'value', * are several kinds of output. The keys of output dictionary are 'value',
......
...@@ -27,20 +27,26 @@ TEST(Argument, poolSequenceWithStride) { ...@@ -27,20 +27,26 @@ TEST(Argument, poolSequenceWithStride) {
inStart[3] = 17; inStart[3] = 17;
inStart[4] = 30; inStart[4] = 30;
IVectorPtr stridePositions;
output.poolSequenceWithStride(input, 5 /* stride */, &stridePositions);
const int* outStart = output.sequenceStartPositions->getData(false);
CHECK_EQ(outStart[0], 0);
CHECK_EQ(outStart[1], 2);
CHECK_EQ(outStart[2], 3);
CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions->getSize(), 8);
int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30}; int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30};
for (int i = 0; i < 8; i++) { int strideResultReversed[] = {0, 4, 9, 14, 17, 20, 25, 30};
CHECK_EQ(stridePositions->getData()[i], strideResult[i]);
for (auto reversed : {false, true}) {
IVectorPtr stridePositions;
output.poolSequenceWithStride(
input, 5 /* stride */, &stridePositions, reversed);
const int* outStart = output.sequenceStartPositions->getData(false);
CHECK_EQ(outStart[0], 0);
CHECK_EQ(outStart[1], 2);
CHECK_EQ(outStart[2], 3);
CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions->getSize(), 8);
auto result = reversed ? strideResultReversed : strideResult;
for (int i = 0; i < 8; i++) {
CHECK_EQ(stridePositions->getData()[i], result[i]);
}
} }
} }
......
...@@ -2497,7 +2497,7 @@ class SequenceLastInstanceLayer(LayerBase): ...@@ -2497,7 +2497,7 @@ class SequenceLastInstanceLayer(LayerBase):
config_assert( config_assert(
len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input') len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input')
if trans_type == 'seq': if trans_type == 'seq':
config_assert(stride == -1, 'subseq do not support stride window') config_assert(stride == -1, 'subseq does not support stride window')
self.config.trans_type = trans_type self.config.trans_type = trans_type
self.config.seq_pool_stride = stride self.config.seq_pool_stride = stride
self.set_layer_size(self.get_input_layer(0).size) self.set_layer_size(self.get_input_layer(0).size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册