未验证 提交 71ab0458 编写于 作者: L liu zhengxi 提交者: GitHub

Fix pointer and c-api encapsulation (#22663)

* refine pointer and c-api prototype, test=develop

* fix new c api profile bug, test=develop

* add unit tests, test=develop
上级 0cb121cf
...@@ -117,8 +117,8 @@ PADDLE_CAPI_EXPORT extern PD_DataType PD_GetPaddleTensorDType( ...@@ -117,8 +117,8 @@ PADDLE_CAPI_EXPORT extern PD_DataType PD_GetPaddleTensorDType(
PADDLE_CAPI_EXPORT extern PD_PaddleBuf* PD_GetPaddleTensorData( PADDLE_CAPI_EXPORT extern PD_PaddleBuf* PD_GetPaddleTensorData(
const PD_Tensor* tensor); const PD_Tensor* tensor);
PADDLE_CAPI_EXPORT extern int* PD_GetPaddleTensorShape(const PD_Tensor* tensor, PADDLE_CAPI_EXPORT extern const int* PD_GetPaddleTensorShape(
int** size); const PD_Tensor* tensor, int* size);
// AnalysisPredictor // AnalysisPredictor
PADDLE_CAPI_EXPORT extern bool PD_PredictorRun(const PD_AnalysisConfig* config, PADDLE_CAPI_EXPORT extern bool PD_PredictorRun(const PD_AnalysisConfig* config,
...@@ -262,22 +262,32 @@ PADDLE_CAPI_EXPORT extern bool PD_ProfileEnabled( ...@@ -262,22 +262,32 @@ PADDLE_CAPI_EXPORT extern bool PD_ProfileEnabled(
PADDLE_CAPI_EXPORT extern void PD_SetInValid(PD_AnalysisConfig* config); PADDLE_CAPI_EXPORT extern void PD_SetInValid(PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern bool PD_IsValid(const PD_AnalysisConfig* config); PADDLE_CAPI_EXPORT extern bool PD_IsValid(const PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_DisableGlogInfo(PD_AnalysisConfig* config); PADDLE_CAPI_EXPORT extern void PD_DisableGlogInfo(PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_DeletePass(PD_AnalysisConfig* config, PADDLE_CAPI_EXPORT extern void PD_DeletePass(PD_AnalysisConfig* config,
char* pass_name); char* pass_name);
PADDLE_CAPI_EXPORT extern PD_Predictor* PD_NewPredictor( PADDLE_CAPI_EXPORT extern PD_Predictor* PD_NewPredictor(
const PD_AnalysisConfig* config); const PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_DeletePredictor(PD_Predictor* predictor); PADDLE_CAPI_EXPORT extern void PD_DeletePredictor(PD_Predictor* predictor);
PADDLE_CAPI_EXPORT extern int PD_GetInputNum(const PD_Predictor*); PADDLE_CAPI_EXPORT extern int PD_GetInputNum(const PD_Predictor*);
PADDLE_CAPI_EXPORT extern int PD_GetOutputNum(const PD_Predictor*); PADDLE_CAPI_EXPORT extern int PD_GetOutputNum(const PD_Predictor*);
PADDLE_CAPI_EXPORT extern const char* PD_GetInputName(const PD_Predictor*, int); PADDLE_CAPI_EXPORT extern const char* PD_GetInputName(const PD_Predictor*, int);
PADDLE_CAPI_EXPORT extern const char* PD_GetOutputName(const PD_Predictor*, PADDLE_CAPI_EXPORT extern const char* PD_GetOutputName(const PD_Predictor*,
int); int);
PADDLE_CAPI_EXPORT extern void PD_SetZeroCopyInput( PADDLE_CAPI_EXPORT extern void PD_SetZeroCopyInput(
PD_Predictor* predictor, const PD_ZeroCopyTensor* tensor); PD_Predictor* predictor, const PD_ZeroCopyTensor* tensor);
PADDLE_CAPI_EXPORT extern void PD_GetZeroCopyOutput(PD_Predictor* predictor, PADDLE_CAPI_EXPORT extern void PD_GetZeroCopyOutput(PD_Predictor* predictor,
PD_ZeroCopyTensor* tensor); PD_ZeroCopyTensor* tensor);
PADDLE_CAPI_EXPORT extern void PD_ZeroCopyRun(PD_Predictor* predictor); PADDLE_CAPI_EXPORT extern void PD_ZeroCopyRun(PD_Predictor* predictor);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -180,7 +180,8 @@ PD_Predictor* PD_NewPredictor(const PD_AnalysisConfig* config) { ...@@ -180,7 +180,8 @@ PD_Predictor* PD_NewPredictor(const PD_AnalysisConfig* config) {
} }
void PD_DeletePredictor(PD_Predictor* predictor) { void PD_DeletePredictor(PD_Predictor* predictor) {
if (predictor == nullptr) { if (predictor) {
predictor->predictor = nullptr;
delete predictor; delete predictor;
predictor = nullptr; predictor = nullptr;
} }
......
...@@ -73,11 +73,10 @@ PD_PaddleBuf* PD_GetPaddleTensorData(const PD_Tensor* tensor) { ...@@ -73,11 +73,10 @@ PD_PaddleBuf* PD_GetPaddleTensorData(const PD_Tensor* tensor) {
return ret; return ret;
} }
int* PD_GetPaddleTensorShape(const PD_Tensor* tensor, int** size) { const int* PD_GetPaddleTensorShape(const PD_Tensor* tensor, int* size) {
PADDLE_ENFORCE_NOT_NULL(tensor); PADDLE_ENFORCE_NOT_NULL(tensor);
std::vector<int> shape = tensor->tensor.shape; const std::vector<int>& shape = tensor->tensor.shape;
int s = shape.size(); *size = shape.size();
*size = &s;
return shape.data(); return shape.data();
} }
......
...@@ -93,6 +93,8 @@ TEST(PD_AnalysisConfig, trt_fp16) { ...@@ -93,6 +93,8 @@ TEST(PD_AnalysisConfig, trt_fp16) {
false); false);
bool trt_enable = PD_TensorrtEngineEnabled(config); bool trt_enable = PD_TensorrtEngineEnabled(config);
CHECK(trt_enable) << "NO"; CHECK(trt_enable) << "NO";
PD_Predictor *predictor = PD_NewPredictor(config);
PD_DeletePredictor(predictor);
PD_DeleteAnalysisConfig(config); PD_DeleteAnalysisConfig(config);
} }
......
...@@ -67,8 +67,14 @@ void PD_run() { ...@@ -67,8 +67,14 @@ void PD_run() {
float* result = static_cast<float*>(PD_PaddleBufData(b)); float* result = static_cast<float*>(PD_PaddleBufData(b));
LOG(INFO) << *result; LOG(INFO) << *result;
PD_DeletePaddleTensor(input); PD_DeletePaddleTensor(input);
int* size; int size;
PD_GetPaddleTensorShape(out_data, &size); const int* out_shape = PD_GetPaddleTensorShape(out_data, &size);
CHECK(size == 2) << "The Output shape's size is NOT match.";
std::vector<int> ref_outshape_size({9, 6});
for (int i = 0; i < 2; ++i) {
CHECK(out_shape[i] == ref_outshape_size[i])
<< "The Output's shape is NOT match.";
}
PD_DeletePaddleBuf(buf); PD_DeletePaddleBuf(buf);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册