提交 cbbec595 编写于 作者: L Luo Tao

adjust poolSequenceWithStride interface for average and max

上级 0291c018
...@@ -72,7 +72,8 @@ bool SequenceLastInstanceLayer::init(const LayerMap& layerMap, ...@@ -72,7 +72,8 @@ bool SequenceLastInstanceLayer::init(const LayerMap& layerMap,
void SequenceLastInstanceLayer::forward(PassType passType) { void SequenceLastInstanceLayer::forward(PassType passType) {
SequencePoolLayer::forward(passType); SequencePoolLayer::forward(passType);
const int* starts = startPositions_->getData(false); auto starts = (stride_ > 0) ? stridePositions_->getData()
: startPositions_->getData(false);
MatrixPtr inputValue = getInputValue(0); MatrixPtr inputValue = getInputValue(0);
MatrixPtr outputValue = getOutputValue(); MatrixPtr outputValue = getOutputValue();
...@@ -82,10 +83,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) { ...@@ -82,10 +83,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) {
insId_.clear(); insId_.clear();
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) { for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
int insId = (stride_ > 0) int insId = select_first_ ? starts[seqId] : starts[seqId + 1] - 1;
? (select_first_ ? stridePositions_[seqId]
: stridePositions_[seqId + 1] - 1)
: (select_first_ ? starts[seqId] : starts[seqId + 1] - 1);
insId_.push_back(insId); insId_.push_back(insId);
outputValue->subMatrix(seqId, 1, tmpDest_) outputValue->subMatrix(seqId, 1, tmpDest_)
......
...@@ -70,7 +70,7 @@ void SequencePoolLayer::forward(PassType passType) { ...@@ -70,7 +70,7 @@ void SequencePoolLayer::forward(PassType passType) {
CHECK_EQ(input.hasSubseq(), 0UL) CHECK_EQ(input.hasSubseq(), 0UL)
<< "sequence stride pooling is not suitable for hasSubseq now"; << "sequence stride pooling is not suitable for hasSubseq now";
output_.poolSequenceWithStride(input, stride_, &stridePositions_); output_.poolSequenceWithStride(input, stride_, &stridePositions_);
newBatchSize_ = stridePositions_.size() - 1; newBatchSize_ = stridePositions_->getSize() - 1;
} }
resetOutput(newBatchSize_, dim); resetOutput(newBatchSize_, dim);
......
...@@ -48,7 +48,7 @@ protected: ...@@ -48,7 +48,7 @@ protected:
ICpuGpuVectorPtr startPositions_; ICpuGpuVectorPtr startPositions_;
int stride_; int stride_;
// store the start position of each stride window // store the start position of each stride window
std::vector<int> stridePositions_; IVectorPtr stridePositions_;
public: public:
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {} explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}
......
...@@ -561,7 +561,7 @@ void Argument::degradeSequence(const Argument& input) { ...@@ -561,7 +561,7 @@ void Argument::degradeSequence(const Argument& input) {
void Argument::poolSequenceWithStride(const Argument& input, void Argument::poolSequenceWithStride(const Argument& input,
size_t stride, size_t stride,
std::vector<int>* stridePostions) { IVectorPtr* stridePostions) {
/* /*
* If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5, * If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
* then sequenceStartPositions = [0, 2, 3, 4, 7], * then sequenceStartPositions = [0, 2, 3, 4, 7],
...@@ -577,10 +577,10 @@ void Argument::poolSequenceWithStride(const Argument& input, ...@@ -577,10 +577,10 @@ void Argument::poolSequenceWithStride(const Argument& input,
int* tgtBuf = sequenceStartPositions->getMutableData(false); int* tgtBuf = sequenceStartPositions->getMutableData(false);
// first index of target sequence and stride positions are both 0 // first index of target sequence and stride positions are both 0
tgtBuf[0] = 0; tgtBuf[0] = 0;
(*stridePostions).clear(); std::vector<int> stridePos;
for (size_t seqId = 0; seqId < numSequences; ++seqId) { for (size_t seqId = 0; seqId < numSequences; ++seqId) {
size_t seqLength = starts[seqId + 1] - starts[seqId]; size_t seqLength = starts[seqId + 1] - starts[seqId];
(*stridePostions).emplace_back(starts[seqId]); stridePos.emplace_back(starts[seqId]);
if (seqLength == 0) { if (seqLength == 0) {
// empty sequence // empty sequence
tgtBuf[seqId + 1] = tgtBuf[seqId]; tgtBuf[seqId + 1] = tgtBuf[seqId];
...@@ -591,12 +591,15 @@ void Argument::poolSequenceWithStride(const Argument& input, ...@@ -591,12 +591,15 @@ void Argument::poolSequenceWithStride(const Argument& input,
int size = int size =
(seqLength % stride) ? seqLength / stride : seqLength / stride - 1; (seqLength % stride) ? seqLength / stride : seqLength / stride - 1;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
(*stridePostions).emplace_back((*stridePostions).back() + stride); stridePos.emplace_back(stridePos.back() + stride);
} }
} }
} }
(*stridePostions).emplace_back(starts[numSequences]); stridePos.emplace_back(starts[numSequences]);
CHECK_EQ((*stridePostions).size() - 1, tgtBuf[numSequences]); int size = stridePos.size();
CHECK_EQ(size - 1, tgtBuf[numSequences]);
IVector::resizeOrCreate(*stridePostions, size, false);
(*stridePostions)->copyFrom(stridePos.data(), size);
} }
void Argument::getValueString( void Argument::getValueString(
......
...@@ -298,7 +298,7 @@ struct Argument { ...@@ -298,7 +298,7 @@ struct Argument {
*/ */
void poolSequenceWithStride(const Argument& input, void poolSequenceWithStride(const Argument& input,
size_t stride, size_t stride,
std::vector<int>* stridePositions); IVectorPtr* stridePositions);
/** /**
* @brief getValueString will return the argument's output in string. There * @brief getValueString will return the argument's output in string. There
* are several kinds of output. The keys of output dictionary are 'value', * are several kinds of output. The keys of output dictionary are 'value',
......
...@@ -27,8 +27,7 @@ TEST(Argument, poolSequenceWithStride) { ...@@ -27,8 +27,7 @@ TEST(Argument, poolSequenceWithStride) {
inStart[3] = 17; inStart[3] = 17;
inStart[4] = 30; inStart[4] = 30;
std::vector<int> stridePositions; IVectorPtr stridePositions;
stridePositions.clear();
output.poolSequenceWithStride(input, 5 /* stride */, &stridePositions); output.poolSequenceWithStride(input, 5 /* stride */, &stridePositions);
const int* outStart = output.sequenceStartPositions->getData(false); const int* outStart = output.sequenceStartPositions->getData(false);
...@@ -38,10 +37,10 @@ TEST(Argument, poolSequenceWithStride) { ...@@ -38,10 +37,10 @@ TEST(Argument, poolSequenceWithStride) {
CHECK_EQ(outStart[3], 4); CHECK_EQ(outStart[3], 4);
CHECK_EQ(outStart[4], 7); CHECK_EQ(outStart[4], 7);
CHECK_EQ(stridePositions.size(), 8); CHECK_EQ(stridePositions->getSize(), 8);
int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30}; int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30};
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
CHECK_EQ(stridePositions[i], strideResult[i]); CHECK_EQ(stridePositions->getData()[i], strideResult[i]);
} }
} }
......
...@@ -1406,7 +1406,6 @@ def first_seq(input, ...@@ -1406,7 +1406,6 @@ def first_seq(input,
and a long sequence will be shorten. Note that for sequence with and a long sequence will be shorten. Note that for sequence with
sub-sequence, stride is default -1 now. sub-sequence, stride is default -1 now.
The simple usage is: The simple usage is:
.. code-block:: python .. code-block:: python
...@@ -1418,6 +1417,8 @@ def first_seq(input, ...@@ -1418,6 +1417,8 @@ def first_seq(input,
:type name: basestring :type name: basestring
:param input: Input layer name. :param input: Input layer name.
:type input: LayerOutput :type input: LayerOutput
:param stride: parameter of stride window.
:type stride: Int
:param layer_attr: extra layer attributes. :param layer_attr: extra layer attributes.
:type layer_attr: ExtraLayerAttribute. :type layer_attr: ExtraLayerAttribute.
:return: LayerOutput object. :return: LayerOutput object.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册