diff --git a/paddle/gserver/layers/AgentLayer.cpp b/paddle/gserver/layers/AgentLayer.cpp index 512932d9a55d770aa1e8b209afdc797e53a95ca4..15e7411b5fde0fa3a532394cf7d0e8477ef052d0 100644 --- a/paddle/gserver/layers/AgentLayer.cpp +++ b/paddle/gserver/layers/AgentLayer.cpp @@ -170,23 +170,22 @@ void ScatterAgentLayer::forward(PassType passType) { CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); int width = this->getSize(); - if (realOutArg_.hasSeq()) { - forwardSequence(passType); - } else if (realOutArg_.value || realOutArg_.ids) { - output_.subArgFrom( - realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); - } else { // used in generation - if (realLayer_->getOutput().ids) { - IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_); - output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_); - } - if (realLayer_->getOutput().value) { - int height = ids_->getSize(); - resetOutput(height, width); - - const MatrixPtr& outV = getOutputValue(); - const MatrixPtr& realV = realLayer_->getOutputValue(); - outV->selectRows(*realV, *ids_); + if (selectionMode_) { + forwardWithSelection(passType); + } else { + if (realOutArg_.hasSeq()) { + output_.subArgFrom(realOutArg_, + /* offset */ idIndex_, + idSize_, + width, + useGpu_, + /* trans */ false, + /* seqFlag */ true, + /* seqStart */ seqStartPosIndex_, + /* seqSize */ numSequences_); + } else { + output_.subArgFrom( + realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); } } } @@ -194,6 +193,8 @@ void ScatterAgentLayer::forward(PassType passType) { void ScatterAgentLayer::backward(const UpdateCallback& callback) { (void)callback; + CHECK(!selectionMode_); + const MatrixPtr& outputGrad = realOutArg_.grad; const MatrixPtr& realGrad = realLayer_->getOutputGrad(); if (realGrad) { @@ -208,7 +209,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { REGISTER_LAYER(gather_agent, GatherAgentLayer); REGISTER_LAYER(scatter_agent, ScatterAgentLayer); -void ScatterAgentLayer::forwardSequence(PassType passType) { +void ScatterAgentLayer::forwardWithSelection(PassType passType) { Layer::forward(passType); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); @@ -219,17 +220,19 @@ void ScatterAgentLayer::forwardSequence(PassType passType) { AsyncGpuBlock asyncGpuBlock; REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str()); - if (realOutArg_.value || realOutArg_.ids) { - CHECK(realOutArg_.sequenceStartPositions); - output_.subArgFrom(realOutArg_, - /* offset */ idIndex_, - idSize_, - width, - useGpu_, - /* trans */ false, - /* seqFlag */ true, - /* seqStart */ seqStartPosIndex_, - /* seqSize */ numSequences_); + if (!input.hasSeq()) { + if (realLayer_->getOutput().ids) { + IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_); + output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_); + } + if (realLayer_->getOutput().value) { + int height = ids_->getSize(); + resetOutput(height, width); + + const MatrixPtr& outV = getOutputValue(); + const MatrixPtr& realV = realLayer_->getOutputValue(); + outV->selectRows(*realV, *ids_); + } } else { // Putting the generation logic here is really an ugly hack! // used in generation diff --git a/paddle/gserver/layers/AgentLayer.h b/paddle/gserver/layers/AgentLayer.h index 461b84b17e556b53e0734bff8e37a0d529a3290e..29681b29c6a9a10715548839f2d365eb4a0c7381 100644 --- a/paddle/gserver/layers/AgentLayer.h +++ b/paddle/gserver/layers/AgentLayer.h @@ -110,6 +110,9 @@ protected: // of real layer. ICpuGpuVectorPtr inputStartPos_; + // true for setRealLayer, false for setRealLayerAndOutput + bool selectionMode_; + public: explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {} @@ -137,6 +140,7 @@ public: } else { cpuIds_ = ids_; } + selectionMode_ = true; } // set real layer and output, [idIndex, idIndex + idSize) of *ids* @@ -153,6 +157,7 @@ public: idIndex_ = idIndex; idSize_ = idSize; handleBackward_ = handleBackward; + selectionMode_ = false; } void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions, @@ -166,7 +171,7 @@ public: void forward(PassType passType) override; void backward(const UpdateCallback& callback) override; - void forwardSequence(PassType passType); + void forwardWithSelection(PassType passType); }; } // namespace paddle