diff --git a/paddle/capi/Arguments.cpp b/paddle/capi/Arguments.cpp index baabd44cc0c63caea6838ada9ea94e2b71876ea1..8d00bda3cb90cda9950533fff86c7c2f54b0448e 100644 --- a/paddle/capi/Arguments.cpp +++ b/paddle/capi/Arguments.cpp @@ -87,7 +87,6 @@ int PDArgsSetIds(PD_Arguments args, uint64_t ID, PD_IVector ids) { int PDArgsSetSequenceStartPos(PD_Arguments args, uint64_t ID, PD_IVector seqPos) { - //! TODO(lizhao): Complete this method. if (args == nullptr || seqPos == nullptr) return kPD_NULLPTR; auto iv = paddle::capi::cast(seqPos); if (iv->vec == nullptr) return kPD_NULLPTR; @@ -101,13 +100,12 @@ int PDArgsSetSequenceStartPos(PD_Arguments args, int PDArgsSetSubSequenceStartPos(PD_Arguments args, uint64_t ID, PD_IVector subSeqPos) { - //! TODO(lizhao): Complete this method. if (args == nullptr || subSeqPos == nullptr) return kPD_NULLPTR; auto iv = paddle::capi::cast(subSeqPos); if (iv->vec == nullptr) return kPD_NULLPTR; auto a = castArg(args); if (ID >= a->args.size()) return kPD_OUT_OF_RANGE; - a->args[ID].sequenceStartPositions = + a->args[ID].subSequenceStartPositions = std::make_shared(iv->vec); return kPD_NO_ERROR; } @@ -115,26 +113,24 @@ int PDArgsSetSubSequenceStartPos(PD_Arguments args, int PDArgsGetSequenceStartPos(PD_Arguments args, uint64_t ID, PD_IVector seqPos) { - //! TODO(lizhao): Complete this method. if (args == nullptr || seqPos == nullptr) return kPD_NULLPTR; auto iv = castIVec(seqPos); auto a = castArg(args); if (ID >= a->args.size()) return kPD_OUT_OF_RANGE; - std::make_shared(iv->vec) = - a->args[ID].sequenceStartPositions; + paddle::Argument& arg = a->args[ID]; + iv->vec = arg.sequenceStartPositions->getMutableVector(false); return kPD_NO_ERROR; } int PDArgsGetSubSequenceStartPos(PD_Arguments args, uint64_t ID, PD_IVector subSeqPos) { - //! TODO(lizhao): Complete this method. if (args == nullptr || subSeqPos == nullptr) return kPD_NULLPTR; auto iv = castIVec(subSeqPos); auto a = castArg(args); if (ID >= a->args.size()) return kPD_OUT_OF_RANGE; - std::make_shared(iv->vec) = - a->args[ID].sequenceStartPositions; + paddle::Argument& arg = a->args[ID]; + iv->vec = arg.subSequenceStartPositions->getMutableVector(false); return kPD_NO_ERROR; } } diff --git a/paddle/capi/Vector.cpp b/paddle/capi/Vector.cpp index 38a5fbc00a8bebd524816d8a39364d7ca0c196c8..af2192551370f4d6ec3e0affc2cf8fa5b3ac2021 100644 --- a/paddle/capi/Vector.cpp +++ b/paddle/capi/Vector.cpp @@ -27,7 +27,6 @@ int PDIVecCreateNone(PD_IVector* ivec) { } int PDIVectorCreate(PD_IVector* ivec, int* array, uint64_t size, bool copy) { - //! TODO(lizhao): Complete this method. if (ivec == nullptr) return kPD_NULLPTR; auto ptr = new paddle::capi::CIVector(); if (copy) { @@ -55,7 +54,6 @@ int PDIVectorGet(PD_IVector ivec, int** buffer) { } int PDIVectorResize(PD_IVector ivec, uint64_t size) { - //! TODO(lizhao): Complete this method. if (ivec == nullptr) return kPD_NULLPTR; auto v = cast(ivec); if (v->vec == nullptr) return kPD_NULLPTR; @@ -64,7 +62,6 @@ int PDIVectorResize(PD_IVector ivec, uint64_t size) { } int PDIVectorGetSize(PD_IVector ivec, uint64_t* size) { - //! TODO(lizhao): Complete this method. if (ivec == nullptr) return kPD_NULLPTR; auto v = cast(ivec); if (v->vec == nullptr) return kPD_NULLPTR; diff --git a/paddle/capi/tests/test_Arguments.cpp b/paddle/capi/tests/test_Arguments.cpp index 1186d2921ba463b9b7e4156e9985e458d52484b0..9357f3a58468e9d7ffc38fc595895c5e5568a09a 100644 --- a/paddle/capi/tests/test_Arguments.cpp +++ b/paddle/capi/tests/test_Arguments.cpp @@ -89,7 +89,8 @@ TEST(CAPIArguments, ids) { ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(args)); } -TEST(CAPIArguments, Sequence) { +template +void testSequenceHelper(T1 setter, T2 getter) { PD_Arguments args; ASSERT_EQ(kPD_NO_ERROR, PDArgsCreateNone(&args)); ASSERT_EQ(kPD_NO_ERROR, PDArgsResize(args, 1)); @@ -97,12 +98,27 @@ TEST(CAPIArguments, Sequence) { PD_IVector ivec; int array[3] = {1, 2, 3}; ASSERT_EQ(kPD_NO_ERROR, PDIVectorCreate(&ivec, array, 3, true)); - ASSERT_EQ(kPD_NO_ERROR, PDArgsSetSequenceStartPos(args, 0, ivec)); + ASSERT_EQ(kPD_NO_ERROR, setter(args, 0, ivec)); PD_IVector val; ASSERT_EQ(kPD_NO_ERROR, PDIVecCreateNone(&val)); - ASSERT_EQ(kPD_NO_ERROR, PDArgsGetSequenceStartPos(args, 0, val)); + ASSERT_EQ(kPD_NO_ERROR, getter(args, 0, val)); + uint64_t size; + ASSERT_EQ(kPD_NO_ERROR, PDIVectorGetSize(val, &size)); + + int* rawBuf; + ASSERT_EQ(kPD_NO_ERROR, PDIVectorGet(val, &rawBuf)); + for (size_t i = 0; i < size; ++i) { + ASSERT_EQ(array[i], rawBuf[i]); + } + ASSERT_EQ(kPD_NO_ERROR, PDIVecDestroy(ivec)); ASSERT_EQ(kPD_NO_ERROR, PDIVecDestroy(val)); ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(args)); } + +TEST(CAPIArguments, Sequence) { + testSequenceHelper(PDArgsSetSequenceStartPos, PDArgsGetSequenceStartPos); + testSequenceHelper(PDArgsSetSubSequenceStartPos, + PDArgsGetSubSequenceStartPos); +}