serving_client_api.go 4.6 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package serving_client

import (
       "bytes"
       "encoding/json"
       "io/ioutil"
       "log"
       "net/http"
D
Dong Daxiang 已提交
23
       pb "github.com/PaddlePaddle/Serving/go/proto"
G
guru4elephant 已提交
24 25 26 27
       "github.com/golang/protobuf/proto"
)

type Tensor struct {
G
guru4elephant 已提交
28 29
     Data   []byte `json:"data"`
     FloatData	   []float32 `json:"float_data"`
G
guru4elephant 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
     IntData	   []int64 `json:"int_data"`
     ElemType	int `json:"elem_type"`
     Shape	[]int `json:"shape"`
}

type FeedInst struct {
     TensorArray     []Tensor `json:"tensor_array"`
}

type FetchInst struct {
     TensorArray      []Tensor `json:"tensor_array"`
}

type Request struct {
     Insts   []FeedInst `json:"insts"`
     FetchVarNames	[]string `json:"fetch_var_names"`
}

type Response struct {
     Insts    []FetchInst `json:"insts"`
G
guru4elephant 已提交
50
     MeanInferUs	  float32 `json:"mean_infer_us"`
G
guru4elephant 已提交
51 52 53 54
}

type Handle struct {
     Url    string
G
guru4elephant 已提交
55
     Port   string
G
guru4elephant 已提交
56 57 58
     FeedAliasNameMap	map[string]string
     FeedShapeMap	map[string][]int
     FeedNameMap   map[string]int
G
guru4elephant 已提交
59
     FeedAliasNames	   []string
G
guru4elephant 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72
     FetchNameMap  map[string]int
     FetchAliasNameMap	map[string]string
}

func LoadModelConfig(config string) Handle {
     in, err := ioutil.ReadFile(config)
     if err != nil {
     	log.Fatalln("Failed to read general model: ", err)
     }
     general_model_config := &pb.GeneralModelConfig{}
     if err := proto.Unmarshal(in, general_model_config); err != nil {
     	log.Fatalln("Failed to parse GeneralModelConfig: ", err)
     }
G
guru4elephant 已提交
73
     log.Println("read protobuf succeed")
G
guru4elephant 已提交
74
     handle := Handle{}
G
guru4elephant 已提交
75 76 77 78 79 80 81
     handle.FeedNameMap = map[string]int{}
     handle.FeedAliasNameMap = map[string]string{}
     handle.FeedShapeMap = map[string][]int{}
     handle.FetchNameMap = map[string]int{}
     handle.FetchAliasNameMap = map[string]string{}
     handle.FeedAliasNames = []string{}

G
guru4elephant 已提交
82 83 84 85 86 87 88 89
     for i, v := range general_model_config.FeedVar {
     	 handle.FeedNameMap[*v.Name] = i
	 tmp_array := []int{}
	 for _, vv := range v.Shape {
	     tmp_array = append(tmp_array, int(vv))
	 }
	 handle.FeedShapeMap[*v.Name] = tmp_array
	 handle.FeedAliasNameMap[*v.AliasName] = *v.Name
G
guru4elephant 已提交
90
	 handle.FeedAliasNames = append(handle.FeedAliasNames, *v.AliasName)
G
guru4elephant 已提交
91 92 93 94 95 96 97 98 99 100
     }

     for i, v := range general_model_config.FetchVar {
     	 handle.FetchNameMap[*v.Name] = i
	 handle.FetchAliasNameMap[*v.AliasName] = *v.Name
     }

     return handle
}

G
guru4elephant 已提交
101
func Connect(url string, port string, handle Handle) Handle {
G
guru4elephant 已提交
102 103 104 105 106 107 108 109 110 111
     handle.Url = url
     handle.Port = port
     return handle
}

func Predict(handle Handle, int_feed_map map[string][]int64, fetch []string) map[string][]float32 {
     contentType := "application/json;charset=utf-8"

     var tensor_array []Tensor
     var inst FeedInst
G
guru4elephant 已提交
112 113 114 115 116 117 118 119 120
     tensor_array = []Tensor{}
     inst = FeedInst{}

     for i := 0; i < len(handle.FeedAliasNames); i++ {
     	 key_i := handle.FeedAliasNames[i]
	 var tmp Tensor
	 tmp.IntData = []int64{}
	 tmp.Shape = []int{}
	 tmp.IntData = int_feed_map[key_i]
G
guru4elephant 已提交
121
	 tmp.ElemType = 0
G
guru4elephant 已提交
122
	 tmp.Shape = handle.FeedShapeMap[key_i]
G
guru4elephant 已提交
123 124 125
	 tensor_array = append(tensor_array, tmp)
     }

G
guru4elephant 已提交
126 127
     inst.TensorArray = tensor_array

G
guru4elephant 已提交
128 129 130 131 132 133 134 135
     req := &Request{
     	 Insts: []FeedInst{inst},
	 FetchVarNames: fetch}

     b, err := json.Marshal(req)

     body := bytes.NewBuffer(b)

G
guru4elephant 已提交
136 137 138 139 140 141 142 143
     var post_address bytes.Buffer
     post_address.WriteString("http://")
     post_address.WriteString(handle.Url)
     post_address.WriteString(":")
     post_address.WriteString(handle.Port)
     post_address.WriteString("/GeneralModelService/inference")

     resp, err := http.Post(post_address.String(), contentType, body)
G
guru4elephant 已提交
144 145 146 147 148 149 150 151
     if err != nil {
     	log.Println("Post failed:", err)
     }

     defer resp.Body.Close()

     content, err := ioutil.ReadAll(resp.Body)
     if err != nil {
G
guru4elephant 已提交
152
      	log.Println("Read failed:", err)
G
guru4elephant 已提交
153 154
     }

G
guru4elephant 已提交
155 156
     response_json := Response{}
     err = json.Unmarshal([]byte(content), &response_json)
G
guru4elephant 已提交
157 158

     var result map[string][]float32
G
guru4elephant 已提交
159 160 161 162 163
     result = map[string][]float32{}
     for i, v := range fetch {
     	 result[v] = response_json.Insts[0].TensorArray[i].FloatData
     }
     
G
guru4elephant 已提交
164 165
     return result
}