提交 b6910529 编写于 作者: X xuwei06

Fix bug of ScatterAgentLayer for generation

上级 3438d650
...@@ -170,23 +170,22 @@ void ScatterAgentLayer::forward(PassType passType) { ...@@ -170,23 +170,22 @@ void ScatterAgentLayer::forward(PassType passType) {
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
int width = this->getSize(); int width = this->getSize();
if (selectionMode_) {
forwardWithSelection(passType);
} else {
if (realOutArg_.hasSeq()) { if (realOutArg_.hasSeq()) {
forwardSequence(passType); output_.subArgFrom(realOutArg_,
} else if (realOutArg_.value || realOutArg_.ids) { /* offset */ idIndex_,
idSize_,
width,
useGpu_,
/* trans */ false,
/* seqFlag */ true,
/* seqStart */ seqStartPosIndex_,
/* seqSize */ numSequences_);
} else {
output_.subArgFrom( output_.subArgFrom(
realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_);
} else { // used in generation
if (realLayer_->getOutput().ids) {
IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
}
if (realLayer_->getOutput().value) {
int height = ids_->getSize();
resetOutput(height, width);
const MatrixPtr& outV = getOutputValue();
const MatrixPtr& realV = realLayer_->getOutputValue();
outV->selectRows(*realV, *ids_);
} }
} }
} }
...@@ -194,6 +193,8 @@ void ScatterAgentLayer::forward(PassType passType) { ...@@ -194,6 +193,8 @@ void ScatterAgentLayer::forward(PassType passType) {
void ScatterAgentLayer::backward(const UpdateCallback& callback) { void ScatterAgentLayer::backward(const UpdateCallback& callback) {
(void)callback; (void)callback;
CHECK(!selectionMode_);
const MatrixPtr& outputGrad = realOutArg_.grad; const MatrixPtr& outputGrad = realOutArg_.grad;
const MatrixPtr& realGrad = realLayer_->getOutputGrad(); const MatrixPtr& realGrad = realLayer_->getOutputGrad();
if (realGrad) { if (realGrad) {
...@@ -208,7 +209,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { ...@@ -208,7 +209,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) {
REGISTER_LAYER(gather_agent, GatherAgentLayer); REGISTER_LAYER(gather_agent, GatherAgentLayer);
REGISTER_LAYER(scatter_agent, ScatterAgentLayer); REGISTER_LAYER(scatter_agent, ScatterAgentLayer);
void ScatterAgentLayer::forwardSequence(PassType passType) { void ScatterAgentLayer::forwardWithSelection(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
...@@ -219,17 +220,19 @@ void ScatterAgentLayer::forwardSequence(PassType passType) { ...@@ -219,17 +220,19 @@ void ScatterAgentLayer::forwardSequence(PassType passType) {
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str()); REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str());
if (realOutArg_.value || realOutArg_.ids) { if (!input.hasSeq()) {
CHECK(realOutArg_.sequenceStartPositions); if (realLayer_->getOutput().ids) {
output_.subArgFrom(realOutArg_, IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
/* offset */ idIndex_, output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
idSize_, }
width, if (realLayer_->getOutput().value) {
useGpu_, int height = ids_->getSize();
/* trans */ false, resetOutput(height, width);
/* seqFlag */ true,
/* seqStart */ seqStartPosIndex_, const MatrixPtr& outV = getOutputValue();
/* seqSize */ numSequences_); const MatrixPtr& realV = realLayer_->getOutputValue();
outV->selectRows(*realV, *ids_);
}
} else { } else {
// Putting the generation logic here is really an ugly hack! // Putting the generation logic here is really an ugly hack!
// used in generation // used in generation
......
...@@ -110,6 +110,9 @@ protected: ...@@ -110,6 +110,9 @@ protected:
// of real layer. // of real layer.
ICpuGpuVectorPtr inputStartPos_; ICpuGpuVectorPtr inputStartPos_;
// true for setRealLayer, false for setRealLayerAndOutput
bool selectionMode_;
public: public:
explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {} explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {}
...@@ -137,6 +140,7 @@ public: ...@@ -137,6 +140,7 @@ public:
} else { } else {
cpuIds_ = ids_; cpuIds_ = ids_;
} }
selectionMode_ = true;
} }
// set real layer and output, [idIndex, idIndex + idSize) of *ids* // set real layer and output, [idIndex, idIndex + idSize) of *ids*
...@@ -153,6 +157,7 @@ public: ...@@ -153,6 +157,7 @@ public:
idIndex_ = idIndex; idIndex_ = idIndex;
idSize_ = idSize; idSize_ = idSize;
handleBackward_ = handleBackward; handleBackward_ = handleBackward;
selectionMode_ = false;
} }
void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions, void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions,
...@@ -166,7 +171,7 @@ public: ...@@ -166,7 +171,7 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
void forwardSequence(PassType passType); void forwardWithSelection(PassType passType);
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册