未验证 提交 74eb82de 编写于 作者: F flame 提交者: GitHub

fix go api bug (#22669)

上级 14672a63
...@@ -29,6 +29,7 @@ int main(int argc, char *argv[]) { ...@@ -29,6 +29,7 @@ int main(int argc, char *argv[]) {
printf("Output num: %d\n", output_num); printf("Output num: %d\n", output_num);
PD_ZeroCopyTensor input; PD_ZeroCopyTensor input;
PD_InitZeroCopyTensor(&input);
input.name = const_cast<char *>(PD_GetInputName(predictor, 0)); // NOLINT input.name = const_cast<char *>(PD_GetInputName(predictor, 0)); // NOLINT
input.data.capacity = sizeof(float) * 1 * 3 * 300 * 300; input.data.capacity = sizeof(float) * 1 * 3 * 300 * 300;
input.data.length = input.data.capacity; input.data.length = input.data.capacity;
...@@ -48,6 +49,7 @@ int main(int argc, char *argv[]) { ...@@ -48,6 +49,7 @@ int main(int argc, char *argv[]) {
PD_InitZeroCopyTensor(&output); PD_InitZeroCopyTensor(&output);
output.name = const_cast<char *>(PD_GetOutputName(predictor, 0)); // NOLINT output.name = const_cast<char *>(PD_GetOutputName(predictor, 0)); // NOLINT
PD_GetZeroCopyOutput(predictor, &output); PD_GetZeroCopyOutput(predictor, &output);
PD_DestroyZeroCopyTensor(&output); PD_DestroyZeroCopyTensor(&output);
PD_DeleteAnalysisConfig(config); PD_DeleteAnalysisConfig(config);
......
...@@ -35,13 +35,13 @@ int main(int argc, char *argv[]) { ...@@ -35,13 +35,13 @@ int main(int argc, char *argv[]) {
input->copy_from_cpu(data.data()); input->copy_from_cpu(data.data());
predictor->ZeroCopyRun(); predictor->ZeroCopyRun();
auto output_name = predictor->GetOutputNames()[0]; auto output_name = predictor->GetOutputNames()[0];
output = predictor->GetOutputTensor(output_name); auto output = predictor->GetOutputTensor(output_name);
return 0; return 0;
} }
void SetConfig(paddle::AnalysisConfig *config) { void SetConfig(paddle::AnalysisConfig *config) {
config->SetModel("data/model/__model__", "data/model/__params__"); config->SetModel("data/model/__model__", "data/model/__params__");
config->SwitchUseFeedFetchOps(true); config->SwitchUseFeedFetchOps(false);
config->SwitchSpecifyInputNames(true); config->SwitchSpecifyInputNames(true);
config->SwitchIrOptim(false); config->SwitchIrOptim(false);
} }
...@@ -21,6 +21,14 @@ package paddle ...@@ -21,6 +21,14 @@ package paddle
import "C" import "C"
import "fmt" import "fmt"
type Precision C.Precision
const (
kFloat32 Precision = C.kFloat32
kInt8 Precision = C.kInt8
kHalf Precision = C.kHalf
)
func ConvertCBooleanToGo(b C.bool) bool { func ConvertCBooleanToGo(b C.bool) bool {
var c_false C.bool var c_false C.bool
if b != c_false { if b != c_false {
......
...@@ -43,8 +43,13 @@ func (config *AnalysisConfig) SetModel(model, params string) { ...@@ -43,8 +43,13 @@ func (config *AnalysisConfig) SetModel(model, params string) {
//C.printString((*C.char)(unsafe.Pointer(&s[0]))) //C.printString((*C.char)(unsafe.Pointer(&s[0])))
c_model := C.CString(model) c_model := C.CString(model)
defer C.free(unsafe.Pointer(c_model)) defer C.free(unsafe.Pointer(c_model))
c_params := C.CString(params) var c_params *C.char
defer C.free(unsafe.Pointer(c_params)) if params == "" {
c_params = nil
} else {
c_params = C.CString(params)
defer C.free(unsafe.Pointer(c_params))
}
C.PD_SetModel(config.c, c_model, c_params) C.PD_SetModel(config.c, c_model, c_params)
} }
...@@ -61,8 +66,8 @@ func (config *AnalysisConfig) ParamsFile() string { ...@@ -61,8 +66,8 @@ func (config *AnalysisConfig) ParamsFile() string {
return C.GoString(C.PD_ParamsFile(config.c)) return C.GoString(C.PD_ParamsFile(config.c))
} }
func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb uint64, device_id int) { func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb int, device_id int) {
C.PD_EnableUseGpu(config.c, C.ulong(memory_pool_init_size_mb), C.int(device_id)) C.PD_EnableUseGpu(config.c, C.int(memory_pool_init_size_mb), C.int(device_id))
} }
func (config *AnalysisConfig) DisableGpu() { func (config *AnalysisConfig) DisableGpu() {
...@@ -113,7 +118,9 @@ func (config *AnalysisConfig) SpecifyInputName() bool { ...@@ -113,7 +118,9 @@ func (config *AnalysisConfig) SpecifyInputName() bool {
return ConvertCBooleanToGo(C.PD_SpecifyInputName(config.c)) return ConvertCBooleanToGo(C.PD_SpecifyInputName(config.c))
} }
//func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int) func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int, max_batch_size int, min_subgraph_size int, precision Precision, use_static bool, use_calib_mode bool) {
C.PD_EnableTensorRtEngine(config.c, C.int(workspace_size), C.int(max_batch_size), C.int(min_subgraph_size), C.Precision(precision), C.bool(use_static), C.bool(use_calib_mode))
}
func (config *AnalysisConfig) TensorrtEngineEnabled() bool { func (config *AnalysisConfig) TensorrtEngineEnabled() bool {
return ConvertCBooleanToGo(C.PD_TensorrtEngineEnabled(config.c)) return ConvertCBooleanToGo(C.PD_TensorrtEngineEnabled(config.c))
...@@ -175,15 +182,15 @@ func (config *AnalysisConfig) DisableGlogInfo() { ...@@ -175,15 +182,15 @@ func (config *AnalysisConfig) DisableGlogInfo() {
} }
func (config *AnalysisConfig) DeletePass(pass string) { func (config *AnalysisConfig) DeletePass(pass string) {
c_pass := C.CString(pass) c_pass := C.CString(pass)
defer C.free(unsafe.Pointer(c_pass)) defer C.free(unsafe.Pointer(c_pass))
C.PD_DeletePass(config.c, c_pass) C.PD_DeletePass(config.c, c_pass)
} }
func (config *AnalysisConfig) SetInValid() { func (config *AnalysisConfig) SetInValid() {
C.PD_SetInValid(config.c) C.PD_SetInValid(config.c)
} }
func (config *AnalysisConfig) IsValid() bool { func (config *AnalysisConfig) IsValid() bool {
return ConvertCBooleanToGo(C.PD_IsValid(config.c)) return ConvertCBooleanToGo(C.PD_IsValid(config.c))
} }
...@@ -161,9 +161,9 @@ PADDLE_CAPI_EXPORT extern const char* PD_ProgFile( ...@@ -161,9 +161,9 @@ PADDLE_CAPI_EXPORT extern const char* PD_ProgFile(
PADDLE_CAPI_EXPORT extern const char* PD_ParamsFile( PADDLE_CAPI_EXPORT extern const char* PD_ParamsFile(
const PD_AnalysisConfig* config); const PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu( PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu(PD_AnalysisConfig* config,
PD_AnalysisConfig* config, uint64_t memory_pool_init_size_mb, int memory_pool_init_size_mb,
int device_id); int device_id);
PADDLE_CAPI_EXPORT extern void PD_DisableGpu(PD_AnalysisConfig* config); PADDLE_CAPI_EXPORT extern void PD_DisableGpu(PD_AnalysisConfig* config);
......
...@@ -79,10 +79,11 @@ const char* PD_ParamsFile(const PD_AnalysisConfig* config) { ...@@ -79,10 +79,11 @@ const char* PD_ParamsFile(const PD_AnalysisConfig* config) {
return config->config.params_file().c_str(); return config->config.params_file().c_str();
} }
void PD_EnableUseGpu(PD_AnalysisConfig* config, void PD_EnableUseGpu(PD_AnalysisConfig* config, int memory_pool_init_size_mb,
uint64_t memory_pool_init_size_mb, int device_id) { int device_id) {
PADDLE_ENFORCE_NOT_NULL(config); PADDLE_ENFORCE_NOT_NULL(config);
config->config.EnableUseGpu(memory_pool_init_size_mb, device_id); config->config.EnableUseGpu(static_cast<uint64_t>(memory_pool_init_size_mb),
device_id);
} }
void PD_DisableGpu(PD_AnalysisConfig* config) { void PD_DisableGpu(PD_AnalysisConfig* config) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册