提交 699d5f26 编写于 作者: Z zhangruiqing01 提交者: Yu Yang

modify RecurrentGradientMachine to support unequal length inputs

* modify RecurrentGradientMachine to support hasSubSeq sequence inlinks with the same number of sentence but different number of tokens for each sentence

Change-Id: Ic71f00a4bb346b4fa93e650dfb4b1a0d8d2338b0
上级 0f91ea7e
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h" #include "paddle/utils/Util.h"
#include "paddle/utils/Flags.h" #include "paddle/utils/Flags.h"
...@@ -291,6 +290,8 @@ void RecurrentGradientMachine::init( ...@@ -291,6 +290,8 @@ void RecurrentGradientMachine::init(
if (subModelConfig->evaluator_names_size() > 0) { if (subModelConfig->evaluator_names_size() > 0) {
evaluator_.reset(frames_[0]->makeEvaluator()); evaluator_.reset(frames_[0]->makeEvaluator());
} }
targetInfoInlinkId_ = subModelConfig->target_inlinkid();
} }
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) { void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
...@@ -382,6 +383,16 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -382,6 +383,16 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
size_t numSequences = input.getNumSequences(); size_t numSequences = input.getNumSequences();
const int* starts = input.sequenceStartPositions->getData(false); const int* starts = input.sequenceStartPositions->getData(false);
bool hasSubseq = input.hasSubseq(); bool hasSubseq = input.hasSubseq();
// In case of !hasSubseq or targetInfoInlinkId_ == -1, all inlinks share the
// same inframe info
bool shareInlinkInfo = !hasSubseq || targetInfoInlinkId_ == -1;
// Defaultly, share info with the first inlink
if (shareInlinkInfo) {
targetInfoInlinkId_ = 0;
}
// check hasSubseq in both config and input are the same // check hasSubseq in both config and input are the same
CHECK_EQ(hasSubseq, inFrameLines_[0].hasSubseq); CHECK_EQ(hasSubseq, inFrameLines_[0].hasSubseq);
...@@ -394,10 +405,18 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -394,10 +405,18 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
CHECK_EQ((size_t)input1.getNumSequences(), numSequences); CHECK_EQ((size_t)input1.getNumSequences(), numSequences);
// check all inputs should have same hasSubseq flag // check all inputs should have same hasSubseq flag
CHECK_EQ(input.hasSubseq(), inFrameLines_[0].hasSubseq); CHECK_EQ(input.hasSubseq(), inFrameLines_[0].hasSubseq);
// if shareInlinkInfo, checks:
// 1. all inlinks have same number of total tokens
// 2. all inlinks have same number of tokens for each sentence of each
// sample. If hasSubseq, one sample has multiple sentence, else, one
// sample is one sentence
if (shareInlinkInfo) {
CHECK_EQ(input1.getBatchSize(), batchSize); CHECK_EQ(input1.getBatchSize(), batchSize);
CHECK(std::equal(starts, starts + numSequences + 1, CHECK(std::equal(starts, starts + numSequences + 1,
input1.sequenceStartPositions->getData(false))); input1.sequenceStartPositions->getData(false)));
} }
}
if (hasSubseq) { if (hasSubseq) {
CHECK(input.subSequenceStartPositions); CHECK(input.subSequenceStartPositions);
...@@ -408,18 +427,43 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -408,18 +427,43 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
for (size_t i = 1; i < inFrameLines_.size(); ++i) { for (size_t i = 1; i < inFrameLines_.size(); ++i) {
const Argument& input1 = inFrameLines_[i].inLayer->getOutput(); const Argument& input1 = inFrameLines_[i].inLayer->getOutput();
CHECK_EQ((size_t)input1.getNumSubSequences(), numSubSequences); CHECK_EQ((size_t)input1.getNumSubSequences(), numSubSequences);
if (shareInlinkInfo) {
CHECK(std::equal(subStarts, subStarts + numSubSequences + 1, CHECK(std::equal(subStarts, subStarts + numSubSequences + 1,
input1.subSequenceStartPositions->getData(false))); input1.subSequenceStartPositions->getData(false)));
} }
} }
}
seqLengthAndStart_.clear(); seqLengthAndStart_.clear();
input.getSeqLengthAndStart(&seqLengthAndStart_, &maxSequenceLength_); info_.clear();
resizeOrCreateFrames(maxSequenceLength_); info_.resize(inFrameLines_.size());
resizeBootFrame(numSequences); seqLengthAndStart_.resize(inFrameLines_.size());
{
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
createInFrameInfo(input, passType); // if shareInlinkInfo, only calculate info of the first inlink
// else, calculate info for each inlink
if (shareInlinkInfo) {
input.getSeqLengthAndStart(&seqLengthAndStart_[0], &maxSequenceLength_);
createInFrameInfo(0, input, passType);
} else {
for (size_t i = 0; i < inFrameLines_.size(); i++) {
const Argument& input1 = inFrameLines_[i].inLayer->getOutput();
input1.getSeqLengthAndStart(&seqLengthAndStart_[i],
&maxSequenceLength_);
createInFrameInfo(i, input1, passType);
}
}
// inFrameLine select rows in real layer one time
for (size_t i = 0; i < inFrameLines_.size(); i++) {
int curInlinkId = shareInlinkInfo ? 0 : i;
selectRowsOneTime(inFrameLines_[i].inLayer, info_[curInlinkId].allIds,
&(inFrameLines_[i].outArg), passType);
}
}
resizeOrCreateFrames(maxSequenceLength_);
resizeBootFrame(numSequences);
for (auto& memoryFrameLine : memoryFrameLines_) { for (auto& memoryFrameLine : memoryFrameLines_) {
if (memoryFrameLine.rootAgent) { if (memoryFrameLine.rootAgent) {
...@@ -443,23 +487,29 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -443,23 +487,29 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
auto gatherAgent = auto gatherAgent =
dynamic_cast<GatherAgentLayer*>(outFrameLine.agentLayer.get()); dynamic_cast<GatherAgentLayer*>(outFrameLine.agentLayer.get());
CHECK_NOTNULL(gatherAgent); CHECK_NOTNULL(gatherAgent);
gatherAgent->copyIdAndSequenceInfo(input, info_.allIds, info_.idIndex); gatherAgent->copyIdAndSequenceInfo(input, info_[targetInfoInlinkId_].allIds,
info_[targetInfoInlinkId_].idIndex);
} }
for (int i = 0; i < maxSequenceLength_; ++i) { for (int i = 0; i < maxSequenceLength_; ++i) {
int idSize = info_.idIndex[i + 1] - info_.idIndex[i]; int idSize = 0;
// connect in_links // connect in_links
for (auto& inFrameLine : inFrameLines_) { for (size_t j = 0; j < inFrameLines_.size(); ++j) {
// idSize denotes the sum number of tokens in each length i
idSize = info_[j].idIndex[i + 1] - info_[j].idIndex[i];
InFrameLine inFrameLine = inFrameLines_[j];
auto scatterAgent = auto scatterAgent =
dynamic_cast<ScatterAgentLayer*>(inFrameLine.agents[i].get()); dynamic_cast<ScatterAgentLayer*>(inFrameLine.agents[i].get());
scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer, scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer,
inFrameLine.outArg, info_.allIds, inFrameLine.outArg, info_[j].allIds,
info_.idIndex[i], idSize); info_[j].idIndex[i], idSize);
if (hasSubseq) { if (hasSubseq) {
int size = info_.seqStartPosIndex[i + 1] - info_.seqStartPosIndex[i]; // size: the length of subsequence
scatterAgent->setSequenceStartPositions( int size =
info_.sequenceStartPositions, info_.seqStartPosIndex[i], size); info_[j].seqStartPosIndex[i + 1] - info_[j].seqStartPosIndex[i];
scatterAgent->setSequenceStartPositions(info_[j].sequenceStartPositions,
info_[j].seqStartPosIndex[i],
size);
} }
} }
...@@ -471,6 +521,10 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -471,6 +521,10 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
} }
// connect memory links // connect memory links
// Adopt info_[0].idIndex because seq which has_subseq=True
// doesn't support Memory with !hasSubseq bootlayer;
// And inlinks that !hasSubSeq must have same inlink length.
idSize = info_[0].idIndex[i + 1] - info_[0].idIndex[i];
for (auto& memoryFrameLine : memoryFrameLines_) { for (auto& memoryFrameLine : memoryFrameLines_) {
NeuralNetwork::connect( NeuralNetwork::connect(
memoryFrameLine.agents[i], memoryFrameLine.agents[i],
...@@ -560,24 +614,33 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() { ...@@ -560,24 +614,33 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() {
* If hasSubseq, will also create scattered sequenceStartPositions infomation * If hasSubseq, will also create scattered sequenceStartPositions infomation
* for all realLayer of inFrameLines one time. * for all realLayer of inFrameLines one time.
*/ */
void RecurrentGradientMachine::createInFrameInfo(const Argument& input,
void RecurrentGradientMachine::createInFrameInfo(int inlinks_id,
const Argument& input,
PassType passType) { PassType passType) {
bool hasSubseq = input.hasSubseq(); bool hasSubseq = input.hasSubseq();
// numSequences: # samples(sequences) in a batch
size_t numSequences = input.getNumSequences(); size_t numSequences = input.getNumSequences();
std::vector<int> allIds; std::vector<int> allIds;
info_.idIndex.clear(); Info* inlink_info = &info_[inlinks_id];
info_.idIndex.push_back(0); // first idIndex = 0 inlink_info->idIndex.clear();
inlink_info->idIndex.push_back(0); // first idIndex = 0
if (hasSubseq) { // for sequenceScatterAgentLayer if (hasSubseq) { // for sequenceScatterAgentLayer
// numSubSequences : all sentences within all samples(batch)
size_t numSubSequences = input.getNumSubSequences(); size_t numSubSequences = input.getNumSubSequences();
std::vector<int> sequenceStartPositions; std::vector<int> sequenceStartPositions;
info_.seqStartPosIndex.clear(); inlink_info->seqStartPosIndex.clear();
info_.seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0 inlink_info->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0
// maxSequenceLength_: max number of sentences(subseq) in allsamples
for (int i = 0; i < maxSequenceLength_; ++i) { for (int i = 0; i < maxSequenceLength_; ++i) {
sequenceStartPositions.push_back(0); // first element = 0 sequenceStartPositions.push_back(0); // first element = 0
for (size_t j = 0; j < numSubSequences; ++j) { for (size_t j = 0; j < numSubSequences; ++j) { // for each sentence
if (std::get<3>(seqLengthAndStart_[j]) == i) { // seqLengthAndStart_[inlinks_id][j]:
int subSeqStart = std::get<1>(seqLengthAndStart_[j]); // a 4-tuple including <subseqlen, subseqstart, seqid, subseqid>
int subSeqLength = std::get<0>(seqLengthAndStart_[j]); if (std::get<3>(seqLengthAndStart_[inlinks_id][j]) == i) {
// subseqstart: the cpuSubSequenceStartPositions of this subseq
int subSeqStart = std::get<1>(seqLengthAndStart_[inlinks_id][j]);
int subSeqLength = std::get<0>(seqLengthAndStart_[inlinks_id][j]);
for (int k = subSeqStart; k < subSeqStart + subSeqLength; ++k) { for (int k = subSeqStart; k < subSeqStart + subSeqLength; ++k) {
allIds.push_back(k); allIds.push_back(k);
} }
...@@ -585,37 +648,34 @@ void RecurrentGradientMachine::createInFrameInfo(const Argument& input, ...@@ -585,37 +648,34 @@ void RecurrentGradientMachine::createInFrameInfo(const Argument& input,
subSeqLength); subSeqLength);
} }
} }
info_.idIndex.push_back(allIds.size()); inlink_info->idIndex.push_back(allIds.size());
info_.seqStartPosIndex.push_back(sequenceStartPositions.size()); inlink_info->seqStartPosIndex.push_back(sequenceStartPositions.size());
} }
// inFrameLine create sequenceStartPositions one time // inFrameLine create sequenceStartPositions one time
CHECK_EQ(sequenceStartPositions.size(), CHECK_EQ(sequenceStartPositions.size(),
maxSequenceLength_ + numSubSequences); maxSequenceLength_ + numSubSequences);
CHECK_EQ(info_.seqStartPosIndex.size(), CHECK_EQ(inlink_info->seqStartPosIndex.size(),
static_cast<size_t>(maxSequenceLength_ + 1)); static_cast<size_t>(maxSequenceLength_ + 1));
createSeqPos(sequenceStartPositions, &info_.sequenceStartPositions); createSeqPos(sequenceStartPositions, &inlink_info->sequenceStartPositions);
} else { // for scatterAgentLayer } else { // for scatterAgentLayer
for (int i = 0; i < maxSequenceLength_; ++i) { for (int i = 0; i < maxSequenceLength_; ++i) {
for (size_t j = 0; j < numSequences; ++j) { for (size_t j = 0; j < numSequences; ++j) {
int seqLength = std::get<0>(seqLengthAndStart_[j]); int seqLength = std::get<0>(seqLengthAndStart_[inlinks_id][j]);
if (i >= seqLength) { if (i >= seqLength) {
break; break;
} }
int seqStart = std::get<1>(seqLengthAndStart_[j]); int seqStart = std::get<1>(seqLengthAndStart_[inlinks_id][j]);
allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i) allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i)
: (seqStart + i)); : (seqStart + i));
} }
info_.idIndex.push_back(allIds.size()); inlink_info->idIndex.push_back(allIds.size());
} }
} }
// copy and check scatterId // copy and check scatterId
copyScattedId(allIds, &info_.allIds, input.getBatchSize()); copyScattedId(allIds, &inlink_info->allIds, input.getBatchSize());
CHECK_EQ(info_.idIndex.size(), static_cast<size_t>(maxSequenceLength_ + 1)); CHECK_EQ(inlink_info->idIndex.size(),
// inFrameLine select rows in real layer one time static_cast<size_t>(maxSequenceLength_ + 1));
for (auto& inFrameLine : inFrameLines_) {
selectRowsOneTime(inFrameLine.inLayer, info_.allIds, &inFrameLine.outArg,
passType);
}
} }
/* like createInFrameInfo, but for all realLayer of memoryFrameLines*/ /* like createInFrameInfo, but for all realLayer of memoryFrameLines*/
...@@ -633,7 +693,8 @@ void RecurrentGradientMachine::createMemoryFrameInfo( ...@@ -633,7 +693,8 @@ void RecurrentGradientMachine::createMemoryFrameInfo(
sequenceStartPositions.push_back(0); // first element = 0 sequenceStartPositions.push_back(0); // first element = 0
const int* starts = input.sequenceStartPositions->getData(false); const int* starts = input.sequenceStartPositions->getData(false);
for (size_t i = 0; i < numSequences; ++i) { for (size_t i = 0; i < numSequences; ++i) {
int seqId = std::get<2>(seqLengthAndStart_[i]); // memory info adopt info of inlinks[0]
int seqId = std::get<2>(seqLengthAndStart_[0][i]);
for (int k = starts[seqId]; k < starts[seqId + 1]; ++k) { for (int k = starts[seqId]; k < starts[seqId + 1]; ++k) {
allIds.push_back(k); allIds.push_back(k);
} }
...@@ -645,7 +706,7 @@ void RecurrentGradientMachine::createMemoryFrameInfo( ...@@ -645,7 +706,7 @@ void RecurrentGradientMachine::createMemoryFrameInfo(
} else { // for scatterAgentLayer } else { // for scatterAgentLayer
for (size_t i = 0; i < numSequences; ++i) { for (size_t i = 0; i < numSequences; ++i) {
allIds.push_back(std::get<2>(seqLengthAndStart_[i])); allIds.push_back(std::get<2>(seqLengthAndStart_[0][i]));
} }
} }
// copy and check scatterId // copy and check scatterId
...@@ -699,15 +760,16 @@ size_t RecurrentGradientMachine::getGenBatchSize() { ...@@ -699,15 +760,16 @@ size_t RecurrentGradientMachine::getGenBatchSize() {
for (auto& memoryFrameLine : memoryFrameLines_) { for (auto& memoryFrameLine : memoryFrameLines_) {
if (!memoryFrameLine.rootLayer) continue; if (!memoryFrameLine.rootLayer) continue;
Argument& bootArg = memoryFrameLine.rootLayer->getOutput(); Argument& bootArg = memoryFrameLine.rootLayer->getOutput();
size_t batchSize = memoryFrameLine.is_sequence ? size_t batchSize = memoryFrameLine.is_sequence ? bootArg.getNumSequences()
bootArg.getNumSequences() : bootArg.getBatchSize(); : bootArg.getBatchSize();
if (numSequences) { if (numSequences) {
CHECK_EQ(numSequences, batchSize); CHECK_EQ(numSequences, batchSize);
} else { } else {
numSequences = batchSize; numSequences = batchSize;
} }
} }
CHECK(numSequences) << "Fail to get batch size in generation. " CHECK(numSequences)
<< "Fail to get batch size in generation. "
"At least one of the Memory layer MUST have a layer that is NOT in " "At least one of the Memory layer MUST have a layer that is NOT in "
"the layer group to boot it, and this boot layer is used to " "the layer group to boot it, and this boot layer is used to "
"decide batch_size in generation process."; "decide batch_size in generation process.";
...@@ -732,7 +794,9 @@ void RecurrentGradientMachine::generateSequence() { ...@@ -732,7 +794,9 @@ void RecurrentGradientMachine::generateSequence() {
// connect boot frame memory links // connect boot frame memory links
std::vector<int> ids(numSequences); std::vector<int> ids(numSequences);
for (size_t i = 0; i < numSequences; ++i) { ids[i] = i; } for (size_t i = 0; i < numSequences; ++i) {
ids[i] = i;
}
for (auto& memoryFrameLine : memoryFrameLines_) { for (auto& memoryFrameLine : memoryFrameLines_) {
if (memoryFrameLine.rootAgent) { if (memoryFrameLine.rootAgent) {
auto scatterAgent = auto scatterAgent =
...@@ -756,7 +820,8 @@ void RecurrentGradientMachine::generateSequence() { ...@@ -756,7 +820,8 @@ void RecurrentGradientMachine::generateSequence() {
// init outArg // init outArg
size_t resultNum = generator_.config.num_results_per_sample(); size_t resultNum = generator_.config.num_results_per_sample();
IVector::resizeOrCreate(generator_.outArg.ids, IVector::resizeOrCreate(
generator_.outArg.ids,
generator_.config.max_num_frames() * numSequences * resultNum, false); generator_.config.max_num_frames() * numSequences * resultNum, false);
if (resultNum > 1) { if (resultNum > 1) {
CHECK_LE(resultNum, static_cast<size_t>(generator_.config.beam_size())); CHECK_LE(resultNum, static_cast<size_t>(generator_.config.beam_size()));
...@@ -847,7 +912,9 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) { ...@@ -847,7 +912,9 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
// path.seqId = -1 indicates end of generation // path.seqId = -1 indicates end of generation
// of an input sequence // of an input sequence
finalPaths[seqIds_[j]].seqId = -1; finalPaths[seqIds_[j]].seqId = -1;
} else { scatterIds.push_back(j); } } else {
scatterIds.push_back(j);
}
} }
} }
...@@ -856,8 +923,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) { ...@@ -856,8 +923,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
starts[0] = 0; starts[0] = 0;
generator_.ids.clear(); generator_.ids.clear();
for (size_t i = 0; i < batchSize; ++i) { for (size_t i = 0; i < batchSize; ++i) {
generator_.ids.insert(generator_.ids.end(), generator_.ids.insert(generator_.ids.end(), finalPaths[i].ids.begin(),
finalPaths[i].ids.begin(),
finalPaths[i].ids.end()); finalPaths[i].ids.end());
starts[i + 1] = generator_.ids.size(); starts[i + 1] = generator_.ids.size();
batchMachineIdVec_.insert(batchMachineIdVec_.end(), batchMachineIdVec_.insert(batchMachineIdVec_.end(),
...@@ -920,8 +986,8 @@ void RecurrentGradientMachine::forwardFrame(int machineCur) { ...@@ -920,8 +986,8 @@ void RecurrentGradientMachine::forwardFrame(int machineCur) {
} }
} }
void RecurrentGradientMachine::singlePathExpand( void RecurrentGradientMachine::singlePathExpand(Path& curPath, size_t curPathId,
Path& curPath, size_t curPathId, std::vector<Path>& newPaths, std::vector<Path>& newPaths,
size_t expandWidth) { size_t expandWidth) {
int calc_id = int calc_id =
gDiyProbStart ? gDiyProbStart(curPath.ids.size(), curPath.ids.data()) : 0; gDiyProbStart ? gDiyProbStart(curPath.ids.size(), curPath.ids.data()) : 0;
...@@ -946,19 +1012,20 @@ void RecurrentGradientMachine::singlePathExpand( ...@@ -946,19 +1012,20 @@ void RecurrentGradientMachine::singlePathExpand(
if (id == -1) break; if (id == -1) break;
real newLogProb = generator_.config.log_prob() ? std::log(prob) : prob; real newLogProb = generator_.config.log_prob() ? std::log(prob) : prob;
Path newPath(curPath, id, newLogProb, Path newPath(curPath, id, newLogProb, curPathId /*machineId*/,
curPathId /*machineId*/, k /*topIndex*/); k /*topIndex*/);
if (this->beamSearchCtrlCallbacks_) { if (this->beamSearchCtrlCallbacks_) {
if (beamSearchCtrlCallbacks_->stopDetermineCandidates( if (beamSearchCtrlCallbacks_->stopDetermineCandidates(
newPath.seqId, newPath.ids, newPath.probHistory)) return; newPath.seqId, newPath.ids, newPath.probHistory))
return;
} }
// outFrameLines_.size() > 1UL // outFrameLines_.size() > 1UL
if (dataArgsSize_) { if (dataArgsSize_) {
newPath.machineIdVec = curPath.machineIdVec; newPath.machineIdVec = curPath.machineIdVec;
newPath.machineIdVec.push_back(curPathId); newPath.machineIdVec.push_back(curPathId);
} }
bool atEos = eosVec[index] == 1U || bool atEos =
newPath.ids.size() >= (size_t)maxSequenceLength_; eosVec[index] == 1U || newPath.ids.size() >= (size_t)maxSequenceLength_;
// adjustNewPath // adjustNewPath
newPath.adjustProb(calc_id, atEos); newPath.adjustProb(calc_id, atEos);
if (this->beamSearchCtrlCallbacks_) { if (this->beamSearchCtrlCallbacks_) {
...@@ -966,16 +1033,18 @@ void RecurrentGradientMachine::singlePathExpand( ...@@ -966,16 +1033,18 @@ void RecurrentGradientMachine::singlePathExpand(
newPath.seqId, newPath.ids, newPath.probHistory, &newPath.logProb); newPath.seqId, newPath.ids, newPath.probHistory, &newPath.logProb);
} }
if (!newPath.isDropable()) { if (!newPath.isDropable()) {
atEos ? finalPaths_[curPath.seqId].push_back(newPath) : atEos ? finalPaths_[curPath.seqId].push_back(newPath)
newPaths.push_back(newPath); : newPaths.push_back(newPath);
} }
} // for expandWidth } // for expandWidth
if (gDiyProbStop) { gDiyProbStop(calc_id); } if (gDiyProbStop) {
gDiyProbStop(calc_id);
}
} }
void RecurrentGradientMachine::beamExpand( void RecurrentGradientMachine::beamExpand(std::vector<Path>& paths,
std::vector<Path>& paths, std::vector<Path>& newPaths) { std::vector<Path>& newPaths) {
size_t candidatePathCount = paths.size(); size_t candidatePathCount = paths.size();
// idVec.size() could be larger than candidatePathCount * beam, // idVec.size() could be larger than candidatePathCount * beam,
// so user can drop some node customly. // so user can drop some node customly.
...@@ -988,7 +1057,7 @@ void RecurrentGradientMachine::beamExpand( ...@@ -988,7 +1057,7 @@ void RecurrentGradientMachine::beamExpand(
int curSeqId = 0; int curSeqId = 0;
for (size_t j = 0; j <= candidatePathCount; j++) { for (size_t j = 0; j <= candidatePathCount; j++) {
// expansions of a single sequence are all processed // expansions of a single sequence are all processed
curSeqId = (j < candidatePathCount? paths[j].seqId : curSeqId + 1); curSeqId = (j < candidatePathCount ? paths[j].seqId : curSeqId + 1);
if (prevSeqId != -1 && curSeqId != prevSeqId) { if (prevSeqId != -1 && curSeqId != prevSeqId) {
totalExpandCount += beamShrink(newPaths, prevSeqId, totalExpandCount); totalExpandCount += beamShrink(newPaths, prevSeqId, totalExpandCount);
} }
...@@ -1000,11 +1069,14 @@ void RecurrentGradientMachine::beamExpand( ...@@ -1000,11 +1069,14 @@ void RecurrentGradientMachine::beamExpand(
} }
// Drop extra nodes to beam size. // Drop extra nodes to beam size.
size_t RecurrentGradientMachine::beamShrink( size_t RecurrentGradientMachine::beamShrink(std::vector<Path>& newPaths,
std::vector<Path>& newPaths, size_t seqId, size_t totalExpandCount) { size_t seqId,
size_t minNewPathSize = std::min(getBeamSize(), size_t totalExpandCount) {
newPaths.size() - totalExpandCount); size_t minNewPathSize =
if (!minNewPathSize) { return 0; } std::min(getBeamSize(), newPaths.size() - totalExpandCount);
if (!minNewPathSize) {
return 0;
}
std::nth_element(newPaths.begin() + totalExpandCount, std::nth_element(newPaths.begin() + totalExpandCount,
newPaths.begin() + totalExpandCount + minNewPathSize, newPaths.begin() + totalExpandCount + minNewPathSize,
newPaths.end(), Path::greaterPath); newPaths.end(), Path::greaterPath);
...@@ -1017,11 +1089,8 @@ size_t RecurrentGradientMachine::beamShrink( ...@@ -1017,11 +1089,8 @@ size_t RecurrentGradientMachine::beamShrink(
// Remove the already formed paths that are relatively short // Remove the already formed paths that are relatively short
finalPaths_[seqId].erase( finalPaths_[seqId].erase(
std::remove_if(finalPaths_[seqId].begin(), std::remove_if(finalPaths_[seqId].begin(), finalPaths_[seqId].end(),
finalPaths_[seqId].end(), [&](Path& p) { return p.logProb < minPathLogProb; }),
[&](Path& p) {
return p.logProb < minPathLogProb;
}),
finalPaths_[seqId].end()); finalPaths_[seqId].end());
for (auto p : finalPaths_[seqId]) { for (auto p : finalPaths_[seqId]) {
if (minFinalPathLogProb_[seqId] > p.logProb) { if (minFinalPathLogProb_[seqId] > p.logProb) {
...@@ -1067,7 +1136,8 @@ void RecurrentGradientMachine::fillGenOutputs() { ...@@ -1067,7 +1136,8 @@ void RecurrentGradientMachine::fillGenOutputs() {
// in beam search, here only reserved the top 1 generated result // in beam search, here only reserved the top 1 generated result
// for out_links that are not the generated word indices. // for out_links that are not the generated word indices.
batchMachineIdVec_.insert(batchMachineIdVec_.end(), batchMachineIdVec_.insert(batchMachineIdVec_.end(),
path.machineIdVec.begin(), path.machineIdVec.end()); path.machineIdVec.begin(),
path.machineIdVec.end());
} }
} }
starts[i + 1] = generator_.ids.size(); starts[i + 1] = generator_.ids.size();
...@@ -1091,21 +1161,21 @@ void RecurrentGradientMachine::copyDataOutlinkFrame(size_t machineCur) { ...@@ -1091,21 +1161,21 @@ void RecurrentGradientMachine::copyDataOutlinkFrame(size_t machineCur) {
void RecurrentGradientMachine::createDataOutlink( void RecurrentGradientMachine::createDataOutlink(
std::vector<int>& machineIdVec) { std::vector<int>& machineIdVec) {
size_t seqNum = getBeamSize() > 1UL ? size_t seqNum =
finalPaths_.size() : finalPaths_[0].size(); getBeamSize() > 1UL ? finalPaths_.size() : finalPaths_[0].size();
std::vector<int> starts(seqNum + 1, 0); std::vector<int> starts(seqNum + 1, 0);
for (size_t i = 0; i < seqNum; ++i) { for (size_t i = 0; i < seqNum; ++i) {
size_t seqLen = getBeamSize() > 1UL ? finalPaths_[i][0].ids.size() : size_t seqLen = getBeamSize() > 1UL ? finalPaths_[i][0].ids.size()
finalPaths_[0][i].ids.size(); : finalPaths_[0][i].ids.size();
starts[i + 1] = starts[i] + seqLen; starts[i + 1] = starts[i] + seqLen;
} }
for (size_t i = 0; i < dataArgsSize_; i++) { for (size_t i = 0; i < dataArgsSize_; i++) {
dataArgs_[i].concat(dataArgsFrame_[i], machineIdVec, dataArgs_[i].concat(dataArgsFrame_[i], machineIdVec, starts, useGpu_,
starts, useGpu_, HPPL_STREAM_1, PASS_TEST); HPPL_STREAM_1, PASS_TEST);
auto dataAgent = dynamic_cast<DataLayer*>( auto dataAgent =
outFrameLines_[i + 1].agentLayer.get()); dynamic_cast<DataLayer*>(outFrameLines_[i + 1].agentLayer.get());
CHECK_NOTNULL(dataAgent); CHECK_NOTNULL(dataAgent);
dataAgent->setData(dataArgs_[i]); dataAgent->setData(dataArgs_[i]);
} }
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "GradientMachine.h" #include "GradientMachine.h"
...@@ -206,7 +205,10 @@ public: ...@@ -206,7 +205,10 @@ public:
/** /**
* @brief Path default ctor, first logProb is 0. * @brief Path default ctor, first logProb is 0.
*/ */
Path() { logProb = 0; seqId = 0; } Path() {
logProb = 0;
seqId = 0;
}
explicit Path(size_t seqId) : seqId(seqId) { logProb = 0; } explicit Path(size_t seqId) : seqId(seqId) { logProb = 0; }
/** /**
...@@ -319,7 +321,9 @@ protected: ...@@ -319,7 +321,9 @@ protected:
}; };
std::vector<MemoryFrameLine> memoryFrameLines_; std::vector<MemoryFrameLine> memoryFrameLines_;
// All inFrameLines and outFrameLines have the same element as follows. // Each inFrameLines(inlinks) has its own info(elements) below,
// and all outFrameLines(outlinks) share the info with one inFrameLine,
// which is assigned by targetInfoInlinkId_.
struct Info { struct Info {
IVectorPtr allIds; // scattered id of realLayer IVectorPtr allIds; // scattered id of realLayer
std::vector<int> idIndex; // index of allIds std::vector<int> idIndex; // index of allIds
...@@ -327,13 +331,23 @@ protected: ...@@ -327,13 +331,23 @@ protected:
sequenceStartPositions; // scattered sequenceStartPositions sequenceStartPositions; // scattered sequenceStartPositions
std::vector<int> seqStartPosIndex; // index of sequenceStartPositions std::vector<int> seqStartPosIndex; // index of sequenceStartPositions
}; };
Info info_; std::vector<Info> info_;
// each inlinks has a "std::vector<std::tuple<int, int, int, int>>" denotes
// its sequence info:
// if hasSubSeq, tuple of (subSeqLength, subSeqStart, seqIndex, subSeqIndex)
// else, tuple of (seqLength, seqStart, seqIndex, seqIndex)
std::vector<std::vector<std::tuple<int, int, int, int>>> seqLengthAndStart_;
// if no subSeq, tuple of (seqLength, seqStart, seqIndex, seqIndex) // the id of inlink which share info with outlinks
// else, tuple of (subSeqLength, subSeqStart, seqIndex, subSeqIndex) int targetInfoInlinkId_;
std::vector<std::tuple<int, int, int, int>> seqLengthAndStart_;
void createInFrameInfo(const Argument& input, PassType passType); /* create scattered id infomation for all realLayer of inFrameLines one time.
* If hasSubseq, will also create scattered sequenceStartPositions infomation
* for all realLayer of inFrameLines one time.
*/
void createInFrameInfo(int inlinks_id, const Argument& input,
PassType passType);
void createMemoryFrameInfo(MemoryFrameLine* memoryFrameLine, void createMemoryFrameInfo(MemoryFrameLine* memoryFrameLine,
PassType passType); PassType passType);
...@@ -363,6 +377,9 @@ protected: ...@@ -363,6 +377,9 @@ protected:
NeuralNetwork* rootNetwork_; NeuralNetwork* rootNetwork_;
bool reversed_; bool reversed_;
// if hasSubseq: max number of sentences(subseq)in batchsize samples
// else: max number of tokens in batchsize samples(sentences)
int maxSequenceLength_; int maxSequenceLength_;
bool useGpu_; bool useGpu_;
bool stopBeamSearch_; bool stopBeamSearch_;
...@@ -415,7 +432,7 @@ private: ...@@ -415,7 +432,7 @@ private:
* @param machineIdVec : select a row of output matrix in each frame * @param machineIdVec : select a row of output matrix in each frame
* that the generation process expanded. * that the generation process expanded.
*/ */
void createDataOutlink(std::vector<int> & machineIdVec); void createDataOutlink(std::vector<int>& machineIdVec);
/* /*
* @brief used in beam search, connect previous frame to form recurrent link * @brief used in beam search, connect previous frame to form recurrent link
......
...@@ -452,6 +452,9 @@ message SubModelConfig { ...@@ -452,6 +452,9 @@ message SubModelConfig {
repeated LinkConfig out_links = 10; repeated LinkConfig out_links = 10;
optional GeneratorConfig generator = 11; optional GeneratorConfig generator = 11;
// the id of inlink which share info with outlinks, used in recurrent layer group
optional int32 target_inlinkid = 12;
} }
message ModelConfig { message ModelConfig {
......
...@@ -303,7 +303,8 @@ def MakeLayerNameInSubmodel(name, submodel_name = None): ...@@ -303,7 +303,8 @@ def MakeLayerNameInSubmodel(name, submodel_name = None):
@config_func @config_func
def RecurrentLayerGroupWithoutOutLinksBegin(name, def RecurrentLayerGroupWithoutOutLinksBegin(name,
in_links, in_links,
seq_reversed=False): seq_reversed=False,
target_inlinkname=""):
global g_current_submodel global g_current_submodel
config_assert(g_config.model_config.type == "recurrent_nn", config_assert(g_config.model_config.type == "recurrent_nn",
"RecurrentLayerGroup should be used only in recurrent_nn") "RecurrentLayerGroup should be used only in recurrent_nn")
...@@ -311,14 +312,19 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, ...@@ -311,14 +312,19 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
SubModelBegin(name) SubModelBegin(name)
g_current_submodel.is_recurrent_layer_group = True g_current_submodel.is_recurrent_layer_group = True
g_current_submodel.reversed = seq_reversed g_current_submodel.reversed = seq_reversed
g_current_submodel.target_inlinkid = -1
in_links_count = 0 in_links_count = 0
for link in in_links: for linkid, link in enumerate(in_links):
if isinstance(link, basestring): if isinstance(link, basestring):
name = link name = link
has_subseq = False has_subseq = False
else: else:
name = link.link_name name = link.link_name
has_subseq = link.has_subseq has_subseq = link.has_subseq
# assign target_inlinkid according to target_inlinkname
if target_inlinkname == name:
g_current_submodel.target_inlinkid = linkid
if in_links_count == 0: if in_links_count == 0:
in_links_has_subseq = has_subseq in_links_has_subseq = has_subseq
else: else:
...@@ -331,6 +337,7 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, ...@@ -331,6 +337,7 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
SequenceScatterAgentLayer(name=name, size=layer.size) SequenceScatterAgentLayer(name=name, size=layer.size)
else: else:
ScatterAgentLayer(name=name, size=layer.size) ScatterAgentLayer(name=name, size=layer.size)
pair = g_current_submodel.in_links.add() pair = g_current_submodel.in_links.add()
pair.layer_name = layer_name pair.layer_name = layer_name
pair.link_name = MakeLayerNameInSubmodel(name) pair.link_name = MakeLayerNameInSubmodel(name)
...@@ -362,10 +369,12 @@ def RecurrentLayerGroupBegin(name, ...@@ -362,10 +369,12 @@ def RecurrentLayerGroupBegin(name,
in_links, in_links,
out_links, out_links,
generator=None, generator=None,
target_inlinkname="",
seq_reversed=False): seq_reversed=False):
RecurrentLayerGroupWithoutOutLinksBegin(name, RecurrentLayerGroupWithoutOutLinksBegin(name,
in_links, in_links,
seq_reversed) seq_reversed,
target_inlinkname)
for link in out_links: for link in out_links:
RecurrentLayerGroupSetOutLink(link) RecurrentLayerGroupSetOutLink(link)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册