未验证 提交 bc8c7ccd 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #12114 from reyoung/feature/get_beam_search_prob_on_capi

Get BeamSearch Prob on C-API
...@@ -66,6 +66,17 @@ paddle_error paddle_arguments_get_value(paddle_arguments args, ...@@ -66,6 +66,17 @@ paddle_error paddle_arguments_get_value(paddle_arguments args,
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args,
uint64_t ID,
paddle_matrix mat) {
if (args == nullptr || mat == nullptr) return kPD_NULLPTR;
auto m = paddle::capi::cast<paddle::capi::CMatrix>(mat);
auto a = castArg(args);
if (ID >= a->args.size()) return kPD_OUT_OF_RANGE;
m->mat = a->args[ID].in;
return kPD_NO_ERROR;
}
paddle_error paddle_arguments_get_ids(paddle_arguments args, paddle_error paddle_arguments_get_ids(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_ivector ids) { paddle_ivector ids) {
......
...@@ -87,6 +87,18 @@ PD_API paddle_error paddle_arguments_get_value(paddle_arguments args, ...@@ -87,6 +87,18 @@ PD_API paddle_error paddle_arguments_get_value(paddle_arguments args,
uint64_t ID, uint64_t ID,
paddle_matrix mat); paddle_matrix mat);
/**
* @brief paddle_arguments_get_prob Get the prob matrix of beam search, which
* slot ID is `ID`
* @param [in] args arguments array
* @param [in] ID array index
* @param [out] mat matrix pointer
* @return paddle_error
*/
PD_API paddle_error paddle_arguments_get_prob(paddle_arguments args,
uint64_t ID,
paddle_matrix mat);
/** /**
* @brief PDArgsGetIds Get the integer vector of one argument in array, which * @brief PDArgsGetIds Get the integer vector of one argument in array, which
* index is `ID`. * index is `ID`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册