提交 1e1a33b5 编写于 作者: H Haonan 提交者: Yu Yang

Argument concat for subsequence start positions

Change-Id: Ia60c008a8c922f66e6b5e2ca3e488fc4625d6506
上级 c3c76d69
...@@ -269,6 +269,9 @@ void Argument::concat(const std::vector<Argument>& args, ...@@ -269,6 +269,9 @@ void Argument::concat(const std::vector<Argument>& args,
const std::vector<int>& selectRows, const std::vector<int>& selectRows,
const std::vector<int>& seqStartPos, bool useGpu, const std::vector<int>& seqStartPos, bool useGpu,
hl_stream_t stream, PassType passType) { hl_stream_t stream, PassType passType) {
CHECK(!subSequenceStartPositions)
<< "undefined behavior for subsequence positions";
size_t batchSize = selectRows.size(); size_t batchSize = selectRows.size();
auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src, auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src,
int startRow, int pos, int size, int startRow, int pos, int size,
...@@ -347,9 +350,11 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu, ...@@ -347,9 +350,11 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu,
hl_stream_t stream, PassType passType) { hl_stream_t stream, PassType passType) {
int32_t batchSize = 0; int32_t batchSize = 0;
int64_t numSequences = 0; int64_t numSequences = 0;
int64_t numSubSequences = 0;
for (auto& arg : args) { for (auto& arg : args) {
batchSize += arg.getBatchSize(); batchSize += arg.getBatchSize();
numSequences += arg.getNumSequences(); numSequences += arg.getNumSequences();
numSubSequences += arg.getNumSubSequences();
} }
auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src, auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src,
...@@ -393,8 +398,26 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu, ...@@ -393,8 +398,26 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu,
std::copy(src->begin(), src->end(), dst->begin() + startRow); std::copy(src->begin(), src->end(), dst->begin() + startRow);
}; };
auto copySequencePos = []
(ICpuGpuVectorPtr& dstSeq, const ICpuGpuVectorPtr& srcSeq,
int dstNumSequences, int srcNumSequences,
int& startSequences, int startRow) {
if (srcSeq) {
ICpuGpuVector::resizeOrCreate(dstSeq, dstNumSequences + 1, false);
const int* src = srcSeq->getData(false);
int* dest = dstSeq->getMutableData(false);
for (int i = 0; i < srcNumSequences + 1; ++i) {
dest[i + startSequences] = src[i] + startRow;
}
startSequences += srcNumSequences;
} else {
dstSeq.reset();
}
};
int startRow = 0; int startRow = 0;
int startSequences = 0; int startSequences = 0;
int startSubSequences = 0;
dataId = args[0].dataId; dataId = args[0].dataId;
for (auto& arg : args) { for (auto& arg : args) {
CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have" CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have"
...@@ -403,17 +426,18 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu, ...@@ -403,17 +426,18 @@ void Argument::concat(const std::vector<Argument>& args, bool useGpu,
copyArg(value, arg.value, startRow, useGpu); copyArg(value, arg.value, startRow, useGpu);
if (passType != PASS_TEST) copyArg(grad, arg.grad, startRow, useGpu); if (passType != PASS_TEST) copyArg(grad, arg.grad, startRow, useGpu);
copyIds(ids, arg.ids, startRow, useGpu); copyIds(ids, arg.ids, startRow, useGpu);
if (arg.sequenceStartPositions) { copySequencePos(sequenceStartPositions,
ICpuGpuVector::resizeOrCreate(sequenceStartPositions, arg.sequenceStartPositions,
numSequences + 1, numSequences,
false); arg.getNumSequences(),
const int* src = arg.sequenceStartPositions->getData(false); startSequences,
int* dest = sequenceStartPositions->getMutableData(false); startRow);
for (int i = 0; i < arg.getNumSequences() + 1; ++i) { copySequencePos(subSequenceStartPositions,
dest[i + startSequences] = src[i] + startRow; arg.subSequenceStartPositions,
} numSubSequences,
startSequences += arg.getNumSequences(); arg.getNumSubSequences(),
} startSubSequences,
startRow);
copyStrs(strs, arg.strs, startRow, useGpu); copyStrs(strs, arg.strs, startRow, useGpu);
startRow += arg.getBatchSize(); startRow += arg.getBatchSize();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册