未验证 提交 70b67f10 编写于 作者: W Wilber 提交者: GitHub

fix go api bug. (#31857)

上级 e804f085
......@@ -50,6 +50,7 @@ output_data := value.Interface().([][]float32)
运行
```bash
go mod init github.com/paddlepaddle
export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH
go run ./demo/mobilenet.go
```
......@@ -13,7 +13,7 @@
// limitations under the License.
package main
import "../paddle"
import "github.com/paddlepaddle/paddle"
import "strings"
import "io/ioutil"
import "strconv"
......
......@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include <paddle_c_api.h>
import "C"
......
......@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include <stdlib.h>
// #include <paddle_c_api.h>
......
......@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include "paddle_c_api.h"
import "C"
......@@ -88,7 +88,7 @@ func (predictor *Predictor) GetInputNames() []string {
}
func (predictor *Predictor) GetOutputNames() []string {
names := make([]string, predictor.GetInputNum())
names := make([]string, predictor.GetOutputNum())
for i := 0; i < len(names); i++ {
names[i] = predictor.GetOutputName(i)
}
......
......@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include <stdbool.h>
// #include <stdlib.h>
// #include <string.h>
......@@ -209,7 +209,7 @@ func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Va
value := reflect.Indirect(ptr)
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
if len(shape) == 1 && value.Len() > 0 {
switch value.Index(1).Kind() {
switch value.Index(0).Kind() {
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
binary.Read(r, Endian(), value.Interface())
return
......
......@@ -207,13 +207,16 @@ int PD_GetOutputNum(const PD_Predictor* predictor) {
}
const char* PD_GetInputName(const PD_Predictor* predictor, int n) {
static std::vector<std::string> names = predictor->predictor->GetInputNames();
static std::vector<std::string> names;
names.resize(predictor->predictor->GetInputNames().size());
names[n] = predictor->predictor->GetInputNames()[n];
return names[n].c_str();
}
const char* PD_GetOutputName(const PD_Predictor* predictor, int n) {
static std::vector<std::string> names =
predictor->predictor->GetOutputNames();
static std::vector<std::string> names;
names.resize(predictor->predictor->GetOutputNames().size());
names[n] = predictor->predictor->GetOutputNames()[n];
return names[n].c_str();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册