提交 d369577f 编写于 作者: L Luo Tao

add reversed poolSequenceWithStride

上级 08d6622d
......@@ -40,7 +40,6 @@ class SequenceLastInstanceLayer : public SequencePoolLayer {
protected:
MatrixPtr tmpSrc_;
MatrixPtr tmpDest_;
bool select_first_;
std::vector<int> insId_;
public:
......@@ -59,7 +58,7 @@ REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer);
bool SequenceLastInstanceLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
SequencePoolLayer::init(layerMap, parameterMap);
select_first_ = config_.select_first();
reversed_ = config_.select_first();
tmpSrc_ =
Matrix::create(nullptr, /* height= */ 1, 1, /* trans= */ false, useGpu_);
......@@ -83,7 +82,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) {
insId_.clear();
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
int insId = select_first_ ? starts[seqId] : starts[seqId + 1] - 1;
int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1;
insId_.push_back(insId);
outputValue->subMatrix(seqId, 1, tmpDest_)
......
......@@ -68,8 +68,9 @@ void SequencePoolLayer::forward(PassType passType) {
}
if (stride_ > 0) {
CHECK_EQ(input.hasSubseq(), 0UL)
<< "sequence stride pooling is not suitable for hasSubseq now";
output_.poolSequenceWithStride(input, stride_, &stridePositions_);
<< "sequence stride pooling is invalid for hasSubseq now";
output_.poolSequenceWithStride(
input, stride_, &stridePositions_, reversed_);
newBatchSize_ = stridePositions_->getSize() - 1;
}
......
......@@ -49,6 +49,8 @@ protected:
int stride_;
// store the start position of each stride window
IVectorPtr stridePositions_;
// Whether it is reversed sequence
bool reversed_ = false;
public:
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}
......
......@@ -561,11 +561,13 @@ void Argument::degradeSequence(const Argument& input) {
void Argument::poolSequenceWithStride(const Argument& input,
size_t stride,
IVectorPtr* stridePostions) {
IVectorPtr* stridePostions,
bool reversed) {
/*
* If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
* then sequenceStartPositions = [0, 2, 3, 4, 7],
* and stridePostions = [0, 5, 9, 14, 17, 22, 27, 30]
* then sequenceStartPositions = [0, 2, 3, 4, 7].
* If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30];
* else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
*/
CHECK(input.sequenceStartPositions);
CHECK_EQ(input.hasSubseq(), 0UL);
......@@ -584,14 +586,13 @@ void Argument::poolSequenceWithStride(const Argument& input,
if (seqLength == 0) {
// empty sequence
tgtBuf[seqId + 1] = tgtBuf[seqId];
} else if (seqLength < stride) {
tgtBuf[seqId + 1] = tgtBuf[seqId] + 1;
} else {
tgtBuf[seqId + 1] = tgtBuf[seqId] + ceil((float)seqLength / stride);
int size =
(seqLength % stride) ? seqLength / stride : seqLength / stride - 1;
for (int i = 0; i < size; i++) {
stridePos.emplace_back(stridePos.back() + stride);
int size = ceil((float)seqLength / stride);
tgtBuf[seqId + 1] = tgtBuf[seqId] + size;
for (int i = 0; i < size - 1; i++) {
int cur = reversed ? starts[seqId + 1] - (size - 1 - i) * stride
: stridePos.back() + stride;
stridePos.emplace_back(cur);
}
}
}
......
......@@ -298,7 +298,8 @@ struct Argument {
*/
void poolSequenceWithStride(const Argument& input,
size_t stride,
IVectorPtr* stridePositions);
IVectorPtr* stridePositions,
bool reversed = false);
/**
* @brief getValueString will return the argument's output in string. There
* are several kinds of output. The keys of output dictionary are 'value',
......
......@@ -27,20 +27,26 @@ TEST(Argument, poolSequenceWithStride) {
inStart[3] = 17;
inStart[4] = 30;
IVectorPtr stridePositions;
output.poolSequenceWithStride(input, 5 /* stride */, &stridePositions);
const int* outStart = output.sequenceStartPositions->getData(false);
CHECK_EQ(outStart[0], 0);
CHECK_EQ(outStart[1], 2);
CHECK_EQ(outStart[2], 3);
CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions->getSize(), 8);
int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30};
for (int i = 0; i < 8; i++) {
CHECK_EQ(stridePositions->getData()[i], strideResult[i]);
int strideResultReversed[] = {0, 4, 9, 14, 17, 20, 25, 30};
for (auto reversed : {false, true}) {
IVectorPtr stridePositions;
output.poolSequenceWithStride(
input, 5 /* stride */, &stridePositions, reversed);
const int* outStart = output.sequenceStartPositions->getData(false);
CHECK_EQ(outStart[0], 0);
CHECK_EQ(outStart[1], 2);
CHECK_EQ(outStart[2], 3);
CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions->getSize(), 8);
auto result = reversed ? strideResultReversed : strideResult;
for (int i = 0; i < 8; i++) {
CHECK_EQ(stridePositions->getData()[i], result[i]);
}
}
}
......
......@@ -2497,7 +2497,7 @@ class SequenceLastInstanceLayer(LayerBase):
config_assert(
len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input')
if trans_type == 'seq':
config_assert(stride == -1, 'subseq do not support stride window')
config_assert(stride == -1, 'subseq does not support stride window')
self.config.trans_type = trans_type
self.config.seq_pool_stride = stride
self.set_layer_size(self.get_input_layer(0).size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册