未验证 提交 74eb82de 编写于 作者: F flame 提交者: GitHub

fix go api bug (#22669)

上级 14672a63
......@@ -29,6 +29,7 @@ int main(int argc, char *argv[]) {
printf("Output num: %d\n", output_num);
PD_ZeroCopyTensor input;
PD_InitZeroCopyTensor(&input);
input.name = const_cast<char *>(PD_GetInputName(predictor, 0)); // NOLINT
input.data.capacity = sizeof(float) * 1 * 3 * 300 * 300;
input.data.length = input.data.capacity;
......@@ -48,6 +49,7 @@ int main(int argc, char *argv[]) {
PD_InitZeroCopyTensor(&output);
output.name = const_cast<char *>(PD_GetOutputName(predictor, 0)); // NOLINT
PD_GetZeroCopyOutput(predictor, &output);
PD_DestroyZeroCopyTensor(&output);
PD_DeleteAnalysisConfig(config);
......
......@@ -35,13 +35,13 @@ int main(int argc, char *argv[]) {
input->copy_from_cpu(data.data());
predictor->ZeroCopyRun();
auto output_name = predictor->GetOutputNames()[0];
output = predictor->GetOutputTensor(output_name);
auto output = predictor->GetOutputTensor(output_name);
return 0;
}
void SetConfig(paddle::AnalysisConfig *config) {
config->SetModel("data/model/__model__", "data/model/__params__");
config->SwitchUseFeedFetchOps(true);
config->SwitchUseFeedFetchOps(false);
config->SwitchSpecifyInputNames(true);
config->SwitchIrOptim(false);
}
......@@ -21,6 +21,14 @@ package paddle
import "C"
import "fmt"
type Precision C.Precision
const (
kFloat32 Precision = C.kFloat32
kInt8 Precision = C.kInt8
kHalf Precision = C.kHalf
)
func ConvertCBooleanToGo(b C.bool) bool {
var c_false C.bool
if b != c_false {
......
......@@ -43,8 +43,13 @@ func (config *AnalysisConfig) SetModel(model, params string) {
//C.printString((*C.char)(unsafe.Pointer(&s[0])))
c_model := C.CString(model)
defer C.free(unsafe.Pointer(c_model))
c_params := C.CString(params)
defer C.free(unsafe.Pointer(c_params))
var c_params *C.char
if params == "" {
c_params = nil
} else {
c_params = C.CString(params)
defer C.free(unsafe.Pointer(c_params))
}
C.PD_SetModel(config.c, c_model, c_params)
}
......@@ -61,8 +66,8 @@ func (config *AnalysisConfig) ParamsFile() string {
return C.GoString(C.PD_ParamsFile(config.c))
}
func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb uint64, device_id int) {
C.PD_EnableUseGpu(config.c, C.ulong(memory_pool_init_size_mb), C.int(device_id))
func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb int, device_id int) {
C.PD_EnableUseGpu(config.c, C.int(memory_pool_init_size_mb), C.int(device_id))
}
func (config *AnalysisConfig) DisableGpu() {
......@@ -113,7 +118,9 @@ func (config *AnalysisConfig) SpecifyInputName() bool {
return ConvertCBooleanToGo(C.PD_SpecifyInputName(config.c))
}
//func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int)
func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int, max_batch_size int, min_subgraph_size int, precision Precision, use_static bool, use_calib_mode bool) {
C.PD_EnableTensorRtEngine(config.c, C.int(workspace_size), C.int(max_batch_size), C.int(min_subgraph_size), C.Precision(precision), C.bool(use_static), C.bool(use_calib_mode))
}
func (config *AnalysisConfig) TensorrtEngineEnabled() bool {
return ConvertCBooleanToGo(C.PD_TensorrtEngineEnabled(config.c))
......@@ -175,15 +182,15 @@ func (config *AnalysisConfig) DisableGlogInfo() {
}
func (config *AnalysisConfig) DeletePass(pass string) {
c_pass := C.CString(pass)
defer C.free(unsafe.Pointer(c_pass))
C.PD_DeletePass(config.c, c_pass)
c_pass := C.CString(pass)
defer C.free(unsafe.Pointer(c_pass))
C.PD_DeletePass(config.c, c_pass)
}
func (config *AnalysisConfig) SetInValid() {
C.PD_SetInValid(config.c)
C.PD_SetInValid(config.c)
}
func (config *AnalysisConfig) IsValid() bool {
return ConvertCBooleanToGo(C.PD_IsValid(config.c))
return ConvertCBooleanToGo(C.PD_IsValid(config.c))
}
......@@ -161,9 +161,9 @@ PADDLE_CAPI_EXPORT extern const char* PD_ProgFile(
PADDLE_CAPI_EXPORT extern const char* PD_ParamsFile(
const PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu(
PD_AnalysisConfig* config, uint64_t memory_pool_init_size_mb,
int device_id);
PADDLE_CAPI_EXPORT extern void PD_EnableUseGpu(PD_AnalysisConfig* config,
int memory_pool_init_size_mb,
int device_id);
PADDLE_CAPI_EXPORT extern void PD_DisableGpu(PD_AnalysisConfig* config);
......
......@@ -79,10 +79,11 @@ const char* PD_ParamsFile(const PD_AnalysisConfig* config) {
return config->config.params_file().c_str();
}
void PD_EnableUseGpu(PD_AnalysisConfig* config,
uint64_t memory_pool_init_size_mb, int device_id) {
void PD_EnableUseGpu(PD_AnalysisConfig* config, int memory_pool_init_size_mb,
int device_id) {
PADDLE_ENFORCE_NOT_NULL(config);
config->config.EnableUseGpu(memory_pool_init_size_mb, device_id);
config->config.EnableUseGpu(static_cast<uint64_t>(memory_pool_init_size_mb),
device_id);
}
void PD_DisableGpu(PD_AnalysisConfig* config) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册