提交 0a86cba1 编写于 作者: G guru4elephant

add go client api and imdb application

上级 1a860698
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package main
import (
"os"
"serving_client"
"path"
"path/filepath"
)
func main() {
config_file_path string = os.Args[1]
handle = serving_client.LoadModelConfig(config_file_path)
serving_client.Connect("127.0.0.1", 9292, handle)
test_file_path string = os.Args[2]
fi, err := os.Open(test_file_path)
if err != nil {
fmt.Print(err)
}
defer fi.Close()
br := bufio.NewReader(fi)
fetch []int = {"cost", "acc", "prediction"}
result Map(string)[]float
for {
line, err := br.ReadString('\n')
if err == io.EOF {
break
}
var words = []int{}
var label = []int{}
s := strings.Split(line, " ")
value, err := strconv.Atoi(s[0])
feed_int_map Map(string)[]int
feed_float_map Map(string)[]float
for _, v := range s[1:value + 1] {
int_v := strconv.Atoi(v)
words = append(words, int_v)
}
label = append(label, strconv.Atoi(s[len(s)-1]))
feed_int_map["words"] = words
feed_int_map["label"] = label
result = serving_client.Predict(
feed_int_map, feed_float_map, fetch, handle)
}
}
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"log"
"net/http"
"encoding/binary"
"fmt"
)
type Tensor struct {
IntData []int64 `json:"int_data"`
ElemType int `json:"elem_type"`
Shape []int `json:"shape"`
}
type FeedInst struct {
TensorArray []Tensor `json:"tensor_array"`
}
type FetchInst struct {
TensorArray []Tensor `json:"tensor_array"`
}
type Request struct {
Insts []FeedInst `json:"insts"`
FetchVarNames []string `json:"fetch_var_names"`
}
type Response struct {
Insts []FetchInst `json:"insts"`
}
func LoadModelConfig(config string) Handle {
in, err := ioutil.ReadFile(config)
if err != nil {
log.Fatalln("Failed to read general model: ", err)
}
general_model_config := &pb.GeneralConfig{}
if err := proto.Unmarshal(in, general_model_config); err != nil {
log.Fatalln("Failed to parse GeneralConfig: ", err)
}
fmt.Println("Read message done.")
}
func Connect(url string, port int, handle Handle) {
handle.Url = url
handle.Port = port
}
func Predict(handle Handle, int_feed_map Map[string]int[],
float_feed_map Map[string]float[], fetch []string) {
contentType := "application/json;charset=utf-8"
var inst FeedInst
req := &Request{
Insts: []FeedInst{inst},
FetchVarNames: []string{fetch}}
b, err := json.Marshal(req)
fmt.Println(string(b))
body := bytes.NewBuffer(b)
resp, err := http.Post(url, contentType, body)
if err != nil {
log.Println("Post failed:", err)
return
}
defer resp.Body.Close()
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Println("Read failed:", err)
return
}
log.Println("content:", string(content))
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册