predictor.go 2.6 KB
Newer Older
L
LKKlein 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
package paddle

// #include <stdbool.h>
// #include "paddle_c_api.h"
import "C"

import (
	"reflect"
	"runtime"
	"unsafe"
)

type Predictor struct {
	c *C.PD_Predictor
}

func NewPredictor(config *AnalysisConfig) *Predictor {
	c_predictor := C.PD_NewPredictor((*config).c)
	predictor := &Predictor{c: c_predictor}
	runtime.SetFinalizer(predictor, (*Predictor).finalize)
	return predictor
}

func (predictor *Predictor) finalize() {
	C.PD_DeletePredictor(predictor.c)
}

func DeletePredictor(predictor *Predictor) {
	C.PD_DeletePredictor(predictor.c)
}

func (predictor *Predictor) GetInputNum() int {
	return int(C.PD_GetInputNum(predictor.c))
}

func (predictor *Predictor) GetOutputNum() int {
	return int(C.PD_GetOutputNum(predictor.c))
}

func (predictor *Predictor) GetInputName(n int) string {
	return C.GoString(C.PD_GetInputName(predictor.c, C.int(n)))
}

func (predictor *Predictor) GetOutputName(n int) string {
	return C.GoString(C.PD_GetOutputName(predictor.c, C.int(n)))
}

func (predictor *Predictor) GetInputTensors() [](*ZeroCopyTensor) {
	var result [](*ZeroCopyTensor)
	for i := 0; i < predictor.GetInputNum(); i++ {
		tensor := NewZeroCopyTensor()
		tensor.c.name = C.PD_GetInputName(predictor.c, C.int(i))
		result = append(result, tensor)
	}
	return result
}

func (predictor *Predictor) GetOutputTensors() [](*ZeroCopyTensor) {
	var result [](*ZeroCopyTensor)
	for i := 0; i < predictor.GetOutputNum(); i++ {
		tensor := NewZeroCopyTensor()
		tensor.c.name = C.PD_GetOutputName(predictor.c, C.int(i))
		result = append(result, tensor)
	}
	return result
}

func (predictor *Predictor) GetInputNames() []string {
	names := make([]string, predictor.GetInputNum())
	for i := 0; i < len(names); i++ {
		names[i] = predictor.GetInputName(i)
	}
	return names
}

func (predictor *Predictor) GetOutputNames() []string {
	names := make([]string, predictor.GetOutputNum())
	for i := 0; i < len(names); i++ {
		names[i] = predictor.GetOutputName(i)
	}
	return names
}

func (predictor *Predictor) SetZeroCopyInput(tensor *ZeroCopyTensor) {
	C.PD_SetZeroCopyInput(predictor.c, tensor.c)
}

func (predictor *Predictor) GetZeroCopyOutput(tensor *ZeroCopyTensor) {
	C.PD_GetZeroCopyOutput(predictor.c, tensor.c)
	tensor.name = C.GoString(tensor.c.name)
	var shape []int32
	shape_hdr := (*reflect.SliceHeader)(unsafe.Pointer(&shape))
	shape_hdr.Data = uintptr(unsafe.Pointer(tensor.c.shape.data))
	shape_hdr.Len = int(tensor.c.shape.length / C.sizeof_int)
	shape_hdr.Cap = int(tensor.c.shape.length / C.sizeof_int)
	tensor.Reshape(shape)
}

func (predictor *Predictor) ZeroCopyRun() {
	C.PD_ZeroCopyRun(predictor.c)
}