diff --git a/go/demo/mobilenet_c.cc b/go/demo/mobilenet_c.cc index e20635b0bb7bc617731fa8181ca15555192bb1be..6a5cc683c9f9a9c88f73a3ca5ebac274210f3b7a 100644 --- a/go/demo/mobilenet_c.cc +++ b/go/demo/mobilenet_c.cc @@ -29,6 +29,7 @@ int main(int argc, char *argv[]) { printf("Output num: %d\n", output_num); PD_ZeroCopyTensor input; + PD_InitZeroCopyTensor(&input); input.name = const_cast(PD_GetInputName(predictor, 0)); // NOLINT input.data.capacity = sizeof(float) * 1 * 3 * 300 * 300; input.data.length = input.data.capacity; @@ -48,6 +49,7 @@ int main(int argc, char *argv[]) { PD_InitZeroCopyTensor(&output); output.name = const_cast(PD_GetOutputName(predictor, 0)); // NOLINT PD_GetZeroCopyOutput(predictor, &output); + PD_DestroyZeroCopyTensor(&output); PD_DeleteAnalysisConfig(config); diff --git a/go/demo/mobilenet_cxx.cc b/go/demo/mobilenet_cxx.cc index a3df9a0e930bfe8755899f39f7d868dd1438f55f..7bdd6b2b03b24e2393e746edde754f763e9dd986 100644 --- a/go/demo/mobilenet_cxx.cc +++ b/go/demo/mobilenet_cxx.cc @@ -35,13 +35,13 @@ int main(int argc, char *argv[]) { input->copy_from_cpu(data.data()); predictor->ZeroCopyRun(); auto output_name = predictor->GetOutputNames()[0]; - output = predictor->GetOutputTensor(output_name); + auto output = predictor->GetOutputTensor(output_name); return 0; } void SetConfig(paddle::AnalysisConfig *config) { config->SetModel("data/model/__model__", "data/model/__params__"); - config->SwitchUseFeedFetchOps(true); + config->SwitchUseFeedFetchOps(false); config->SwitchSpecifyInputNames(true); config->SwitchIrOptim(false); } diff --git a/go/paddle/common.go b/go/paddle/common.go index b29efbdf3022b644d42b22cc5f005f12e0775ab9..346dc3f395bfe6ba4ae9fdc716f41a58b2bb4725 100644 --- a/go/paddle/common.go +++ b/go/paddle/common.go @@ -21,6 +21,14 @@ package paddle import "C" import "fmt" +type Precision C.Precision + +const ( + kFloat32 Precision = C.kFloat32 + kInt8 Precision = C.kInt8 + kHalf Precision = C.kHalf +) + func ConvertCBooleanToGo(b C.bool) bool { var c_false C.bool if b != c_false { diff --git a/go/paddle/config.go b/go/paddle/config.go index 524249a067121f0094ca6c215e96aa85963c8132..5480365d58069b44b24fbb027e62e7809e4eb5cf 100644 --- a/go/paddle/config.go +++ b/go/paddle/config.go @@ -43,8 +43,13 @@ func (config *AnalysisConfig) SetModel(model, params string) { //C.printString((*C.char)(unsafe.Pointer(&s[0]))) c_model := C.CString(model) defer C.free(unsafe.Pointer(c_model)) - c_params := C.CString(params) - defer C.free(unsafe.Pointer(c_params)) + var c_params *C.char + if params == "" { + c_params = nil + } else { + c_params = C.CString(params) + defer C.free(unsafe.Pointer(c_params)) + } C.PD_SetModel(config.c, c_model, c_params) } @@ -61,8 +66,8 @@ func (config *AnalysisConfig) ParamsFile() string { return C.GoString(C.PD_ParamsFile(config.c)) } -func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb uint64, device_id int) { - C.PD_EnableUseGpu(config.c, C.ulong(memory_pool_init_size_mb), C.int(device_id)) +func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb int, device_id int) { + C.PD_EnableUseGpu(config.c, C.int(memory_pool_init_size_mb), C.int(device_id)) } func (config *AnalysisConfig) DisableGpu() { @@ -113,7 +118,9 @@ func (config *AnalysisConfig) SpecifyInputName() bool { return ConvertCBooleanToGo(C.PD_SpecifyInputName(config.c)) } -//func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int) +func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int, max_batch_size int, min_subgraph_size int, precision Precision, use_static bool, use_calib_mode bool) { + C.PD_EnableTensorRtEngine(config.c, C.int(workspace_size), C.int(max_batch_size), C.int(min_subgraph_size), C.Precision(precision), C.bool(use_static), C.bool(use_calib_mode)) +} func (config *AnalysisConfig) TensorrtEngineEnabled() bool { return ConvertCBooleanToGo(C.PD_TensorrtEngineEnabled(config.c)) @@ -175,15 +182,15 @@ func (config *AnalysisConfig) DisableGlogInfo() { } func (config *AnalysisConfig) DeletePass(pass string) { - c_pass := C.CString(pass) - defer C.free(unsafe.Pointer(c_pass)) - C.PD_DeletePass(config.c, c_pass) + c_pass := C.CString(pass) + defer C.free(unsafe.Pointer(c_pass)) + C.PD_DeletePass(config.c, c_pass) } func (config *AnalysisConfig) SetInValid() { - C.PD_SetInValid(config.c) + C.PD_SetInValid(config.c) } func (config *AnalysisConfig) IsValid() bool { - return ConvertCBooleanToGo(C.PD_IsValid(config.c)) + return ConvertCBooleanToGo(C.PD_IsValid(config.c)) } diff --git a/paddle/fluid/inference/capi/paddle_c_api.h b/paddle/fluid/inference/capi/paddle_c_api.h index 87fb515dd6a6a25125c170188cce792951f0ccfd..e9157c06dec3659e3287d8157b6b29a10a7830da 100644 --- a/paddle/fluid/inference/capi/paddle_c_api.h +++ b/paddle/fluid/inference/capi/paddle_c_api.h @@ -161,9 +161,9 @@ PADDLE_CAPI_EXPORT extern const char* PD_ProgFile( PADDLE_CAPI_EXPORT extern const char* PD_ParamsFile( const PD_AnalysisConfig* config); -PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu( - PD_AnalysisConfig* config, uint64_t memory_pool_init_size_mb, - int device_id); +PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu(PD_AnalysisConfig* config, + int memory_pool_init_size_mb, + int device_id); PADDLE_CAPI_EXPORT extern void PD_DisableGpu(PD_AnalysisConfig* config); diff --git a/paddle/fluid/inference/capi/pd_config.cc b/paddle/fluid/inference/capi/pd_config.cc index 368489fde29cc0f51798e8ec7aebf413628c3403..458637377364f92fe976a90fbe7d3e953aed53da 100644 --- a/paddle/fluid/inference/capi/pd_config.cc +++ b/paddle/fluid/inference/capi/pd_config.cc @@ -79,10 +79,11 @@ const char* PD_ParamsFile(const PD_AnalysisConfig* config) { return config->config.params_file().c_str(); } -void PD_EnableUseGpu(PD_AnalysisConfig* config, - uint64_t memory_pool_init_size_mb, int device_id) { +void PD_EnableUseGpu(PD_AnalysisConfig* config, int memory_pool_init_size_mb, + int device_id) { PADDLE_ENFORCE_NOT_NULL(config); - config->config.EnableUseGpu(memory_pool_init_size_mb, device_id); + config->config.EnableUseGpu(static_cast(memory_pool_init_size_mb), + device_id); } void PD_DisableGpu(PD_AnalysisConfig* config) {