serving_client.go 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

package main

import (
       "bytes"
       "encoding/json"
       "io/ioutil"
       "log"
       "net/http"
       "encoding/binary"
       "fmt"
)

type Tensor struct {
     IntData	   []int64 `json:"int_data"`
     ElemType	int `json:"elem_type"`
     Shape	[]int `json:"shape"`
}

type FeedInst struct {
     TensorArray     []Tensor `json:"tensor_array"`
}

type FetchInst struct {
     TensorArray      []Tensor `json:"tensor_array"`
}

type Request struct {
     Insts   []FeedInst `json:"insts"`
     FetchVarNames	[]string `json:"fetch_var_names"`
}

type Response struct {
     Insts    []FetchInst `json:"insts"`
}

func LoadModelConfig(config string) Handle {
     in, err := ioutil.ReadFile(config)
     if err != nil {
     	log.Fatalln("Failed to read general model: ", err)
     }
     general_model_config := &pb.GeneralConfig{}
     if err := proto.Unmarshal(in, general_model_config); err != nil {
     	log.Fatalln("Failed to parse GeneralConfig: ", err)
     }
     fmt.Println("Read message done.")
}

func Connect(url string, port int, handle Handle) {
     handle.Url = url
     handle.Port = port
}

func Predict(handle Handle, int_feed_map Map[string]int[],
             float_feed_map Map[string]float[], fetch []string) {
     contentType := "application/json;charset=utf-8"

     var inst FeedInst
     req := &Request{
     	 Insts: []FeedInst{inst},
	 FetchVarNames: []string{fetch}}

     b, err := json.Marshal(req)
     fmt.Println(string(b))

     body := bytes.NewBuffer(b)

     resp, err := http.Post(url, contentType, body)
     if err != nil {
     	log.Println("Post failed:", err)
	return
     }

     defer resp.Body.Close()

     content, err := ioutil.ReadAll(resp.Body)
     if err != nil {
     	log.Println("Read failed:", err)
	return
     }

     log.Println("content:", string(content))
}