提交 1e1a33b5 编写于 作者: H Haonan 提交者: Yu Yang

Argument concat for subsequence start positions

Change-Id: Ia60c008a8c922f66e6b5e2ca3e488fc4625d6506
上级 c3c76d69
......@@ -269,6 +269,9 @@ void Argument::concat(const std::vector<Argument>& args,
const std::vector<int>& selectRows,
const std::vector<int>& seqStartPos, bool useGpu,
hl_stream_t stream, PassType passType) {
CHECK(!subSequenceStartPositions)
<< "undefined behavior for subsequence positions";
size_t batchSize = selectRows.size();
auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src,
int startRow, int pos, int size,
......@@ -347,9 +350,11 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu,
hl_stream_t stream, PassType passType) {
int32_t batchSize = 0;
int64_t numSequences = 0;
int64_t numSubSequences = 0;
for (auto& arg : args) {
batchSize += arg.getBatchSize();
numSequences += arg.getNumSequences();
numSubSequences += arg.getNumSubSequences();
}
auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src,
......@@ -393,8 +398,26 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu,
std::copy(src->begin(), src->end(), dst->begin() + startRow);
};
auto copySequencePos = []
(ICpuGpuVectorPtr& dstSeq, const ICpuGpuVectorPtr& srcSeq,
int dstNumSequences, int srcNumSequences,
int& startSequences, int startRow) {
if (srcSeq) {
ICpuGpuVector::resizeOrCreate(dstSeq, dstNumSequences + 1, false);
const int* src = srcSeq->getData(false);
int* dest = dstSeq->getMutableData(false);
for (int i = 0; i < srcNumSequences + 1; ++i) {
dest[i + startSequences] = src[i] + startRow;
}
startSequences += srcNumSequences;
} else {
dstSeq.reset();
}
};
int startRow = 0;
int startSequences = 0;
int startSubSequences = 0;
dataId = args[0].dataId;
for (auto& arg : args) {
CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have"
......@@ -403,17 +426,18 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu,
copyArg(value, arg.value, startRow, useGpu);
if (passType != PASS_TEST) copyArg(grad, arg.grad, startRow, useGpu);
copyIds(ids, arg.ids, startRow, useGpu);
if (arg.sequenceStartPositions) {
ICpuGpuVector::resizeOrCreate(sequenceStartPositions,
numSequences + 1,
false);
const int* src = arg.sequenceStartPositions->getData(false);
int* dest = sequenceStartPositions->getMutableData(false);
for (int i = 0; i < arg.getNumSequences() + 1; ++i) {
dest[i + startSequences] = src[i] + startRow;
}
startSequences += arg.getNumSequences();
}
copySequencePos(sequenceStartPositions,
arg.sequenceStartPositions,
numSequences,
arg.getNumSequences(),
startSequences,
startRow);
copySequencePos(subSequenceStartPositions,
arg.subSequenceStartPositions,
numSubSequences,
arg.getNumSubSequences(),
startSubSequences,
startRow);
copyStrs(strs, arg.strs, startRow, useGpu);
startRow += arg.getBatchSize();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册