提交 36f0aa73 编写于 作者: C caoying03

fix code style to pass CI.

上级 3d1b8719
...@@ -28,8 +28,9 @@ void CostForOneSequence::calValidExpandStep() { ...@@ -28,8 +28,9 @@ void CostForOneSequence::calValidExpandStep() {
start, start,
start + goldRowIds_[i - 1] * beamSize_ + goldColIds_[i - 1], start + goldRowIds_[i - 1] * beamSize_ + goldColIds_[i - 1],
[](const real& val) { return val != -1.; }); [](const real& val) { return val != -1.; });
} else } else {
goldRowIds_[i] = 0; goldRowIds_[i] = 0;
}
real* start = real* start =
beams_->candidateIds[i]->getData() + goldRowIds_[i] * beamSize_; beams_->candidateIds[i]->getData() + goldRowIds_[i] * beamSize_;
...@@ -288,7 +289,7 @@ void CrossEntropyOverBeam::copyInputsToCpu() { ...@@ -288,7 +289,7 @@ void CrossEntropyOverBeam::copyInputsToCpu() {
void CrossEntropyOverBeam::splitBatchBeams() { void CrossEntropyOverBeam::splitBatchBeams() {
beamCosts_.resize(batchSize_); beamCosts_.resize(batchSize_);
beamPerSeq_.resize(batchSize_, beamExpanCount_); beamPerSeq_.resize(batchSize_, BeamExpansion(beamExpanCount_));
for (size_t i = 0; i < beamExpanCount_; ++i) { for (size_t i = 0; i < beamExpanCount_; ++i) {
int* seqStarts = int* seqStarts =
...@@ -300,8 +301,9 @@ void CrossEntropyOverBeam::splitBatchBeams() { ...@@ -300,8 +301,9 @@ void CrossEntropyOverBeam::splitBatchBeams() {
subSeqStarts = subSeqStarts =
getInput(i * 3).subSequenceStartPositions->getMutableData(false); getInput(i * 3).subSequenceStartPositions->getMutableData(false);
maxLen = getInput(i * 3).subSequenceStartPositions->getSize() - 1; maxLen = getInput(i * 3).subSequenceStartPositions->getSize() - 1;
} else } else {
maxLen = getInput(i).sequenceStartPositions->getSize() - 1; maxLen = getInput(i).sequenceStartPositions->getSize() - 1;
}
for (size_t j = 0; j < batchSize_; ++j) { for (size_t j = 0; j < batchSize_; ++j) {
beamPerSeq_[j].scores[i] = beamPerSeq_[j].scores[i] =
...@@ -348,8 +350,9 @@ void CrossEntropyOverBeam::resizeOutput() { ...@@ -348,8 +350,9 @@ void CrossEntropyOverBeam::resizeOutput() {
inGrad->getWidth(), inGrad->getWidth(),
false, false,
false); false);
} else } else {
candidateScoreGrad_[i] = std::move(inGrad); candidateScoreGrad_[i] = std::move(inGrad);
}
candidateScoreGrad_[i]->zeroMem(); candidateScoreGrad_[i]->zeroMem();
} }
} }
......
...@@ -31,7 +31,7 @@ struct BeamExpansion { ...@@ -31,7 +31,7 @@ struct BeamExpansion {
size_t expansionCount; size_t expansionCount;
BeamExpansion(int n) { explicit BeamExpansion(int n) {
expansionCount = n; expansionCount = n;
scores.resize(expansionCount); scores.resize(expansionCount);
seqInfo.resize(expansionCount); seqInfo.resize(expansionCount);
...@@ -39,7 +39,7 @@ struct BeamExpansion { ...@@ -39,7 +39,7 @@ struct BeamExpansion {
scoreGrad.resize(expansionCount); scoreGrad.resize(expansionCount);
gold.resize(expansionCount); gold.resize(expansionCount);
}; }
}; };
typedef std::shared_ptr<BeamExpansion> BeamExpansionPtr; typedef std::shared_ptr<BeamExpansion> BeamExpansionPtr;
...@@ -74,7 +74,7 @@ private: ...@@ -74,7 +74,7 @@ private:
CHECK_GT(beams_->seqInfo[beamId]->getSize() - 1, rowId); CHECK_GT(beams_->seqInfo[beamId]->getSize() - 1, rowId);
int* starts = beams_->seqInfo[beamId]->getData(); int* starts = beams_->seqInfo[beamId]->getData();
return starts[rowId] - starts[0]; return starts[rowId] - starts[0];
}; }
size_t beamSize_; size_t beamSize_;
size_t validExpansionCount_; size_t validExpansionCount_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册