提交 e6366e34 编写于 作者: L Luo Tao

update with comments

上级 ba68ce1a
...@@ -42,7 +42,7 @@ class SequenceLastInstanceLayer : public SequencePoolLayer { ...@@ -42,7 +42,7 @@ class SequenceLastInstanceLayer : public SequencePoolLayer {
protected: protected:
MatrixPtr tmpSrc_; MatrixPtr tmpSrc_;
MatrixPtr tmpDest_; MatrixPtr tmpDest_;
std::vector<int> insId_; std::vector<int> instanceIds_;
public: public:
explicit SequenceLastInstanceLayer(const LayerConfig& config) explicit SequenceLastInstanceLayer(const LayerConfig& config)
...@@ -82,10 +82,10 @@ void SequenceLastInstanceLayer::forward(PassType passType) { ...@@ -82,10 +82,10 @@ void SequenceLastInstanceLayer::forward(PassType passType) {
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
REGISTER_TIMER_INFO("SequenceLastInstanceLayerForward", getName().c_str()); REGISTER_TIMER_INFO("SequenceLastInstanceLayerForward", getName().c_str());
insId_.clear(); instanceIds_.clear();
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) { for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1; int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1;
insId_.push_back(insId); instanceIds_.push_back(insId);
outputValue->subMatrix(seqId, 1, tmpDest_) outputValue->subMatrix(seqId, 1, tmpDest_)
->assign(*(inputValue->subMatrix(insId, 1, tmpSrc_))); ->assign(*(inputValue->subMatrix(insId, 1, tmpSrc_)));
...@@ -111,7 +111,7 @@ void SequenceLastInstanceLayer::backward(const UpdateCallback& callback) { ...@@ -111,7 +111,7 @@ void SequenceLastInstanceLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("SequenceLastInstanceLayerBackward", getName().c_str()); REGISTER_TIMER_INFO("SequenceLastInstanceLayerBackward", getName().c_str());
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) { for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
inputGrad->subMatrix(insId_[seqId], 1, tmpDest_) inputGrad->subMatrix(instanceIds_[seqId], 1, tmpDest_)
->add(*(outputGrad->subMatrix(seqId, 1, tmpSrc_))); ->add(*(outputGrad->subMatrix(seqId, 1, tmpSrc_)));
} }
} }
......
...@@ -47,9 +47,9 @@ protected: ...@@ -47,9 +47,9 @@ protected:
size_t newBatchSize_; size_t newBatchSize_;
ICpuGpuVectorPtr startPositions_; ICpuGpuVectorPtr startPositions_;
int stride_; int stride_;
// store the start position of each window // Store the start position of each window.
IVectorPtr stridePositions_; IVectorPtr stridePositions_;
// Whether the input sequence is reversed or not // Whether the input sequence is reversed or not.
bool reversed_ = false; bool reversed_ = false;
public: public:
......
...@@ -807,7 +807,7 @@ TEST(Layer, ExpandLayer) { ...@@ -807,7 +807,7 @@ TEST(Layer, ExpandLayer) {
void testDegradeLayer(bool hasSubseq, void testDegradeLayer(bool hasSubseq,
string layer_type, string layer_type,
string trans_type, string trans_type,
int stride = -1) { int stride) {
TestConfig config; TestConfig config;
config.layerConfig.set_type(layer_type); config.layerConfig.set_type(layer_type);
config.layerConfig.set_size(10); config.layerConfig.set_size(10);
...@@ -844,29 +844,33 @@ void testDegradeLayer(bool hasSubseq, ...@@ -844,29 +844,33 @@ void testDegradeLayer(bool hasSubseq,
} }
TEST(Layer, MaxLayer) { TEST(Layer, MaxLayer) {
testDegradeLayer(false, "max", "non-seq"); // seq max to non-seq testDegradeLayer(false, "max", "non-seq", -1); // seq max to non-seq
testDegradeLayer(true, "max", "non-seq"); // hasSubseq max to non-seq testDegradeLayer(true, "max", "non-seq", -1); // hasSubseq max to non-seq
testDegradeLayer(true, "max", "seq"); // hasSubseq max to seq testDegradeLayer(true, "max", "seq", -1); // hasSubseq max to seq
} }
TEST(Layer, SequenceLastInstanceLayer) { TEST(Layer, SequenceLastInstanceLayer) {
testDegradeLayer(false, testDegradeLayer(false,
"seqlastins", "seqlastins",
"non-seq"); // seq seqlastins to non-seq "non-seq",
-1); // seq seqlastins to non-seq
testDegradeLayer(false, testDegradeLayer(false,
"seqlastins", "seqlastins",
"non-seq", "non-seq",
5); // seq seqlastins to a shorten seq, stride window = 5 5); // seq seqlastins to a shorten seq, stride window = 5
testDegradeLayer(true, testDegradeLayer(true,
"seqlastins", "seqlastins",
"non-seq"); // hasSubseq seqlastins to non-seq "non-seq",
testDegradeLayer(true, "seqlastins", "seq"); // hasSubseq seqlastins to seq -1); // hasSubseq seqlastins to non-seq
testDegradeLayer(
true, "seqlastins", "seq", -1); // hasSubseq seqlastins to seq
} }
TEST(Layer, AverageLayer) { TEST(Layer, AverageLayer) {
testDegradeLayer(false, "average", "non-seq"); // seq average to non-seq testDegradeLayer(false, "average", "non-seq", -1); // seq average to non-seq
testDegradeLayer(true, "average", "non-seq"); // hasSubseq average to non-seq testDegradeLayer(
testDegradeLayer(true, "average", "seq"); // hasSubseq average to seq true, "average", "non-seq", -1); // hasSubseq average to non-seq
testDegradeLayer(true, "average", "seq", -1); // hasSubseq average to seq
} }
TEST(Layer, SequenceConcatLayer) { TEST(Layer, SequenceConcatLayer) {
......
...@@ -563,12 +563,11 @@ void Argument::poolSequenceWithStride(const Argument& input, ...@@ -563,12 +563,11 @@ void Argument::poolSequenceWithStride(const Argument& input,
size_t stride, size_t stride,
IVectorPtr* stridePostions, IVectorPtr* stridePostions,
bool reversed) { bool reversed) {
/* // If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
* If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5, // then sequenceStartPositions = [0, 2, 3, 4, 7].
* then sequenceStartPositions = [0, 2, 3, 4, 7]. // If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30];
* If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30]; // else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
* else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
*/
CHECK(input.sequenceStartPositions); CHECK(input.sequenceStartPositions);
CHECK_EQ(input.hasSubseq(), 0UL); CHECK_EQ(input.hasSubseq(), 0UL);
CHECK_GT(stride, 0) << "stride must larger than 0"; CHECK_GT(stride, 0) << "stride must larger than 0";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册