Argument.cpp 19.8 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include "Argument.h"
17
#include "paddle/math/SparseMatrix.h"
Z
zhangjinchao01 已提交
18 19 20 21 22 23 24

#include <algorithm>

namespace paddle {
static void resizeAndCopy(MatrixPtr& dest, const MatrixPtr& src, bool useGpu,
                          hl_stream_t stream) {
  if (src) {
25 26 27 28 29 30
    if (!dest) {
      dest = src->clone(0, 0, useGpu);
    } else {
      CHECK_EQ(dest->useGpu(), useGpu);
      dest->resize(src->getHeight(), src->getWidth());
    }
Z
zhangjinchao01 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    dest->copyFrom(*src, stream);
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(IVectorPtr& dest, const IVectorPtr& src, bool useGpu,
                          hl_stream_t stream) {
  if (src) {
    IVector::resizeOrCreate(dest, src->getSize(), useGpu);
    dest->copyFrom(*src, stream);
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(ICpuGpuVectorPtr& dest,
                          const ICpuGpuVectorPtr& src,
                          bool useGpu,
                          hl_stream_t stream) {
  if (src) {
    ICpuGpuVector::resizeOrCreate(dest, src->getSize(), useGpu);
    dest->copyFrom(*src, stream);
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(MatrixPtr& dest, const MatrixPtr& src,
                          int32_t startRow, int32_t copySize, bool useGpu,
                          hl_stream_t stream = HPPL_STREAM_DEFAULT) {
  if (src) {
    CHECK_LE((size_t)startRow + copySize, src->getHeight());
    int height = copySize;
    int width = src->getWidth();
66 67 68 69 70 71
    if (!dest) {
      dest = src->clone(height, width, useGpu);
    } else {
      CHECK_EQ(dest->useGpu(), useGpu);
      dest->resize(height, width);
    }
Z
zhangjinchao01 已提交
72
    MatrixPtr submat = src->subMatrix(startRow, copySize);
73 74 75 76 77 78 79 80 81
    if (dynamic_cast<GpuSparseMatrix*>(dest.get())) {
      // copy a subMatrix of CpuSparseMatrix to GpuSparseMatrix.
      // First copy it to CPU, and then copy it to the GPU.
      MatrixPtr tmp = src->clone(height, width, false);
      tmp->copyFrom(*submat, stream);
      dest->copyFrom(*tmp, stream);
    } else {
      dest->copyFrom(*submat, stream);
    }
Z
zhangjinchao01 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(IVectorPtr& dest, const IVectorPtr& src,
                          int32_t startPos, int32_t copySize, bool useGpu,
                          hl_stream_t stream = HPPL_STREAM_DEFAULT) {
  if (src) {
    CHECK_LE((size_t)startPos + copySize, src->getSize());

    int height = copySize;
    IVector::resizeOrCreate(dest, height, useGpu);
    dest->copyFrom(src->getData() + startPos, height, stream);
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(ICpuGpuVectorPtr& dest,
                          const ICpuGpuVectorPtr& src,
                          int32_t startPos,
                          int32_t copySize,
                          bool useGpu,
                          hl_stream_t stream = HPPL_STREAM_DEFAULT) {
  if (src) {
    CHECK_LE((size_t)startPos + copySize, src->getSize());

    ICpuGpuVector::resizeOrCreate(dest, copySize, useGpu);
    dest->copyFrom(*src, startPos, copySize, useGpu, stream);
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(UserDefinedVectorPtr& dest,
                          const UserDefinedVectorPtr& src, bool useGpu,
                          hl_stream_t stream) {
  if (src) {
    CHECK(!useGpu) << "not implemented";
    size_t height = src->size();
    if (!dest) {
      dest = std::make_shared<std::vector<void*>>(height);
    } else {
      dest->resize(height);
    }
    std::copy_n(src->begin(), height, dest->begin());
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(UserDefinedVectorPtr& dest,
                          const UserDefinedVectorPtr& src, int32_t startPos,
                          int32_t copySize, bool useGpu,
                          hl_stream_t stream = HPPL_STREAM_DEFAULT) {
  if (src) {
    CHECK(!useGpu) << "not implemented";
    CHECK_LE((size_t)startPos + copySize, src->size());

    size_t height = copySize;
    if (!dest) {
      dest = std::make_shared<std::vector<void*>>(height);
    } else {
      dest->resize(height);
    }
    std::copy_n(src->begin() + startPos, height, dest->begin());
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(SVectorPtr& dest, const SVectorPtr& src, bool useGpu,
                          hl_stream_t stream) {
  if (src) {
    size_t height = src->size();
    if (!dest) {
      dest = std::make_shared<std::vector<std::string>>(height);
    } else {
      dest->resize(height);
    }
    std::copy_n(src->begin(), height, dest->begin());
  } else {
    dest.reset();
  }
}

static void resizeAndCopy(SVectorPtr& dest, const SVectorPtr& src,
                          int32_t startPos, int32_t copySize, bool useGpu,
                          hl_stream_t stream = HPPL_STREAM_DEFAULT) {
  if (src) {
    CHECK_LE((size_t)startPos + copySize, src->size());
    size_t height = copySize;
    if (!dest) {
      dest = std::make_shared<std::vector<std::string>>(height);
    } else {
      dest->resize(height);
    }
    std::copy_n(src->begin() + startPos, height, dest->begin());
  } else {
    dest.reset();
  }
}

186 187 188 189 190
void Argument::resizeAndCopyFrom(const Argument& src, bool useGpu) {
   resizeAndCopyFrom(src, useGpu, HPPL_STREAM_DEFAULT);
   hl_stream_synchronize(HPPL_STREAM_DEFAULT);
}

Z
zhangjinchao01 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
void Argument::resizeAndCopyFrom(const Argument& src, bool useGpu,
                                 hl_stream_t stream) {
  dataId = src.dataId;
  resizeAndCopy(value, src.value, useGpu, stream);
  resizeAndCopy(grad, src.grad, useGpu, stream);
  resizeAndCopy(in, src.in, useGpu, stream);
  resizeAndCopy(ids, src.ids, useGpu, stream);
  resizeAndCopy(sequenceStartPositions, src.sequenceStartPositions,
                false /* useGpu */, stream);
  if (src.hasSubseq()) {
    resizeAndCopy(subSequenceStartPositions,
                  src.subSequenceStartPositions, false /* useGpu */, stream);
  }
  resizeAndCopy(udp, src.udp, useGpu, stream);
  resizeAndCopy(strs, src.strs, useGpu, stream);
}

208 209 210 211 212 213 214 215
int32_t Argument::resizeAndCopyFrom(const Argument& src, int32_t startSeq,
                                    int32_t copySize, bool useGpu) {
    int32_t size = resizeAndCopyFrom(src, startSeq, copySize, useGpu,
                                     HPPL_STREAM_DEFAULT);
    hl_stream_synchronize(HPPL_STREAM_DEFAULT);
    return size;
}

Z
zhangjinchao01 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
int32_t Argument::resizeAndCopyFrom(const Argument& src, int32_t startSeq,
                                    int32_t copySize, bool useGpu,
                                    hl_stream_t stream) {
  dataId = src.dataId;

  if (!src.sequenceStartPositions) {
    // non-sequence input, copy samples directly
    int32_t startRow = startSeq;
    resizeAndCopy(in, src.in, startRow, copySize, useGpu, stream);
    resizeAndCopy(value, src.value, startRow, copySize, useGpu, stream);
    resizeAndCopy(grad, src.grad, startRow, copySize, useGpu, stream);
    resizeAndCopy(ids, src.ids, startRow, copySize, useGpu, stream);
    resizeAndCopy(udp, src.udp, startRow, copySize, useGpu, stream);
    resizeAndCopy(strs, src.strs, startRow, copySize, useGpu, stream);
    return copySize;
  } else {
    // sequence input
    const int* sequence = src.sequenceStartPositions->getData(false);
    int32_t startRow = sequence[startSeq];           // sample start from here
    int32_t endRow = sequence[startSeq + copySize];  // sample end
    int32_t copyFeatureSize = endRow - startRow;     // num of samples
    resizeAndCopy(in, src.in, startRow, copyFeatureSize, useGpu, stream);
    resizeAndCopy(value, src.value, startRow, copyFeatureSize, useGpu, stream);
    resizeAndCopy(grad, src.grad, startRow, copyFeatureSize, useGpu, stream);
    resizeAndCopy(ids, src.ids, startRow, copyFeatureSize, useGpu, stream);
    resizeAndCopy(udp, src.udp, startRow, copySize, useGpu, stream);
    resizeAndCopy(sequenceStartPositions, src.sequenceStartPositions,
                  startSeq, copySize + 1, false, stream);
    // modify new sequenceStartPositions
    int* destSequences = sequenceStartPositions->getMutableData(false);
    for (int i = 0; i < copySize + 1; i++) {
      destSequences[i] -= startRow;
    }
    CHECK_EQ(destSequences[0], 0);
    CHECK_EQ(destSequences[copySize], copyFeatureSize);
    if (src.hasSubseq()) {
      // sequence has sub-sequence
      int* subSequence = src.subSequenceStartPositions->getMutableData(false);
      int32_t subStartSeq = 0;
      int32_t subEndSeq = 0;
      int numSubSequences = src.getNumSubSequences();
      for (int i = 0; i < numSubSequences + 1; i++) {
        if (subSequence[i] == startRow) {
          subStartSeq = i;
        } else if (subSequence[i] == endRow) {
          subEndSeq = i;
          break;
        }
      }
      int32_t copySubSize = subEndSeq - subStartSeq;
      resizeAndCopy(subSequenceStartPositions,
                    src.subSequenceStartPositions, subStartSeq,
                    copySubSize + 1, false, stream);
      // modify new subSequenceStartPositions
      int* destSubSequences = subSequenceStartPositions->getMutableData(false);
      for (int i = 0; i < copySubSize + 1; i++) {
        destSubSequences[i] -= startRow;
      }
      CHECK_EQ(destSubSequences[0], 0);
      CHECK_EQ(destSubSequences[copySubSize], copyFeatureSize);
    }
    resizeAndCopy(strs, src.strs, startRow, copySize, useGpu, stream);
    return copyFeatureSize;
  }
}

void Argument::concat(const std::vector<Argument>& args,
                      const std::vector<int>& selectRows,
                      const std::vector<int>& seqStartPos, bool useGpu,
                      hl_stream_t stream, PassType passType) {
286 287 288
  CHECK(!subSequenceStartPositions)
          << "undefined behavior for subsequence positions";

Z
zhangjinchao01 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
  size_t batchSize = selectRows.size();
  auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src,
                                     int startRow, int pos, int size,
                                     bool useGpu) {
    if (!src) {
      dst.reset();
      return;
    }
    size_t width = src->getWidth();
    if (!dst) {
      dst = src->clone(batchSize, width, useGpu);
    } else {
      dst->resize(batchSize, width);
    }

    MatrixPtr tmpMatrix = dst->subMatrix(startRow, size);
    tmpMatrix->copyFrom(*src->subMatrix(pos, size), stream);
  };

  auto copyIds = [batchSize, stream](IVectorPtr& dst, const IVectorPtr& src,
                                     int startRow, int pos, int size,
                                     bool useGpu) {
    if (!src) {
      dst.reset();
      return;
    }
    IVector::resizeOrCreate(dst, batchSize, useGpu);
    dst->subVec(startRow, size)->copyFrom(*src->subVec(pos, size), stream);
  };

  auto copyStrs = [batchSize, stream](SVectorPtr& dst, const SVectorPtr& src,
                                      int startRow, int pos, int size,
                                      bool useGpu) {
    if (!src) {
      dst.reset();
      return;
    }
    if (!dst) {
      dst = std::make_shared<std::vector<std::string>>(batchSize);
    } else {
      dst->resize(batchSize);
    }
    std::copy(src->begin() + pos, src->begin() + pos + size,
              dst->begin() + startRow);
  };

  dataId = args[0].dataId;
  CHECK_NE(seqStartPos.size(), 0UL);
  size_t sampleNum = seqStartPos.size() - 1;
  for (size_t i = 0; i < sampleNum; ++i) {
    int startPos = seqStartPos[i];
    int endPos = seqStartPos[i + 1];
    CHECK_GE(args.size(), static_cast<size_t>(endPos - startPos));
    for (int j = startPos; j < endPos; ++j) {
      const Argument& arg = args[j - startPos];
      CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have"
                                   << " same dataId";
      const int copySize = 1;
      const int rowIdx = selectRows[j];
      copyArg(in, arg.in, j, rowIdx, copySize, useGpu);
      copyArg(value, arg.value, j, rowIdx, copySize, useGpu);
      if (passType != PASS_TEST) {
        copyArg(grad, arg.grad, j, rowIdx, copySize, useGpu);
      }
      copyIds(ids, arg.ids, j, rowIdx, copySize, useGpu);
      copyStrs(strs, arg.strs, j, rowIdx, copySize, useGpu);
    }
  }
  ICpuGpuVector::resizeOrCreate(sequenceStartPositions,
                          seqStartPos.size(), useGpu);
  sequenceStartPositions->copyFrom(seqStartPos.data(),
                                   seqStartPos.size(), useGpu);
}

void Argument::concat(const std::vector<Argument>& args, bool useGpu,
                      hl_stream_t stream, PassType passType) {
  int32_t batchSize = 0;
  int64_t numSequences = 0;
367
  int64_t numSubSequences = 0;
Z
zhangjinchao01 已提交
368 369 370
  for (auto& arg : args) {
    batchSize += arg.getBatchSize();
    numSequences += arg.getNumSequences();
371
    numSubSequences += arg.getNumSubSequences();
Z
zhangjinchao01 已提交
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
  }

  auto copyArg = [batchSize, stream](MatrixPtr& dst, MatrixPtr src,
                                     int startRow, bool useGpu) {
    if (!src) {
      dst.reset();
      return;
    }
    size_t width = src->getWidth();
    if (!dst) {
      dst = src->clone(batchSize, width, useGpu);
    } else {
      dst->resize(batchSize, width);
    }

    MatrixPtr tmpMatrix = dst->subMatrix(startRow, src->getHeight());
    tmpMatrix->copyFrom(*src, stream);
  };

  auto copyIds = [batchSize, stream](IVectorPtr& dst, const IVectorPtr& src,
                                     int startRow, bool useGpu) {
    if (!src) {
      dst.reset();
      return;
    }
    IVector::resizeOrCreate(dst, batchSize, useGpu);
    dst->subVec(startRow, src->getSize())->copyFrom(*src, stream);
  };

  auto copyStrs = [batchSize, stream](SVectorPtr& dst, const SVectorPtr& src,
                                      int startRow, bool useGpu) {
    if (!src) {
      dst.reset();
      return;
    }
    if (!dst) {
      dst = std::make_shared<std::vector<std::string>>(batchSize);
    } else {
      dst->resize(batchSize);
    }
    std::copy(src->begin(), src->end(), dst->begin() + startRow);
  };

415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
  auto copySequencePos = []
          (ICpuGpuVectorPtr& dstSeq, const ICpuGpuVectorPtr& srcSeq,
           int dstNumSequences, int srcNumSequences,
           int& startSequences, int startRow) {
      if (srcSeq) {
          ICpuGpuVector::resizeOrCreate(dstSeq, dstNumSequences + 1, false);
          const int* src = srcSeq->getData(false);
          int* dest = dstSeq->getMutableData(false);
          for (int i = 0; i < srcNumSequences + 1; ++i) {
              dest[i + startSequences] = src[i] + startRow;
          }
          startSequences += srcNumSequences;
      } else {
          dstSeq.reset();
      }
  };

Z
zhangjinchao01 已提交
432 433
  int startRow = 0;
  int startSequences = 0;
434
  int startSubSequences = 0;
Z
zhangjinchao01 已提交
435 436 437 438 439 440 441 442
  dataId = args[0].dataId;
  for (auto& arg : args) {
    CHECK_EQ(arg.dataId, dataId) << "Arguments in concat should have"
                                 << " same dataId";
    copyArg(in, arg.in, startRow, useGpu);
    copyArg(value, arg.value, startRow, useGpu);
    if (passType != PASS_TEST) copyArg(grad, arg.grad, startRow, useGpu);
    copyIds(ids, arg.ids, startRow, useGpu);
443 444 445 446 447 448 449 450 451 452 453 454
    copySequencePos(sequenceStartPositions,
                    arg.sequenceStartPositions,
                    numSequences,
                    arg.getNumSequences(),
                    startSequences,
                    startRow);
    copySequencePos(subSequenceStartPositions,
                    arg.subSequenceStartPositions,
                    numSubSequences,
                    arg.getNumSubSequences(),
                    startSubSequences,
                    startRow);
Z
zhangjinchao01 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
    copyStrs(strs, arg.strs, startRow, useGpu);
    startRow += arg.getBatchSize();
  }
}

void Argument::splitByDataId(const std::vector<Argument>& argus,
                             std::vector<std::vector<Argument>>* arguGroups) {
  arguGroups->clear();
  int lastDataId = -1;
  for (const auto& argu : argus) {
    if (argu.dataId == -1) {
      // is -1, then create a new group
      arguGroups->emplace_back();
      lastDataId = -1;
    } else if (argu.dataId != lastDataId) {
      // not -1, also not equal to last Argument, then create a new group
      arguGroups->emplace_back();
      lastDataId = argu.dataId;
    } else {
      // not -1, and equal to last Argument, do nothing
    }
    arguGroups->back().push_back(argu);
  }
}

480
void Argument::getSeqInfo(std::vector<SeqInfo>* seqInfo) const {
Z
zhangjinchao01 已提交
481
  const int* starts = sequenceStartPositions->getData(false);
482 483 484 485 486 487 488 489 490 491 492 493 494 495
  const int* subStarts = hasSubseq()
      ? subSequenceStartPositions->getData(false) : nullptr;
  size_t numSequences = getNumSequences();
  seqInfo->reserve(numSequences);
  int subSeqEnd = 0;
  for (size_t i = 0; i < numSequences; ++i) {
    SeqInfo info;
    info.seqStart = starts[i];
    info.subLevelLength = starts[i + 1] - starts[i];
    info.seqId = i;
    if (hasSubseq()) {
      info.subSeqStart = subSeqEnd;
      while (subStarts[subSeqEnd] < starts[i + 1]) {
        ++subSeqEnd;
Z
zhangjinchao01 已提交
496
      }
497 498 499 500
      info.topLevelLength = subSeqEnd - info.subSeqStart;
    } else {
      info.topLevelLength = info.subLevelLength;
      info.subSeqStart = 0;  // not used
Z
zhangjinchao01 已提交
501
    }
502
    seqInfo->push_back(info);
Z
zhangjinchao01 已提交
503
  }
504 505 506 507
  std::sort(seqInfo->begin(), seqInfo->end(),
            [](const SeqInfo& a, const SeqInfo& b) {
              return a.topLevelLength > b.topLevelLength;
            });
Z
zhangjinchao01 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
}

void Argument::checkSubset() const {
  if (getNumSequences() > getNumSubSequences()) {
    LOG(FATAL) << "numSubSequences is less than numSequences ("
               << getNumSubSequences() << " vs. " << getNumSequences() << ")";
  }
  const int* start = sequenceStartPositions->getData(false);
  const int* subStart = subSequenceStartPositions->getData(false);
  int seqId = 0;
  int subSeqId = 0;
  while (seqId < getNumSequences() && subSeqId < getNumSubSequences()) {
    if (start[seqId] > subStart[subSeqId]) {
      ++subSeqId;
    } else if (start[seqId] == subStart[subSeqId]) {
      ++subSeqId;
      ++seqId;
    } else {
      LOG(FATAL) << "seqStartPositions is not subset of subSeqStartPositions";
    }
  }
  if (seqId < getNumSequences()) {
    LOG(FATAL) << "seqStartPositions is not subset of subSeqStartPositions";
  }
}

void Argument::degradeSequence(const Argument& input, bool useGpu) {
  CHECK_EQ(input.hasSubseq(), 1UL);
  size_t numSequences = input.getNumSequences();
  size_t numSubSequences = input.getNumSubSequences();
  ICpuGpuVector::resizeOrCreate(sequenceStartPositions,
                                 numSequences + 1,
                                 false);
  int* tgtBuf = sequenceStartPositions->getMutableData(false);
  const int* starts = input.sequenceStartPositions->getData(false);
  const int* subStarts = input.subSequenceStartPositions->getData(false);
  int seqId = 0;
  for (size_t subSeqId = 0; subSeqId < numSubSequences; ++subSeqId) {
    if (subStarts[subSeqId] == starts[seqId]) {
      tgtBuf[seqId] = subSeqId;
      seqId++;
    }
  }
  tgtBuf[numSequences] = numSubSequences;
}

void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
                          size_t width, bool useGpu, bool trans, bool seqFlag,
                          size_t seqStart, size_t seqSize) {
557 558 559 560 561 562 563
  if (input.value) {
    value = Matrix::create(input.value->getData() + offset * width,
                           height, width, trans, useGpu);
  }
  if (input.ids) {
    ids = IVector::create(input.ids->getData() + offset, height, useGpu);
  }
Z
zhangjinchao01 已提交
564
  if (input.grad) {
565 566
    grad = Matrix::create(input.grad->getData() + offset * width,
                          height, width, trans, useGpu);
Z
zhangjinchao01 已提交
567 568 569 570 571 572 573 574 575
  }
  if (seqFlag) {
    sequenceStartPositions = std::make_shared<ICpuGpuVector>(
        *(input.sequenceStartPositions),
        seqStart, seqSize);
  }
}

}  // namespace paddle