From 08fa3983b0832d08dfa05df55f0a9aa00f94ee5e Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Thu, 12 Jul 2018 12:14:55 +0800 Subject: [PATCH] Get BeamSearch Prob on C-API --- paddle/legacy/capi/Arguments.cpp | 11 +++++++++++ paddle/legacy/capi/arguments.h | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/paddle/legacy/capi/Arguments.cpp b/paddle/legacy/capi/Arguments.cpp index 87fac3d6c..0ce1770c7 100644 --- a/paddle/legacy/capi/Arguments.cpp +++ b/paddle/legacy/capi/Arguments.cpp @@ -66,6 +66,17 @@ paddle_error paddle_arguments_get_value(paddle_arguments args, return kPD_NO_ERROR; } +PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args, + uint64_t ID, + paddle_matrix mat) { + if (args == nullptr || mat == nullptr) return kPD_NULLPTR; + auto m = paddle::capi::cast(mat); + auto a = castArg(args); + if (ID >= a->args.size()) return kPD_OUT_OF_RANGE; + m->mat = a->args[ID].in; + return kPD_NO_ERROR; +} + paddle_error paddle_arguments_get_ids(paddle_arguments args, uint64_t ID, paddle_ivector ids) { diff --git a/paddle/legacy/capi/arguments.h b/paddle/legacy/capi/arguments.h index 69a66bb01..ceb64ee6a 100644 --- a/paddle/legacy/capi/arguments.h +++ b/paddle/legacy/capi/arguments.h @@ -87,6 +87,18 @@ PD_API paddle_error paddle_arguments_get_value(paddle_arguments args, uint64_t ID, paddle_matrix mat); +/** + * @brief paddle_arguments_get_prob Get the prob matrix of beam search, which + * slot ID is `ID` + * @param [in] args arguments array + * @param [in] ID array index + * @param [out] mat matrix pointer + * @return paddle_error + */ +PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args, + uint64_t ID, + paddle_matrix mat); + /** * @brief PDArgsGetIds Get the integer vector of one argument in array, which * index is `ID`. -- GitLab