提交 1c5a7c43 编写于 作者: H hedaoyuan

follow comments

上级 f8c9c889
......@@ -192,6 +192,7 @@ public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), (size_t)1);
CHECK_GT(shape_[0], 1);
numSeqs_ = shape_[0] - 1;
}
......
......@@ -85,6 +85,7 @@ void testBufferArgs(const BufferArgs& inputs,
}
void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) {
EXPECT_EQ(inputs.size(), 1);
check(inputs[0]);
}
......
......@@ -172,7 +172,7 @@ protected:
void initArg(SequenceIdArg& arg, size_t batchSize) {
size_t numSeqs = arg.numSeqs();
int* buf = (int*)arg.data();
int* buf = reinterpret_cast<int*>(arg.data());
int pos = 0;
size_t maxLen = 2 * batchSize / numSeqs;
for (int i = 0; i < (int)numSeqs; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册