From 4c5418f972081d81e4b738b0ed5cbcd7b32e6e68 Mon Sep 17 00:00:00 2001 From: guru4elephant Date: Thu, 13 Feb 2020 16:51:16 +0800 Subject: [PATCH] add general model config --- go/serving_client/serving_client.go | 40 ++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/go/serving_client/serving_client.go b/go/serving_client/serving_client.go index a85d0b14..919a5b85 100644 --- a/go/serving_client/serving_client.go +++ b/go/serving_client/serving_client.go @@ -22,6 +22,7 @@ import ( "net/http" "encoding/binary" "fmt" + "general_model_config" ) type Tensor struct { @@ -47,6 +48,16 @@ type Response struct { Insts []FetchInst `json:"insts"` } +type Handle struct { + Url string + Port int + FeedAliasNameMap Map[string]string + FeedShapeMap Map[string][]int + FeedNameMap Map[string]int + FetchNameMap Map[string]int + FetchAliasNameMap Map[string]string +} + func LoadModelConfig(config string) Handle { in, err := ioutil.ReadFile(config) if err != nil { @@ -56,19 +67,40 @@ func LoadModelConfig(config string) Handle { if err := proto.Unmarshal(in, general_model_config); err != nil { log.Fatalln("Failed to parse GeneralConfig: ", err) } - fmt.Println("Read message done.") + handle := Handle{} + for i, v in range general_model_config.FeedVar { + handle.FeedNameMap[v.Name] = i + handle.FeedShapeMap[v.Name] = v + handle.FeedAliasNameMap[v.AliasName] = v.Name + } + + for i, v in range general_model_config.FetchVar { + handle.FetchNameMap[v.Name] = i + handle.FetchAliasNameMap[v.AliasName] = v.Name + } + + return handle } -func Connect(url string, port int, handle Handle) { +func Connect(url string, port int, handle Handle) Handle { handle.Url = url handle.Port = port + return handle } -func Predict(handle Handle, int_feed_map Map[string]int[], - float_feed_map Map[string]float[], fetch []string) { +func Predict(handle Handle, int_feed_map Map[string]int{}, fetch []string) { contentType := "application/json;charset=utf-8" + var tensor_array []Tensor{} var inst FeedInst + for k, v := range int_feed_map { + var tmp Tensor + tmp.Data = int_feed_map[k] + tmp.ElemType = 0 + tmp.Shape = handle.FeedNameShape[k] + tensor_array = append(tensor_array, tmp) + } + req := &Request{ Insts: []FeedInst{inst}, FetchVarNames: []string{fetch}} -- GitLab