提交 e6366e34 编写于 作者: L Luo Tao

update with comments

上级 ba68ce1a
......@@ -42,7 +42,7 @@ class SequenceLastInstanceLayer : public SequencePoolLayer {
protected:
MatrixPtr tmpSrc_;
MatrixPtr tmpDest_;
std::vector<int> insId_;
std::vector<int> instanceIds_;
public:
explicit SequenceLastInstanceLayer(const LayerConfig& config)
......@@ -82,10 +82,10 @@ void SequenceLastInstanceLayer::forward(PassType passType) {
AsyncGpuBlock asyncGpuBlock;
REGISTER_TIMER_INFO("SequenceLastInstanceLayerForward", getName().c_str());
insId_.clear();
instanceIds_.clear();
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1;
insId_.push_back(insId);
instanceIds_.push_back(insId);
outputValue->subMatrix(seqId, 1, tmpDest_)
->assign(*(inputValue->subMatrix(insId, 1, tmpSrc_)));
......@@ -111,7 +111,7 @@ void SequenceLastInstanceLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("SequenceLastInstanceLayerBackward", getName().c_str());
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
inputGrad->subMatrix(insId_[seqId], 1, tmpDest_)
inputGrad->subMatrix(instanceIds_[seqId], 1, tmpDest_)
->add(*(outputGrad->subMatrix(seqId, 1, tmpSrc_)));
}
}
......
......@@ -47,9 +47,9 @@ protected:
size_t newBatchSize_;
ICpuGpuVectorPtr startPositions_;
int stride_;
// store the start position of each window
// Store the start position of each window.
IVectorPtr stridePositions_;
// Whether the input sequence is reversed or not
// Whether the input sequence is reversed or not.
bool reversed_ = false;
public:
......
......@@ -807,7 +807,7 @@ TEST(Layer, ExpandLayer) {
void testDegradeLayer(bool hasSubseq,
string layer_type,
string trans_type,
int stride = -1) {
int stride) {
TestConfig config;
config.layerConfig.set_type(layer_type);
config.layerConfig.set_size(10);
......@@ -844,29 +844,33 @@ void testDegradeLayer(bool hasSubseq,
}
TEST(Layer, MaxLayer) {
testDegradeLayer(false, "max", "non-seq"); // seq max to non-seq
testDegradeLayer(true, "max", "non-seq"); // hasSubseq max to non-seq
testDegradeLayer(true, "max", "seq"); // hasSubseq max to seq
testDegradeLayer(false, "max", "non-seq", -1); // seq max to non-seq
testDegradeLayer(true, "max", "non-seq", -1); // hasSubseq max to non-seq
testDegradeLayer(true, "max", "seq", -1); // hasSubseq max to seq
}
TEST(Layer, SequenceLastInstanceLayer) {
testDegradeLayer(false,
"seqlastins",
"non-seq"); // seq seqlastins to non-seq
"non-seq",
-1); // seq seqlastins to non-seq
testDegradeLayer(false,
"seqlastins",
"non-seq",
5); // seq seqlastins to a shorten seq, stride window = 5
testDegradeLayer(true,
"seqlastins",
"non-seq"); // hasSubseq seqlastins to non-seq
testDegradeLayer(true, "seqlastins", "seq"); // hasSubseq seqlastins to seq
"non-seq",
-1); // hasSubseq seqlastins to non-seq
testDegradeLayer(
true, "seqlastins", "seq", -1); // hasSubseq seqlastins to seq
}
TEST(Layer, AverageLayer) {
testDegradeLayer(false, "average", "non-seq"); // seq average to non-seq
testDegradeLayer(true, "average", "non-seq"); // hasSubseq average to non-seq
testDegradeLayer(true, "average", "seq"); // hasSubseq average to seq
testDegradeLayer(false, "average", "non-seq", -1); // seq average to non-seq
testDegradeLayer(
true, "average", "non-seq", -1); // hasSubseq average to non-seq
testDegradeLayer(true, "average", "seq", -1); // hasSubseq average to seq
}
TEST(Layer, SequenceConcatLayer) {
......
......@@ -563,12 +563,11 @@ void Argument::poolSequenceWithStride(const Argument& input,
size_t stride,
IVectorPtr* stridePostions,
bool reversed) {
/*
* If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
* then sequenceStartPositions = [0, 2, 3, 4, 7].
* If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30];
* else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
*/
// If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
// then sequenceStartPositions = [0, 2, 3, 4, 7].
// If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30];
// else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
CHECK(input.sequenceStartPositions);
CHECK_EQ(input.hasSubseq(), 0UL);
CHECK_GT(stride, 0) << "stride must larger than 0";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册