提交 699d5f26 编写于 作者: Z zhangruiqing01 提交者: Yu Yang

modify RecurrentGradientMachine to support unequal length inputs

* modify RecurrentGradientMachine to support hasSubSeq sequence inlinks with the same number of sentence but different number of tokens for each sentence

Change-Id: Ic71f00a4bb346b4fa93e650dfb4b1a0d8d2338b0
上级 0f91ea7e
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Flags.h"
......@@ -291,6 +290,8 @@ void RecurrentGradientMachine::init(
if (subModelConfig->evaluator_names_size() > 0) {
evaluator_.reset(frames_[0]->makeEvaluator());
}
targetInfoInlinkId_ = subModelConfig->target_inlinkid();
}
void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) {
......@@ -382,6 +383,16 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
size_t numSequences = input.getNumSequences();
const int* starts = input.sequenceStartPositions->getData(false);
bool hasSubseq = input.hasSubseq();
// In case of !hasSubseq or targetInfoInlinkId_ == -1, all inlinks share the
// same inframe info
bool shareInlinkInfo = !hasSubseq || targetInfoInlinkId_ == -1;
// Defaultly, share info with the first inlink
if (shareInlinkInfo) {
targetInfoInlinkId_ = 0;
}
// check hasSubseq in both config and input are the same
CHECK_EQ(hasSubseq, inFrameLines_[0].hasSubseq);
......@@ -394,10 +405,18 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
CHECK_EQ((size_t)input1.getNumSequences(), numSequences);
// check all inputs should have same hasSubseq flag
CHECK_EQ(input.hasSubseq(), inFrameLines_[0].hasSubseq);
// if shareInlinkInfo, checks:
// 1. all inlinks have same number of total tokens
// 2. all inlinks have same number of tokens for each sentence of each
// sample. If hasSubseq, one sample has multiple sentence, else, one
// sample is one sentence
if (shareInlinkInfo) {
CHECK_EQ(input1.getBatchSize(), batchSize);
CHECK(std::equal(starts, starts + numSequences + 1,
input1.sequenceStartPositions->getData(false)));
}
}
if (hasSubseq) {
CHECK(input.subSequenceStartPositions);
......@@ -408,18 +427,43 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
for (size_t i = 1; i < inFrameLines_.size(); ++i) {
const Argument& input1 = inFrameLines_[i].inLayer->getOutput();
CHECK_EQ((size_t)input1.getNumSubSequences(), numSubSequences);
if (shareInlinkInfo) {
CHECK(std::equal(subStarts, subStarts + numSubSequences + 1,
input1.subSequenceStartPositions->getData(false)));
}
}
}
seqLengthAndStart_.clear();
input.getSeqLengthAndStart(&seqLengthAndStart_, &maxSequenceLength_);
resizeOrCreateFrames(maxSequenceLength_);
resizeBootFrame(numSequences);
info_.clear();
info_.resize(inFrameLines_.size());
seqLengthAndStart_.resize(inFrameLines_.size());
{
AsyncGpuBlock asyncGpuBlock;
createInFrameInfo(input, passType);
// if shareInlinkInfo, only calculate info of the first inlink
// else, calculate info for each inlink
if (shareInlinkInfo) {
input.getSeqLengthAndStart(&seqLengthAndStart_[0], &maxSequenceLength_);
createInFrameInfo(0, input, passType);
} else {
for (size_t i = 0; i < inFrameLines_.size(); i++) {
const Argument& input1 = inFrameLines_[i].inLayer->getOutput();
input1.getSeqLengthAndStart(&seqLengthAndStart_[i],
&maxSequenceLength_);
createInFrameInfo(i, input1, passType);
}
}
// inFrameLine select rows in real layer one time
for (size_t i = 0; i < inFrameLines_.size(); i++) {
int curInlinkId = shareInlinkInfo ? 0 : i;
selectRowsOneTime(inFrameLines_[i].inLayer, info_[curInlinkId].allIds,
&(inFrameLines_[i].outArg), passType);
}
}
resizeOrCreateFrames(maxSequenceLength_);
resizeBootFrame(numSequences);
for (auto& memoryFrameLine : memoryFrameLines_) {
if (memoryFrameLine.rootAgent) {
......@@ -443,23 +487,29 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
auto gatherAgent =
dynamic_cast<GatherAgentLayer*>(outFrameLine.agentLayer.get());
CHECK_NOTNULL(gatherAgent);
gatherAgent->copyIdAndSequenceInfo(input, info_.allIds, info_.idIndex);
gatherAgent->copyIdAndSequenceInfo(input, info_[targetInfoInlinkId_].allIds,
info_[targetInfoInlinkId_].idIndex);
}
for (int i = 0; i < maxSequenceLength_; ++i) {
int idSize = info_.idIndex[i + 1] - info_.idIndex[i];
int idSize = 0;
// connect in_links
for (auto& inFrameLine : inFrameLines_) {
for (size_t j = 0; j < inFrameLines_.size(); ++j) {
// idSize denotes the sum number of tokens in each length i
idSize = info_[j].idIndex[i + 1] - info_[j].idIndex[i];
InFrameLine inFrameLine = inFrameLines_[j];
auto scatterAgent =
dynamic_cast<ScatterAgentLayer*>(inFrameLine.agents[i].get());
scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer,
inFrameLine.outArg, info_.allIds,
info_.idIndex[i], idSize);
inFrameLine.outArg, info_[j].allIds,
info_[j].idIndex[i], idSize);
if (hasSubseq) {
int size = info_.seqStartPosIndex[i + 1] - info_.seqStartPosIndex[i];
scatterAgent->setSequenceStartPositions(
info_.sequenceStartPositions, info_.seqStartPosIndex[i], size);
// size: the length of subsequence
int size =
info_[j].seqStartPosIndex[i + 1] - info_[j].seqStartPosIndex[i];
scatterAgent->setSequenceStartPositions(info_[j].sequenceStartPositions,
info_[j].seqStartPosIndex[i],
size);
}
}
......@@ -471,6 +521,10 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
}
// connect memory links
// Adopt info_[0].idIndex because seq which has_subseq=True
// doesn't support Memory with !hasSubseq bootlayer;
// And inlinks that !hasSubSeq must have same inlink length.
idSize = info_[0].idIndex[i + 1] - info_[0].idIndex[i];
for (auto& memoryFrameLine : memoryFrameLines_) {
NeuralNetwork::connect(
memoryFrameLine.agents[i],
......@@ -560,24 +614,33 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() {
* If hasSubseq, will also create scattered sequenceStartPositions infomation
* for all realLayer of inFrameLines one time.
*/
void RecurrentGradientMachine::createInFrameInfo(const Argument& input,
void RecurrentGradientMachine::createInFrameInfo(int inlinks_id,
const Argument& input,
PassType passType) {
bool hasSubseq = input.hasSubseq();
// numSequences: # samples(sequences) in a batch
size_t numSequences = input.getNumSequences();
std::vector<int> allIds;
info_.idIndex.clear();
info_.idIndex.push_back(0); // first idIndex = 0
Info* inlink_info = &info_[inlinks_id];
inlink_info->idIndex.clear();
inlink_info->idIndex.push_back(0); // first idIndex = 0
if (hasSubseq) { // for sequenceScatterAgentLayer
// numSubSequences : all sentences within all samples(batch)
size_t numSubSequences = input.getNumSubSequences();
std::vector<int> sequenceStartPositions;
info_.seqStartPosIndex.clear();
info_.seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0
inlink_info->seqStartPosIndex.clear();
inlink_info->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0
// maxSequenceLength_: max number of sentences(subseq) in allsamples
for (int i = 0; i < maxSequenceLength_; ++i) {
sequenceStartPositions.push_back(0); // first element = 0
for (size_t j = 0; j < numSubSequences; ++j) {
if (std::get<3>(seqLengthAndStart_[j]) == i) {
int subSeqStart = std::get<1>(seqLengthAndStart_[j]);
int subSeqLength = std::get<0>(seqLengthAndStart_[j]);
for (size_t j = 0; j < numSubSequences; ++j) { // for each sentence
// seqLengthAndStart_[inlinks_id][j]:
// a 4-tuple including <subseqlen, subseqstart, seqid, subseqid>
if (std::get<3>(seqLengthAndStart_[inlinks_id][j]) == i) {
// subseqstart: the cpuSubSequenceStartPositions of this subseq
int subSeqStart = std::get<1>(seqLengthAndStart_[inlinks_id][j]);
int subSeqLength = std::get<0>(seqLengthAndStart_[inlinks_id][j]);
for (int k = subSeqStart; k < subSeqStart + subSeqLength; ++k) {
allIds.push_back(k);
}
......@@ -585,37 +648,34 @@ void RecurrentGradientMachine::createInFrameInfo(const Argument& input,
subSeqLength);
}
}
info_.idIndex.push_back(allIds.size());
info_.seqStartPosIndex.push_back(sequenceStartPositions.size());
inlink_info->idIndex.push_back(allIds.size());
inlink_info->seqStartPosIndex.push_back(sequenceStartPositions.size());
}
// inFrameLine create sequenceStartPositions one time
CHECK_EQ(sequenceStartPositions.size(),
maxSequenceLength_ + numSubSequences);
CHECK_EQ(info_.seqStartPosIndex.size(),
CHECK_EQ(inlink_info->seqStartPosIndex.size(),
static_cast<size_t>(maxSequenceLength_ + 1));
createSeqPos(sequenceStartPositions, &info_.sequenceStartPositions);
createSeqPos(sequenceStartPositions, &inlink_info->sequenceStartPositions);
} else { // for scatterAgentLayer
for (int i = 0; i < maxSequenceLength_; ++i) {
for (size_t j = 0; j < numSequences; ++j) {
int seqLength = std::get<0>(seqLengthAndStart_[j]);
int seqLength = std::get<0>(seqLengthAndStart_[inlinks_id][j]);
if (i >= seqLength) {
break;
}
int seqStart = std::get<1>(seqLengthAndStart_[j]);
int seqStart = std::get<1>(seqLengthAndStart_[inlinks_id][j]);
allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i)
: (seqStart + i));
}
info_.idIndex.push_back(allIds.size());
inlink_info->idIndex.push_back(allIds.size());
}
}
// copy and check scatterId
copyScattedId(allIds, &info_.allIds, input.getBatchSize());
CHECK_EQ(info_.idIndex.size(), static_cast<size_t>(maxSequenceLength_ + 1));
// inFrameLine select rows in real layer one time
for (auto& inFrameLine : inFrameLines_) {
selectRowsOneTime(inFrameLine.inLayer, info_.allIds, &inFrameLine.outArg,
passType);
}
copyScattedId(allIds, &inlink_info->allIds, input.getBatchSize());
CHECK_EQ(inlink_info->idIndex.size(),
static_cast<size_t>(maxSequenceLength_ + 1));
}
/* like createInFrameInfo, but for all realLayer of memoryFrameLines*/
......@@ -633,7 +693,8 @@ void RecurrentGradientMachine::createMemoryFrameInfo(
sequenceStartPositions.push_back(0); // first element = 0
const int* starts = input.sequenceStartPositions->getData(false);
for (size_t i = 0; i < numSequences; ++i) {
int seqId = std::get<2>(seqLengthAndStart_[i]);
// memory info adopt info of inlinks[0]
int seqId = std::get<2>(seqLengthAndStart_[0][i]);
for (int k = starts[seqId]; k < starts[seqId + 1]; ++k) {
allIds.push_back(k);
}
......@@ -645,7 +706,7 @@ void RecurrentGradientMachine::createMemoryFrameInfo(
} else { // for scatterAgentLayer
for (size_t i = 0; i < numSequences; ++i) {
allIds.push_back(std::get<2>(seqLengthAndStart_[i]));
allIds.push_back(std::get<2>(seqLengthAndStart_[0][i]));
}
}
// copy and check scatterId
......@@ -699,15 +760,16 @@ size_t RecurrentGradientMachine::getGenBatchSize() {
for (auto& memoryFrameLine : memoryFrameLines_) {
if (!memoryFrameLine.rootLayer) continue;
Argument& bootArg = memoryFrameLine.rootLayer->getOutput();
size_t batchSize = memoryFrameLine.is_sequence ?
bootArg.getNumSequences() : bootArg.getBatchSize();
size_t batchSize = memoryFrameLine.is_sequence ? bootArg.getNumSequences()
: bootArg.getBatchSize();
if (numSequences) {
CHECK_EQ(numSequences, batchSize);
} else {
numSequences = batchSize;
}
}
CHECK(numSequences) << "Fail to get batch size in generation. "
CHECK(numSequences)
<< "Fail to get batch size in generation. "
"At least one of the Memory layer MUST have a layer that is NOT in "
"the layer group to boot it, and this boot layer is used to "
"decide batch_size in generation process.";
......@@ -732,7 +794,9 @@ void RecurrentGradientMachine::generateSequence() {
// connect boot frame memory links
std::vector<int> ids(numSequences);
for (size_t i = 0; i < numSequences; ++i) { ids[i] = i; }
for (size_t i = 0; i < numSequences; ++i) {
ids[i] = i;
}
for (auto& memoryFrameLine : memoryFrameLines_) {
if (memoryFrameLine.rootAgent) {
auto scatterAgent =
......@@ -756,7 +820,8 @@ void RecurrentGradientMachine::generateSequence() {
// init outArg
size_t resultNum = generator_.config.num_results_per_sample();
IVector::resizeOrCreate(generator_.outArg.ids,
IVector::resizeOrCreate(
generator_.outArg.ids,
generator_.config.max_num_frames() * numSequences * resultNum, false);
if (resultNum > 1) {
CHECK_LE(resultNum, static_cast<size_t>(generator_.config.beam_size()));
......@@ -847,7 +912,9 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
// path.seqId = -1 indicates end of generation
// of an input sequence
finalPaths[seqIds_[j]].seqId = -1;
} else { scatterIds.push_back(j); }
} else {
scatterIds.push_back(j);
}
}
}
......@@ -856,8 +923,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) {
starts[0] = 0;
generator_.ids.clear();
for (size_t i = 0; i < batchSize; ++i) {
generator_.ids.insert(generator_.ids.end(),
finalPaths[i].ids.begin(),
generator_.ids.insert(generator_.ids.end(), finalPaths[i].ids.begin(),
finalPaths[i].ids.end());
starts[i + 1] = generator_.ids.size();
batchMachineIdVec_.insert(batchMachineIdVec_.end(),
......@@ -920,8 +986,8 @@ void RecurrentGradientMachine::forwardFrame(int machineCur) {
}
}
void RecurrentGradientMachine::singlePathExpand(
Path& curPath, size_t curPathId, std::vector<Path>& newPaths,
void RecurrentGradientMachine::singlePathExpand(Path& curPath, size_t curPathId,
std::vector<Path>& newPaths,
size_t expandWidth) {
int calc_id =
gDiyProbStart ? gDiyProbStart(curPath.ids.size(), curPath.ids.data()) : 0;
......@@ -946,19 +1012,20 @@ void RecurrentGradientMachine::singlePathExpand(
if (id == -1) break;
real newLogProb = generator_.config.log_prob() ? std::log(prob) : prob;
Path newPath(curPath, id, newLogProb,
curPathId /*machineId*/, k /*topIndex*/);
Path newPath(curPath, id, newLogProb, curPathId /*machineId*/,
k /*topIndex*/);
if (this->beamSearchCtrlCallbacks_) {
if (beamSearchCtrlCallbacks_->stopDetermineCandidates(
newPath.seqId, newPath.ids, newPath.probHistory)) return;
newPath.seqId, newPath.ids, newPath.probHistory))
return;
}
// outFrameLines_.size() > 1UL
if (dataArgsSize_) {
newPath.machineIdVec = curPath.machineIdVec;
newPath.machineIdVec.push_back(curPathId);
}
bool atEos = eosVec[index] == 1U ||
newPath.ids.size() >= (size_t)maxSequenceLength_;
bool atEos =
eosVec[index] == 1U || newPath.ids.size() >= (size_t)maxSequenceLength_;
// adjustNewPath
newPath.adjustProb(calc_id, atEos);
if (this->beamSearchCtrlCallbacks_) {
......@@ -966,16 +1033,18 @@ void RecurrentGradientMachine::singlePathExpand(
newPath.seqId, newPath.ids, newPath.probHistory, &newPath.logProb);
}
if (!newPath.isDropable()) {
atEos ? finalPaths_[curPath.seqId].push_back(newPath) :
newPaths.push_back(newPath);
atEos ? finalPaths_[curPath.seqId].push_back(newPath)
: newPaths.push_back(newPath);
}
} // for expandWidth
if (gDiyProbStop) { gDiyProbStop(calc_id); }
if (gDiyProbStop) {
gDiyProbStop(calc_id);
}
}
void RecurrentGradientMachine::beamExpand(
std::vector<Path>& paths, std::vector<Path>& newPaths) {
void RecurrentGradientMachine::beamExpand(std::vector<Path>& paths,
std::vector<Path>& newPaths) {
size_t candidatePathCount = paths.size();
// idVec.size() could be larger than candidatePathCount * beam,
// so user can drop some node customly.
......@@ -988,7 +1057,7 @@ void RecurrentGradientMachine::beamExpand(
int curSeqId = 0;
for (size_t j = 0; j <= candidatePathCount; j++) {
// expansions of a single sequence are all processed
curSeqId = (j < candidatePathCount? paths[j].seqId : curSeqId + 1);
curSeqId = (j < candidatePathCount ? paths[j].seqId : curSeqId + 1);
if (prevSeqId != -1 && curSeqId != prevSeqId) {
totalExpandCount += beamShrink(newPaths, prevSeqId, totalExpandCount);
}
......@@ -1000,11 +1069,14 @@ void RecurrentGradientMachine::beamExpand(
}
// Drop extra nodes to beam size.
size_t RecurrentGradientMachine::beamShrink(
std::vector<Path>& newPaths, size_t seqId, size_t totalExpandCount) {
size_t minNewPathSize = std::min(getBeamSize(),
newPaths.size() - totalExpandCount);
if (!minNewPathSize) { return 0; }
size_t RecurrentGradientMachine::beamShrink(std::vector<Path>& newPaths,
size_t seqId,
size_t totalExpandCount) {
size_t minNewPathSize =
std::min(getBeamSize(), newPaths.size() - totalExpandCount);
if (!minNewPathSize) {
return 0;
}
std::nth_element(newPaths.begin() + totalExpandCount,
newPaths.begin() + totalExpandCount + minNewPathSize,
newPaths.end(), Path::greaterPath);
......@@ -1017,11 +1089,8 @@ size_t RecurrentGradientMachine::beamShrink(
// Remove the already formed paths that are relatively short
finalPaths_[seqId].erase(
std::remove_if(finalPaths_[seqId].begin(),
finalPaths_[seqId].end(),
[&](Path& p) {
return p.logProb < minPathLogProb;
}),
std::remove_if(finalPaths_[seqId].begin(), finalPaths_[seqId].end(),
[&](Path& p) { return p.logProb < minPathLogProb; }),
finalPaths_[seqId].end());
for (auto p : finalPaths_[seqId]) {
if (minFinalPathLogProb_[seqId] > p.logProb) {
......@@ -1067,7 +1136,8 @@ void RecurrentGradientMachine::fillGenOutputs() {
// in beam search, here only reserved the top 1 generated result
// for out_links that are not the generated word indices.
batchMachineIdVec_.insert(batchMachineIdVec_.end(),
path.machineIdVec.begin(), path.machineIdVec.end());
path.machineIdVec.begin(),
path.machineIdVec.end());
}
}
starts[i + 1] = generator_.ids.size();
......@@ -1091,21 +1161,21 @@ void RecurrentGradientMachine::copyDataOutlinkFrame(size_t machineCur) {
void RecurrentGradientMachine::createDataOutlink(
std::vector<int>& machineIdVec) {
size_t seqNum = getBeamSize() > 1UL ?
finalPaths_.size() : finalPaths_[0].size();
size_t seqNum =
getBeamSize() > 1UL ? finalPaths_.size() : finalPaths_[0].size();
std::vector<int> starts(seqNum + 1, 0);
for (size_t i = 0; i < seqNum; ++i) {
size_t seqLen = getBeamSize() > 1UL ? finalPaths_[i][0].ids.size() :
finalPaths_[0][i].ids.size();
size_t seqLen = getBeamSize() > 1UL ? finalPaths_[i][0].ids.size()
: finalPaths_[0][i].ids.size();
starts[i + 1] = starts[i] + seqLen;
}
for (size_t i = 0; i < dataArgsSize_; i++) {
dataArgs_[i].concat(dataArgsFrame_[i], machineIdVec,
starts, useGpu_, HPPL_STREAM_1, PASS_TEST);
dataArgs_[i].concat(dataArgsFrame_[i], machineIdVec, starts, useGpu_,
HPPL_STREAM_1, PASS_TEST);
auto dataAgent = dynamic_cast<DataLayer*>(
outFrameLines_[i + 1].agentLayer.get());
auto dataAgent =
dynamic_cast<DataLayer*>(outFrameLines_[i + 1].agentLayer.get());
CHECK_NOTNULL(dataAgent);
dataAgent->setData(dataArgs_[i]);
}
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "GradientMachine.h"
......@@ -206,7 +205,10 @@ public:
/**
* @brief Path default ctor, first logProb is 0.
*/
Path() { logProb = 0; seqId = 0; }
Path() {
logProb = 0;
seqId = 0;
}
explicit Path(size_t seqId) : seqId(seqId) { logProb = 0; }
/**
......@@ -319,7 +321,9 @@ protected:
};
std::vector<MemoryFrameLine> memoryFrameLines_;
// All inFrameLines and outFrameLines have the same element as follows.
// Each inFrameLines(inlinks) has its own info(elements) below,
// and all outFrameLines(outlinks) share the info with one inFrameLine,
// which is assigned by targetInfoInlinkId_.
struct Info {
IVectorPtr allIds; // scattered id of realLayer
std::vector<int> idIndex; // index of allIds
......@@ -327,13 +331,23 @@ protected:
sequenceStartPositions; // scattered sequenceStartPositions
std::vector<int> seqStartPosIndex; // index of sequenceStartPositions
};
Info info_;
std::vector<Info> info_;
// each inlinks has a "std::vector<std::tuple<int, int, int, int>>" denotes
// its sequence info:
// if hasSubSeq, tuple of (subSeqLength, subSeqStart, seqIndex, subSeqIndex)
// else, tuple of (seqLength, seqStart, seqIndex, seqIndex)
std::vector<std::vector<std::tuple<int, int, int, int>>> seqLengthAndStart_;
// if no subSeq, tuple of (seqLength, seqStart, seqIndex, seqIndex)
// else, tuple of (subSeqLength, subSeqStart, seqIndex, subSeqIndex)
std::vector<std::tuple<int, int, int, int>> seqLengthAndStart_;
// the id of inlink which share info with outlinks
int targetInfoInlinkId_;
void createInFrameInfo(const Argument& input, PassType passType);
/* create scattered id infomation for all realLayer of inFrameLines one time.
* If hasSubseq, will also create scattered sequenceStartPositions infomation
* for all realLayer of inFrameLines one time.
*/
void createInFrameInfo(int inlinks_id, const Argument& input,
PassType passType);
void createMemoryFrameInfo(MemoryFrameLine* memoryFrameLine,
PassType passType);
......@@ -363,6 +377,9 @@ protected:
NeuralNetwork* rootNetwork_;
bool reversed_;
// if hasSubseq: max number of sentences(subseq)in batchsize samples
// else: max number of tokens in batchsize samples(sentences)
int maxSequenceLength_;
bool useGpu_;
bool stopBeamSearch_;
......@@ -415,7 +432,7 @@ private:
* @param machineIdVec : select a row of output matrix in each frame
* that the generation process expanded.
*/
void createDataOutlink(std::vector<int> & machineIdVec);
void createDataOutlink(std::vector<int>& machineIdVec);
/*
* @brief used in beam search, connect previous frame to form recurrent link
......
......@@ -452,6 +452,9 @@ message SubModelConfig {
repeated LinkConfig out_links = 10;
optional GeneratorConfig generator = 11;
// the id of inlink which share info with outlinks, used in recurrent layer group
optional int32 target_inlinkid = 12;
}
message ModelConfig {
......
......@@ -303,7 +303,8 @@ def MakeLayerNameInSubmodel(name, submodel_name = None):
@config_func
def RecurrentLayerGroupWithoutOutLinksBegin(name,
in_links,
seq_reversed=False):
seq_reversed=False,
target_inlinkname=""):
global g_current_submodel
config_assert(g_config.model_config.type == "recurrent_nn",
"RecurrentLayerGroup should be used only in recurrent_nn")
......@@ -311,14 +312,19 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
SubModelBegin(name)
g_current_submodel.is_recurrent_layer_group = True
g_current_submodel.reversed = seq_reversed
g_current_submodel.target_inlinkid = -1
in_links_count = 0
for link in in_links:
for linkid, link in enumerate(in_links):
if isinstance(link, basestring):
name = link
has_subseq = False
else:
name = link.link_name
has_subseq = link.has_subseq
# assign target_inlinkid according to target_inlinkname
if target_inlinkname == name:
g_current_submodel.target_inlinkid = linkid
if in_links_count == 0:
in_links_has_subseq = has_subseq
else:
......@@ -331,6 +337,7 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
SequenceScatterAgentLayer(name=name, size=layer.size)
else:
ScatterAgentLayer(name=name, size=layer.size)
pair = g_current_submodel.in_links.add()
pair.layer_name = layer_name
pair.link_name = MakeLayerNameInSubmodel(name)
......@@ -362,10 +369,12 @@ def RecurrentLayerGroupBegin(name,
in_links,
out_links,
generator=None,
target_inlinkname="",
seq_reversed=False):
RecurrentLayerGroupWithoutOutLinksBegin(name,
in_links,
seq_reversed)
seq_reversed,
target_inlinkname)
for link in out_links:
RecurrentLayerGroupSetOutLink(link)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册