提交 0afd5c30 编写于 作者: Y Yu Yang

Stash

上级 b5288289
...@@ -21,27 +21,31 @@ using paddle::capi::cast; ...@@ -21,27 +21,31 @@ using paddle::capi::cast;
#define castIVec(v) cast<paddle::capi::CIVector>(v) #define castIVec(v) cast<paddle::capi::CIVector>(v)
extern "C" { extern "C" {
PD_Arguments PDArgsCreateNone() { return new paddle::capi::CArguments(); } paddle_arguments paddle_arguments_create_none() {
return new paddle::capi::CArguments();
}
paddle_error PDArgsDestroy(PD_Arguments args) { paddle_error paddle_arguments_destroy(paddle_arguments args) {
if (args == nullptr) return kPD_NULLPTR; if (args == nullptr) return kPD_NULLPTR;
delete castArg(args); delete castArg(args);
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsGetSize(PD_Arguments args, uint64_t* size) { paddle_error paddle_arguments_size(paddle_arguments args, uint64_t* size) {
if (args == nullptr || size == nullptr) return kPD_NULLPTR; if (args == nullptr || size == nullptr) return kPD_NULLPTR;
*size = castArg(args)->args.size(); *size = castArg(args)->args.size();
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsResize(PD_Arguments args, uint64_t size) { paddle_error paddle_arguments_resize(paddle_arguments args, uint64_t size) {
if (args == nullptr) return kPD_NULLPTR; if (args == nullptr) return kPD_NULLPTR;
castArg(args)->args.resize(size); castArg(args)->args.resize(size);
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsSetValue(PD_Arguments args, uint64_t ID, paddle_matrix mat) { paddle_error paddle_arguments_set_value(paddle_arguments args,
uint64_t ID,
paddle_matrix mat) {
if (args == nullptr || mat == nullptr) return kPD_NULLPTR; if (args == nullptr || mat == nullptr) return kPD_NULLPTR;
auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat); auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat);
if (m->mat == nullptr) return kPD_NULLPTR; if (m->mat == nullptr) return kPD_NULLPTR;
...@@ -51,7 +55,9 @@ paddle_error PDArgsSetValue(PD_Arguments args, uint64_t ID, paddle_matrix mat) { ...@@ -51,7 +55,9 @@ paddle_error PDArgsSetValue(PD_Arguments args, uint64_t ID, paddle_matrix mat) {
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsGetValue(PD_Arguments args, uint64_t ID, paddle_matrix mat) { paddle_error paddle_arguments_value(paddle_arguments args,
uint64_t ID,
paddle_matrix mat) {
if (args == nullptr || mat == nullptr) return kPD_NULLPTR; if (args == nullptr || mat == nullptr) return kPD_NULLPTR;
auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat); auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat);
auto a = castArg(args); auto a = castArg(args);
...@@ -60,7 +66,9 @@ paddle_error PDArgsGetValue(PD_Arguments args, uint64_t ID, paddle_matrix mat) { ...@@ -60,7 +66,9 @@ paddle_error PDArgsGetValue(PD_Arguments args, uint64_t ID, paddle_matrix mat) {
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsGetIds(PD_Arguments args, uint64_t ID, paddle_ivector ids) { paddle_error paddle_arguments_ids(paddle_arguments args,
uint64_t ID,
paddle_ivector ids) {
if (args == nullptr || ids == nullptr) return kPD_NULLPTR; if (args == nullptr || ids == nullptr) return kPD_NULLPTR;
auto iv = castIVec(ids); auto iv = castIVec(ids);
auto a = castArg(args); auto a = castArg(args);
...@@ -69,7 +77,9 @@ paddle_error PDArgsGetIds(PD_Arguments args, uint64_t ID, paddle_ivector ids) { ...@@ -69,7 +77,9 @@ paddle_error PDArgsGetIds(PD_Arguments args, uint64_t ID, paddle_ivector ids) {
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsSetIds(PD_Arguments args, uint64_t ID, paddle_ivector ids) { paddle_error paddle_arguments_set_ids(paddle_arguments args,
uint64_t ID,
paddle_ivector ids) {
//! TODO(lizhao): Complete this method. //! TODO(lizhao): Complete this method.
if (args == nullptr || ids == nullptr) return kPD_NULLPTR; if (args == nullptr || ids == nullptr) return kPD_NULLPTR;
auto iv = paddle::capi::cast<paddle::capi::CIVector>(ids); auto iv = paddle::capi::cast<paddle::capi::CIVector>(ids);
...@@ -80,7 +90,7 @@ paddle_error PDArgsSetIds(PD_Arguments args, uint64_t ID, paddle_ivector ids) { ...@@ -80,7 +90,7 @@ paddle_error PDArgsSetIds(PD_Arguments args, uint64_t ID, paddle_ivector ids) {
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsSetSequenceStartPos(PD_Arguments args, paddle_error paddle_arguments_set_sequence_start_pos(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector seqPos) { paddle_ivector seqPos) {
if (args == nullptr || seqPos == nullptr) return kPD_NULLPTR; if (args == nullptr || seqPos == nullptr) return kPD_NULLPTR;
...@@ -93,9 +103,8 @@ paddle_error PDArgsSetSequenceStartPos(PD_Arguments args, ...@@ -93,9 +103,8 @@ paddle_error PDArgsSetSequenceStartPos(PD_Arguments args,
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsSetSubSequenceStartPos(PD_Arguments args, paddle_error paddle_arguments_set_sub_sequence_start_pos(
uint64_t ID, paddle_arguments args, uint64_t ID, paddle_ivector subSeqPos) {
paddle_ivector subSeqPos) {
if (args == nullptr || subSeqPos == nullptr) return kPD_NULLPTR; if (args == nullptr || subSeqPos == nullptr) return kPD_NULLPTR;
auto iv = paddle::capi::cast<paddle::capi::CIVector>(subSeqPos); auto iv = paddle::capi::cast<paddle::capi::CIVector>(subSeqPos);
if (iv->vec == nullptr) return kPD_NULLPTR; if (iv->vec == nullptr) return kPD_NULLPTR;
...@@ -106,7 +115,7 @@ paddle_error PDArgsSetSubSequenceStartPos(PD_Arguments args, ...@@ -106,7 +115,7 @@ paddle_error PDArgsSetSubSequenceStartPos(PD_Arguments args,
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsGetSequenceStartPos(PD_Arguments args, paddle_error paddle_arguments_sequence_start_pos(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector seqPos) { paddle_ivector seqPos) {
if (args == nullptr || seqPos == nullptr) return kPD_NULLPTR; if (args == nullptr || seqPos == nullptr) return kPD_NULLPTR;
...@@ -118,7 +127,7 @@ paddle_error PDArgsGetSequenceStartPos(PD_Arguments args, ...@@ -118,7 +127,7 @@ paddle_error PDArgsGetSequenceStartPos(PD_Arguments args,
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error PDArgsGetSubSequenceStartPos(PD_Arguments args, paddle_error paddle_arguments_sub_sequence_start_pos(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector subSeqPos) { paddle_ivector subSeqPos) {
if (args == nullptr || subSeqPos == nullptr) return kPD_NULLPTR; if (args == nullptr || subSeqPos == nullptr) return kPD_NULLPTR;
......
...@@ -70,8 +70,8 @@ paddle_error PDGradientMachineLoadParameterFromDisk(PD_GradientMachine machine, ...@@ -70,8 +70,8 @@ paddle_error PDGradientMachineLoadParameterFromDisk(PD_GradientMachine machine,
} }
paddle_error PDGradientMachineForward(PD_GradientMachine machine, paddle_error PDGradientMachineForward(PD_GradientMachine machine,
PD_Arguments inArgs, paddle_arguments inArgs,
PD_Arguments outArgs, paddle_arguments outArgs,
bool isTrain) { bool isTrain) {
auto m = cast(machine); auto m = cast(machine);
auto in = paddle::capi::cast<paddle::capi::CArguments>(inArgs); auto in = paddle::capi::cast<paddle::capi::CArguments>(inArgs);
......
...@@ -36,20 +36,21 @@ extern "C" { ...@@ -36,20 +36,21 @@ extern "C" {
* Arguments functions. Each argument means layer output. Arguments means a * Arguments functions. Each argument means layer output. Arguments means a
* array of arguemnt. * array of arguemnt.
*/ */
typedef void* PD_Arguments; typedef void* paddle_arguments;
/** /**
* @brief PDArgsCreateNone Create a array of arguments, which size is zero. * @brief paddle_arguments_create_none Create a array of arguments, which size
* is zero.
* @return Arguemnts * @return Arguemnts
*/ */
PD_API PD_Arguments PDArgsCreateNone(); PD_API paddle_arguments paddle_arguments_create_none();
/** /**
* @brief PDArgsDestroy Destroy the arguments * @brief paddle_arguments_destroy Destroy the arguments
* @param args arguments to destroy * @param args arguments to destroy
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsDestroy(PD_Arguments args); PD_API paddle_error paddle_arguments_destroy(paddle_arguments args);
/** /**
* @brief PDArgsGetSize Get size of arguments array * @brief PDArgsGetSize Get size of arguments array
...@@ -57,7 +58,8 @@ PD_API paddle_error PDArgsDestroy(PD_Arguments args); ...@@ -57,7 +58,8 @@ PD_API paddle_error PDArgsDestroy(PD_Arguments args);
* @param [out] size array size * @param [out] size array size
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsGetSize(PD_Arguments args, uint64_t* size); PD_API paddle_error paddle_arguments_size(paddle_arguments args,
uint64_t* size);
/** /**
* @brief PDArgsResize Resize a arguments array. * @brief PDArgsResize Resize a arguments array.
...@@ -65,7 +67,8 @@ PD_API paddle_error PDArgsGetSize(PD_Arguments args, uint64_t* size); ...@@ -65,7 +67,8 @@ PD_API paddle_error PDArgsGetSize(PD_Arguments args, uint64_t* size);
* @param size target size of array * @param size target size of array
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsResize(PD_Arguments args, uint64_t size); PD_API paddle_error paddle_arguments_resize(paddle_arguments args,
uint64_t size);
/** /**
* @brief PDArgsSetValue Set value matrix of one argument in array, which index * @brief PDArgsSetValue Set value matrix of one argument in array, which index
...@@ -75,7 +78,7 @@ PD_API paddle_error PDArgsResize(PD_Arguments args, uint64_t size); ...@@ -75,7 +78,7 @@ PD_API paddle_error PDArgsResize(PD_Arguments args, uint64_t size);
* @param mat matrix pointer * @param mat matrix pointer
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsSetValue(PD_Arguments args, PD_API paddle_error paddle_arguments_set_value(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_matrix mat); paddle_matrix mat);
...@@ -87,7 +90,7 @@ PD_API paddle_error PDArgsSetValue(PD_Arguments args, ...@@ -87,7 +90,7 @@ PD_API paddle_error PDArgsSetValue(PD_Arguments args,
* @param [out] mat matrix pointer * @param [out] mat matrix pointer
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsGetValue(PD_Arguments args, PD_API paddle_error paddle_arguments_value(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_matrix mat); paddle_matrix mat);
...@@ -99,7 +102,7 @@ PD_API paddle_error PDArgsGetValue(PD_Arguments args, ...@@ -99,7 +102,7 @@ PD_API paddle_error PDArgsGetValue(PD_Arguments args,
* @param ids integer vector pointer * @param ids integer vector pointer
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsGetIds(PD_Arguments args, PD_API paddle_error paddle_arguments_ids(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector ids); paddle_ivector ids);
...@@ -111,7 +114,7 @@ PD_API paddle_error PDArgsGetIds(PD_Arguments args, ...@@ -111,7 +114,7 @@ PD_API paddle_error PDArgsGetIds(PD_Arguments args,
* @param [out] ids integer vector pointer * @param [out] ids integer vector pointer
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsSetIds(PD_Arguments args, PD_API paddle_error paddle_arguments_set_ids(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector ids); paddle_ivector ids);
...@@ -123,9 +126,8 @@ PD_API paddle_error PDArgsSetIds(PD_Arguments args, ...@@ -123,9 +126,8 @@ PD_API paddle_error PDArgsSetIds(PD_Arguments args,
* @param seqPos sequence position array. * @param seqPos sequence position array.
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsSetSequenceStartPos(PD_Arguments args, PD_API paddle_error paddle_arguments_set_sequence_start_pos(
uint64_t ID, paddle_arguments args, uint64_t ID, paddle_ivector seqPos);
paddle_ivector seqPos);
/** /**
* @brief PDArgsGetSequenceStartPos Get sequence start position vector of one * @brief PDArgsGetSequenceStartPos Get sequence start position vector of one
* argument in array, which index is `ID`. * argument in array, which index is `ID`.
...@@ -134,7 +136,7 @@ PD_API paddle_error PDArgsSetSequenceStartPos(PD_Arguments args, ...@@ -134,7 +136,7 @@ PD_API paddle_error PDArgsSetSequenceStartPos(PD_Arguments args,
* @param [out] seqPos sequence position array * @param [out] seqPos sequence position array
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsGetSequenceStartPos(PD_Arguments args, PD_API paddle_error paddle_arguments_sequence_start_pos(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector seqPos); paddle_ivector seqPos);
...@@ -146,9 +148,8 @@ PD_API paddle_error PDArgsGetSequenceStartPos(PD_Arguments args, ...@@ -146,9 +148,8 @@ PD_API paddle_error PDArgsGetSequenceStartPos(PD_Arguments args,
* @param subSeqPos sub-sequence start position array. * @param subSeqPos sub-sequence start position array.
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsSetSubSequenceStartPos(PD_Arguments args, PD_API paddle_error paddle_arguments_set_sub_sequence_start_pos(
uint64_t ID, paddle_arguments args, uint64_t ID, paddle_ivector subSeqPos);
paddle_ivector subSeqPos);
/** /**
* @brief PDArgsGetSubSequenceStartPos Get sub-sequence start position vector of * @brief PDArgsGetSubSequenceStartPos Get sub-sequence start position vector of
...@@ -158,9 +159,8 @@ PD_API paddle_error PDArgsSetSubSequenceStartPos(PD_Arguments args, ...@@ -158,9 +159,8 @@ PD_API paddle_error PDArgsSetSubSequenceStartPos(PD_Arguments args,
* @param subSeqPos sub-sequence start position array * @param subSeqPos sub-sequence start position array
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDArgsGetSubSequenceStartPos(PD_Arguments args, PD_API paddle_error paddle_arguments_sub_sequence_start_pos(
uint64_t ID, paddle_arguments args, uint64_t ID, paddle_ivector subSeqPos);
paddle_ivector subSeqPos);
/** /**
* @brief GradientMachine means a neural network. * @brief GradientMachine means a neural network.
*/ */
...@@ -195,8 +195,8 @@ PD_API paddle_error PDGradientMachineLoadParameterFromDisk( ...@@ -195,8 +195,8 @@ PD_API paddle_error PDGradientMachineLoadParameterFromDisk(
* @return paddle_error * @return paddle_error
*/ */
PD_API paddle_error PDGradientMachineForward(PD_GradientMachine machine, PD_API paddle_error PDGradientMachineForward(PD_GradientMachine machine,
PD_Arguments inArgs, paddle_arguments inArgs,
PD_Arguments outArgs, paddle_arguments outArgs,
bool isTrain); bool isTrain);
/** /**
......
...@@ -28,27 +28,27 @@ static std::vector<pd_real> randomBuffer(size_t bufSize) { ...@@ -28,27 +28,27 @@ static std::vector<pd_real> randomBuffer(size_t bufSize) {
} }
TEST(CAPIArguments, create) { TEST(CAPIArguments, create) {
PD_Arguments args = PDArgsCreateNone(); paddle_arguments args = paddle_arguments_create_none();
uint64_t size; uint64_t size;
ASSERT_EQ(kPD_NO_ERROR, PDArgsGetSize(args, &size)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_size(args, &size));
ASSERT_EQ(0UL, size); ASSERT_EQ(0UL, size);
ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(args)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_destroy(args));
} }
TEST(CAPIArguments, value) { TEST(CAPIArguments, value) {
PD_Arguments args = PDArgsCreateNone(); paddle_arguments args = paddle_arguments_create_none();
ASSERT_EQ(kPD_NO_ERROR, PDArgsResize(args, 1)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_resize(args, 1));
paddle_matrix mat = paddle_matrix_create(128, 64, false); paddle_matrix mat = paddle_matrix_create(128, 64, false);
for (size_t i = 0; i < 128; ++i) { for (size_t i = 0; i < 128; ++i) {
std::vector<pd_real> sampleBuf = randomBuffer(64); std::vector<pd_real> sampleBuf = randomBuffer(64);
paddle_matrix_set_row(mat, i, sampleBuf.data()); paddle_matrix_set_row(mat, i, sampleBuf.data());
} }
ASSERT_EQ(kPD_NO_ERROR, PDArgsSetValue(args, 0, mat)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_set_value(args, 0, mat));
paddle_matrix val = paddle_matrix_create_none(); paddle_matrix val = paddle_matrix_create_none();
ASSERT_EQ(kPD_NO_ERROR, PDArgsGetValue(args, 0, val)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_value(args, 0, val));
for (size_t i = 0; i < 128; ++i) { for (size_t i = 0; i < 128; ++i) {
pd_real* row1; pd_real* row1;
...@@ -63,29 +63,29 @@ TEST(CAPIArguments, value) { ...@@ -63,29 +63,29 @@ TEST(CAPIArguments, value) {
ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(ivec)); ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(ivec));
ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_destroy(val)); ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_destroy(val));
ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_destroy(mat)); ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_destroy(mat));
ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(args)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_destroy(args));
} }
TEST(CAPIArguments, ids) { TEST(CAPIArguments, ids) {
PD_Arguments args = PDArgsCreateNone(); paddle_arguments args = paddle_arguments_create_none();
ASSERT_EQ(kPD_NO_ERROR, PDArgsResize(args, 1)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_resize(args, 1));
paddle_ivector ivec; paddle_ivector ivec;
int array[3] = {1, 2, 3}; int array[3] = {1, 2, 3};
ivec = paddle_ivector_create(array, 3, true, false); ivec = paddle_ivector_create(array, 3, true, false);
ASSERT_EQ(kPD_NO_ERROR, PDArgsSetIds(args, 0, ivec)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_set_ids(args, 0, ivec));
paddle_ivector val = paddle_ivector_create_none(); paddle_ivector val = paddle_ivector_create_none();
ASSERT_EQ(kPD_NO_ERROR, PDArgsGetIds(args, 0, val)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_ids(args, 0, val));
ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(ivec)); ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(ivec));
ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(val)); ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(val));
ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(args)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_destroy(args));
} }
template <typename T1, typename T2> template <typename T1, typename T2>
void testSequenceHelper(T1 setter, T2 getter) { void testSequenceHelper(T1 setter, T2 getter) {
PD_Arguments args = PDArgsCreateNone(); paddle_arguments args = paddle_arguments_create_none();
ASSERT_EQ(kPD_NO_ERROR, PDArgsResize(args, 1)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_resize(args, 1));
paddle_ivector ivec; paddle_ivector ivec;
int array[3] = {1, 2, 3}; int array[3] = {1, 2, 3};
...@@ -105,11 +105,12 @@ void testSequenceHelper(T1 setter, T2 getter) { ...@@ -105,11 +105,12 @@ void testSequenceHelper(T1 setter, T2 getter) {
ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(ivec)); ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(ivec));
ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(val)); ASSERT_EQ(kPD_NO_ERROR, paddle_ivector_destroy(val));
ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(args)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_destroy(args));
} }
TEST(CAPIArguments, Sequence) { TEST(CAPIArguments, Sequence) {
testSequenceHelper(PDArgsSetSequenceStartPos, PDArgsGetSequenceStartPos); testSequenceHelper(paddle_arguments_set_sequence_start_pos,
testSequenceHelper(PDArgsSetSubSequenceStartPos, paddle_arguments_sequence_start_pos);
PDArgsGetSubSequenceStartPos); testSequenceHelper(paddle_arguments_set_sub_sequence_start_pos,
paddle_arguments_sub_sequence_start_pos);
} }
...@@ -55,10 +55,10 @@ TEST(GradientMachine, testPredict) { ...@@ -55,10 +55,10 @@ TEST(GradientMachine, testPredict) {
PDGradientMachineCreateSharedParam( PDGradientMachineCreateSharedParam(
machine, &buffer[0], (int)buffer.size(), &machineSlave)); machine, &buffer[0], (int)buffer.size(), &machineSlave));
std::swap(machineSlave, machine); std::swap(machineSlave, machine);
PD_Arguments outArgs = PDArgsCreateNone(); paddle_arguments outArgs = paddle_arguments_create_none();
PD_Arguments inArgs = PDArgsCreateNone(); paddle_arguments inArgs = paddle_arguments_create_none();
ASSERT_EQ(kPD_NO_ERROR, PDArgsResize(inArgs, 1)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_resize(inArgs, 1));
paddle_matrix mat = paddle_matrix_create(1, 100, false); paddle_matrix mat = paddle_matrix_create(1, 100, false);
static_assert(std::is_same<pd_real, paddle::real>::value, ""); static_assert(std::is_same<pd_real, paddle::real>::value, "");
...@@ -67,15 +67,15 @@ TEST(GradientMachine, testPredict) { ...@@ -67,15 +67,15 @@ TEST(GradientMachine, testPredict) {
ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_get_row(mat, 0, &rowPtr)); ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_get_row(mat, 0, &rowPtr));
memcpy(rowPtr, data.data(), data.size() * sizeof(pd_real)); memcpy(rowPtr, data.data(), data.size() * sizeof(pd_real));
ASSERT_EQ(kPD_NO_ERROR, PDArgsSetValue(inArgs, 0, mat)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_set_value(inArgs, 0, mat));
ASSERT_EQ(kPD_NO_ERROR, ASSERT_EQ(kPD_NO_ERROR,
PDGradientMachineForward(machine, inArgs, outArgs, false)); PDGradientMachineForward(machine, inArgs, outArgs, false));
uint64_t sz; uint64_t sz;
ASSERT_EQ(kPD_NO_ERROR, PDArgsGetSize(outArgs, &sz)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_size(outArgs, &sz));
ASSERT_EQ(1UL, sz); ASSERT_EQ(1UL, sz);
ASSERT_EQ(kPD_NO_ERROR, PDArgsGetValue(outArgs, 0, mat)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_value(outArgs, 0, mat));
std::vector<paddle::Argument> paddleInArgs; std::vector<paddle::Argument> paddleInArgs;
std::vector<paddle::Argument> paddleOutArgs; std::vector<paddle::Argument> paddleOutArgs;
paddleInArgs.resize(1); paddleInArgs.resize(1);
...@@ -97,8 +97,8 @@ TEST(GradientMachine, testPredict) { ...@@ -97,8 +97,8 @@ TEST(GradientMachine, testPredict) {
} }
ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_destroy(mat)); ASSERT_EQ(kPD_NO_ERROR, paddle_matrix_destroy(mat));
ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(inArgs)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_destroy(inArgs));
ASSERT_EQ(kPD_NO_ERROR, PDArgsDestroy(outArgs)); ASSERT_EQ(kPD_NO_ERROR, paddle_arguments_destroy(outArgs));
std::swap(machineSlave, machine); std::swap(machineSlave, machine);
ASSERT_EQ(kPD_NO_ERROR, PDGradientMachineDestroy(machineSlave)); ASSERT_EQ(kPD_NO_ERROR, PDGradientMachineDestroy(machineSlave));
ASSERT_EQ(kPD_NO_ERROR, PDGradientMachineDestroy(machine)); ASSERT_EQ(kPD_NO_ERROR, PDGradientMachineDestroy(machine));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册