diff --git a/paddle/legacy/capi/Arguments.cpp b/paddle/legacy/capi/Arguments.cpp index 87fac3d6c6abe37b128213d4ffd66f8c1573a910..0ce1770c76c2e145d0b2bf71332cc4593517f195 100644 --- a/paddle/legacy/capi/Arguments.cpp +++ b/paddle/legacy/capi/Arguments.cpp @@ -66,6 +66,17 @@ paddle_error paddle_arguments_get_value(paddle_arguments args, return kPD_NO_ERROR; } +PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args, + uint64_t ID, + paddle_matrix mat) { + if (args == nullptr || mat == nullptr) return kPD_NULLPTR; + auto m = paddle::capi::cast(mat); + auto a = castArg(args); + if (ID >= a->args.size()) return kPD_OUT_OF_RANGE; + m->mat = a->args[ID].in; + return kPD_NO_ERROR; +} + paddle_error paddle_arguments_get_ids(paddle_arguments args, uint64_t ID, paddle_ivector ids) { diff --git a/paddle/legacy/capi/arguments.h b/paddle/legacy/capi/arguments.h index 69a66bb012c318bc8317c246d690a7f4baffd248..ceb64ee6aa74a8ba4b5cb9045b366dcda8f8cc90 100644 --- a/paddle/legacy/capi/arguments.h +++ b/paddle/legacy/capi/arguments.h @@ -87,6 +87,18 @@ PD_API paddle_error paddle_arguments_get_value(paddle_arguments args, uint64_t ID, paddle_matrix mat); +/** + * @brief paddle_arguments_get_prob Get the prob matrix of beam search, which + * slot ID is `ID` + * @param [in] args arguments array + * @param [in] ID array index + * @param [out] mat matrix pointer + * @return paddle_error + */ +PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args, + uint64_t ID, + paddle_matrix mat); + /** * @brief PDArgsGetIds Get the integer vector of one argument in array, which * index is `ID`.