未验证 提交 70b67f10 编写于 作者: W Wilber 提交者: GitHub

fix go api bug. (#31857)

上级 e804f085
...@@ -50,6 +50,7 @@ output_data := value.Interface().([][]float32) ...@@ -50,6 +50,7 @@ output_data := value.Interface().([][]float32)
运行 运行
```bash ```bash
go mod init github.com/paddlepaddle
export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH
go run ./demo/mobilenet.go go run ./demo/mobilenet.go
``` ```
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
package main package main
import "../paddle" import "github.com/paddlepaddle/paddle"
import "strings" import "strings"
import "io/ioutil" import "io/ioutil"
import "strconv" import "strconv"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include <paddle_c_api.h> // #include <paddle_c_api.h>
import "C" import "C"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include <stdlib.h> // #include <stdlib.h>
// #include <paddle_c_api.h> // #include <paddle_c_api.h>
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include "paddle_c_api.h" // #include "paddle_c_api.h"
import "C" import "C"
...@@ -88,7 +88,7 @@ func (predictor *Predictor) GetInputNames() []string { ...@@ -88,7 +88,7 @@ func (predictor *Predictor) GetInputNames() []string {
} }
func (predictor *Predictor) GetOutputNames() []string { func (predictor *Predictor) GetOutputNames() []string {
names := make([]string, predictor.GetInputNum()) names := make([]string, predictor.GetOutputNum())
for i := 0; i < len(names); i++ { for i := 0; i < len(names); i++ {
names[i] = predictor.GetOutputName(i) names[i] = predictor.GetOutputName(i)
} }
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
package paddle package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include // #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c // #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h> // #include <stdbool.h>
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>
...@@ -209,7 +209,7 @@ func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Va ...@@ -209,7 +209,7 @@ func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Va
value := reflect.Indirect(ptr) value := reflect.Indirect(ptr)
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0]))) value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
if len(shape) == 1 && value.Len() > 0 { if len(shape) == 1 && value.Len() > 0 {
switch value.Index(1).Kind() { switch value.Index(0).Kind() {
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32: case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
binary.Read(r, Endian(), value.Interface()) binary.Read(r, Endian(), value.Interface())
return return
......
...@@ -207,13 +207,16 @@ int PD_GetOutputNum(const PD_Predictor* predictor) { ...@@ -207,13 +207,16 @@ int PD_GetOutputNum(const PD_Predictor* predictor) {
} }
const char* PD_GetInputName(const PD_Predictor* predictor, int n) { const char* PD_GetInputName(const PD_Predictor* predictor, int n) {
static std::vector<std::string> names = predictor->predictor->GetInputNames(); static std::vector<std::string> names;
names.resize(predictor->predictor->GetInputNames().size());
names[n] = predictor->predictor->GetInputNames()[n];
return names[n].c_str(); return names[n].c_str();
} }
const char* PD_GetOutputName(const PD_Predictor* predictor, int n) { const char* PD_GetOutputName(const PD_Predictor* predictor, int n) {
static std::vector<std::string> names = static std::vector<std::string> names;
predictor->predictor->GetOutputNames(); names.resize(predictor->predictor->GetOutputNames().size());
names[n] = predictor->predictor->GetOutputNames()[n];
return names[n].c_str(); return names[n].c_str();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册