提交 c0099661 编写于 作者: G guru4elephant

add imdb example for go client

上级 e11adceb
package main
import (
"io"
"os"
"fmt"
"bufio"
"strings"
"strconv"
)
func main() {
score_file := os.Args[1]
fi, err := os.Open(score_file)
if err != nil {
fmt.Print(err)
}
defer fi.Close()
br := bufio.NewReader(fi)
total := int(0)
acc := int(0)
for {
line, err := br.ReadString('\n')
if err == io.EOF {
break
}
line = strings.Trim(line, "\n")
s := strings.Split(line, "\t")
prob_str := strings.Trim(s[0], " ")
label_str := strings.Trim(s[1], " ")
prob, err := strconv.ParseFloat(prob_str, 32)
if err != nil {
panic(err)
}
label, err := strconv.ParseFloat(label_str, 32)
if err != nil {
panic(err)
}
if (prob - 0.5) * (label - 0.5) > 0 {
acc++
}
total++
}
fmt.Println("total num: ", total)
fmt.Println("acc num: ", acc)
fmt.Println("acc: ", float32(acc) / float32(total))
}
\ No newline at end of file
......@@ -16,10 +16,10 @@ package main
import (
"io"
"fmt"
"strings"
"bufio"
"strconv"
"fmt"
"os"
"serving_client"
)
......@@ -28,7 +28,7 @@ func main() {
var config_file_path string
config_file_path = os.Args[1]
handle := serving_client.LoadModelConfig(config_file_path)
serving_client.Connect("127.0.0.1", 9292, handle)
handle = serving_client.Connect("127.0.0.1", "9292", handle)
test_file_path := os.Args[2]
fi, err := os.Open(test_file_path)
......@@ -49,6 +49,8 @@ func main() {
break
}
line = strings.Trim(line, "\n")
var words = []int64{}
s := strings.Split(line, " ")
......@@ -60,13 +62,18 @@ func main() {
words = append(words, int64(int_v))
}
label, _ := strconv.Atoi(s[len(s)-1])
label, err := strconv.Atoi(s[len(s)-1])
if err != nil {
panic(err)
}
feed_int_map = map[string][]int64{}
feed_int_map["words"] = words
feed_int_map["label"] = []int64{int64(label)}
result = serving_client.Predict(handle,
feed_int_map, fetch)
fmt.Println(result["prediction"][0])
fmt.Println(result["prediction"][1], "\t", int64(label))
}
}
\ No newline at end of file
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: general_model_config.proto
package baidu_paddle_serving_configure
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type FeedVar struct {
Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
AliasName *string `protobuf:"bytes,2,opt,name=alias_name,json=aliasName" json:"alias_name,omitempty"`
IsLodTensor *bool `protobuf:"varint,3,opt,name=is_lod_tensor,json=isLodTensor,def=0" json:"is_lod_tensor,omitempty"`
FeedType *int32 `protobuf:"varint,4,opt,name=feed_type,json=feedType,def=0" json:"feed_type,omitempty"`
Shape []int32 `protobuf:"varint,5,rep,name=shape" json:"shape,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *FeedVar) Reset() { *m = FeedVar{} }
func (m *FeedVar) String() string { return proto.CompactTextString(m) }
func (*FeedVar) ProtoMessage() {}
func (*FeedVar) Descriptor() ([]byte, []int) {
return fileDescriptor_efa52beffa29d37a, []int{0}
}
func (m *FeedVar) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_FeedVar.Unmarshal(m, b)
}
func (m *FeedVar) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_FeedVar.Marshal(b, m, deterministic)
}
func (m *FeedVar) XXX_Merge(src proto.Message) {
xxx_messageInfo_FeedVar.Merge(m, src)
}
func (m *FeedVar) XXX_Size() int {
return xxx_messageInfo_FeedVar.Size(m)
}
func (m *FeedVar) XXX_DiscardUnknown() {
xxx_messageInfo_FeedVar.DiscardUnknown(m)
}
var xxx_messageInfo_FeedVar proto.InternalMessageInfo
const Default_FeedVar_IsLodTensor bool = false
const Default_FeedVar_FeedType int32 = 0
func (m *FeedVar) GetName() string {
if m != nil && m.Name != nil {
return *m.Name
}
return ""
}
func (m *FeedVar) GetAliasName() string {
if m != nil && m.AliasName != nil {
return *m.AliasName
}
return ""
}
func (m *FeedVar) GetIsLodTensor() bool {
if m != nil && m.IsLodTensor != nil {
return *m.IsLodTensor
}
return Default_FeedVar_IsLodTensor
}
func (m *FeedVar) GetFeedType() int32 {
if m != nil && m.FeedType != nil {
return *m.FeedType
}
return Default_FeedVar_FeedType
}
func (m *FeedVar) GetShape() []int32 {
if m != nil {
return m.Shape
}
return nil
}
type FetchVar struct {
Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
AliasName *string `protobuf:"bytes,2,opt,name=alias_name,json=aliasName" json:"alias_name,omitempty"`
IsLodTensor *bool `protobuf:"varint,3,opt,name=is_lod_tensor,json=isLodTensor,def=0" json:"is_lod_tensor,omitempty"`
Shape []int32 `protobuf:"varint,4,rep,name=shape" json:"shape,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *FetchVar) Reset() { *m = FetchVar{} }
func (m *FetchVar) String() string { return proto.CompactTextString(m) }
func (*FetchVar) ProtoMessage() {}
func (*FetchVar) Descriptor() ([]byte, []int) {
return fileDescriptor_efa52beffa29d37a, []int{1}
}
func (m *FetchVar) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_FetchVar.Unmarshal(m, b)
}
func (m *FetchVar) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_FetchVar.Marshal(b, m, deterministic)
}
func (m *FetchVar) XXX_Merge(src proto.Message) {
xxx_messageInfo_FetchVar.Merge(m, src)
}
func (m *FetchVar) XXX_Size() int {
return xxx_messageInfo_FetchVar.Size(m)
}
func (m *FetchVar) XXX_DiscardUnknown() {
xxx_messageInfo_FetchVar.DiscardUnknown(m)
}
var xxx_messageInfo_FetchVar proto.InternalMessageInfo
const Default_FetchVar_IsLodTensor bool = false
func (m *FetchVar) GetName() string {
if m != nil && m.Name != nil {
return *m.Name
}
return ""
}
func (m *FetchVar) GetAliasName() string {
if m != nil && m.AliasName != nil {
return *m.AliasName
}
return ""
}
func (m *FetchVar) GetIsLodTensor() bool {
if m != nil && m.IsLodTensor != nil {
return *m.IsLodTensor
}
return Default_FetchVar_IsLodTensor
}
func (m *FetchVar) GetShape() []int32 {
if m != nil {
return m.Shape
}
return nil
}
type GeneralModelConfig struct {
FeedVar []*FeedVar `protobuf:"bytes,1,rep,name=feed_var,json=feedVar" json:"feed_var,omitempty"`
FetchVar []*FetchVar `protobuf:"bytes,2,rep,name=fetch_var,json=fetchVar" json:"fetch_var,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *GeneralModelConfig) Reset() { *m = GeneralModelConfig{} }
func (m *GeneralModelConfig) String() string { return proto.CompactTextString(m) }
func (*GeneralModelConfig) ProtoMessage() {}
func (*GeneralModelConfig) Descriptor() ([]byte, []int) {
return fileDescriptor_efa52beffa29d37a, []int{2}
}
func (m *GeneralModelConfig) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_GeneralModelConfig.Unmarshal(m, b)
}
func (m *GeneralModelConfig) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_GeneralModelConfig.Marshal(b, m, deterministic)
}
func (m *GeneralModelConfig) XXX_Merge(src proto.Message) {
xxx_messageInfo_GeneralModelConfig.Merge(m, src)
}
func (m *GeneralModelConfig) XXX_Size() int {
return xxx_messageInfo_GeneralModelConfig.Size(m)
}
func (m *GeneralModelConfig) XXX_DiscardUnknown() {
xxx_messageInfo_GeneralModelConfig.DiscardUnknown(m)
}
var xxx_messageInfo_GeneralModelConfig proto.InternalMessageInfo
func (m *GeneralModelConfig) GetFeedVar() []*FeedVar {
if m != nil {
return m.FeedVar
}
return nil
}
func (m *GeneralModelConfig) GetFetchVar() []*FetchVar {
if m != nil {
return m.FetchVar
}
return nil
}
func init() {
proto.RegisterType((*FeedVar)(nil), "baidu.paddle_serving.configure.FeedVar")
proto.RegisterType((*FetchVar)(nil), "baidu.paddle_serving.configure.FetchVar")
proto.RegisterType((*GeneralModelConfig)(nil), "baidu.paddle_serving.configure.GeneralModelConfig")
}
func init() { proto.RegisterFile("general_model_config.proto", fileDescriptor_efa52beffa29d37a) }
var fileDescriptor_efa52beffa29d37a = []byte{
// 283 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xb4, 0xd0, 0x31, 0x4b, 0xc4, 0x30,
0x14, 0x07, 0x70, 0x72, 0x6d, 0xb9, 0xf6, 0x1d, 0x2e, 0xc1, 0xa1, 0x08, 0x1e, 0xe5, 0x16, 0xe3,
0x52, 0xc4, 0xf1, 0x46, 0xc5, 0x73, 0x51, 0x87, 0x72, 0xb8, 0x86, 0xd8, 0xbc, 0xb6, 0x81, 0x5c,
0x53, 0x92, 0xde, 0xc1, 0x2d, 0x7e, 0x13, 0xf1, 0xab, 0x4a, 0x93, 0x43, 0x9c, 0x74, 0x72, 0x7b,
0x79, 0xff, 0xf0, 0xde, 0xe3, 0x07, 0x17, 0x2d, 0xf6, 0x68, 0x85, 0xe6, 0x3b, 0x23, 0x51, 0xf3,
0xda, 0xf4, 0x8d, 0x6a, 0xcb, 0xc1, 0x9a, 0xd1, 0xd0, 0xe5, 0x9b, 0x50, 0x72, 0x5f, 0x0e, 0x42,
0x4a, 0x8d, 0xdc, 0xa1, 0x3d, 0xa8, 0xbe, 0x2d, 0xc3, 0x97, 0xbd, 0xc5, 0xd5, 0x07, 0x81, 0xf9,
0x06, 0x51, 0xbe, 0x0a, 0x4b, 0x29, 0xc4, 0xbd, 0xd8, 0x61, 0x4e, 0x0a, 0xc2, 0xb2, 0xca, 0xd7,
0xf4, 0x12, 0x40, 0x68, 0x25, 0x1c, 0xf7, 0xc9, 0xcc, 0x27, 0x99, 0xef, 0xbc, 0x4c, 0xf1, 0x35,
0x9c, 0x29, 0xc7, 0xb5, 0x91, 0x7c, 0xc4, 0xde, 0x19, 0x9b, 0x47, 0x05, 0x61, 0xe9, 0x3a, 0x69,
0x84, 0x76, 0x58, 0x2d, 0x94, 0x7b, 0x32, 0x72, 0xeb, 0x13, 0xba, 0x84, 0xac, 0x41, 0x94, 0x7c,
0x3c, 0x0e, 0x98, 0xc7, 0x05, 0x61, 0xc9, 0x9a, 0xdc, 0x54, 0xe9, 0xd4, 0xdb, 0x1e, 0x07, 0xa4,
0xe7, 0x90, 0xb8, 0x4e, 0x0c, 0x98, 0x27, 0x45, 0xc4, 0x92, 0x2a, 0x3c, 0x56, 0xef, 0x90, 0x6e,
0x70, 0xac, 0xbb, 0xff, 0xbf, 0xef, 0x7b, 0x7f, 0xfc, 0x73, 0xff, 0x27, 0x01, 0xfa, 0x18, 0x78,
0x9f, 0x27, 0xdd, 0x7b, 0x2f, 0x47, 0xef, 0xc0, 0x1f, 0xce, 0x0f, 0xc2, 0xe6, 0xa4, 0x88, 0xd8,
0xe2, 0xf6, 0xaa, 0xfc, 0x5d, 0xba, 0x3c, 0x29, 0x57, 0xf3, 0xe6, 0xc4, 0xfd, 0x30, 0x81, 0x8c,
0x75, 0xe7, 0x87, 0xcc, 0xfc, 0x10, 0xf6, 0xf7, 0x90, 0x60, 0x31, 0xb9, 0x85, 0xea, 0x2b, 0x00,
0x00, 0xff, 0xff, 0x08, 0x27, 0x9c, 0x1a, 0xfe, 0x01, 0x00, 0x00,
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package main
import (
"bytes"
"encoding/json"
"io/ioutil"
"log"
"net/http"
"encoding/binary"
"fmt"
"general_model_config"
)
type Tensor struct {
IntData []int64 `json:"int_data"`
ElemType int `json:"elem_type"`
Shape []int `json:"shape"`
}
type FeedInst struct {
TensorArray []Tensor `json:"tensor_array"`
}
type FetchInst struct {
TensorArray []Tensor `json:"tensor_array"`
}
type Request struct {
Insts []FeedInst `json:"insts"`
FetchVarNames []string `json:"fetch_var_names"`
}
type Response struct {
Insts []FetchInst `json:"insts"`
}
type Handle struct {
Url string
Port int
FeedAliasNameMap Map[string]string
FeedShapeMap Map[string][]int
FeedNameMap Map[string]int
FetchNameMap Map[string]int
FetchAliasNameMap Map[string]string
}
func LoadModelConfig(config string) Handle {
in, err := ioutil.ReadFile(config)
if err != nil {
log.Fatalln("Failed to read general model: ", err)
}
general_model_config := &pb.GeneralConfig{}
if err := proto.Unmarshal(in, general_model_config); err != nil {
log.Fatalln("Failed to parse GeneralConfig: ", err)
}
handle := Handle{}
for i, v in range general_model_config.FeedVar {
handle.FeedNameMap[v.Name] = i
handle.FeedShapeMap[v.Name] = v
handle.FeedAliasNameMap[v.AliasName] = v.Name
}
for i, v in range general_model_config.FetchVar {
handle.FetchNameMap[v.Name] = i
handle.FetchAliasNameMap[v.AliasName] = v.Name
}
return handle
}
func Connect(url string, port int, handle Handle) Handle {
handle.Url = url
handle.Port = port
return handle
}
func Predict(handle Handle, int_feed_map Map[string]int{}, fetch []string) {
contentType := "application/json;charset=utf-8"
var tensor_array []Tensor{}
var inst FeedInst
for k, v := range int_feed_map {
var tmp Tensor
tmp.Data = int_feed_map[k]
tmp.ElemType = 0
tmp.Shape = handle.FeedNameShape[k]
tensor_array = append(tensor_array, tmp)
}
req := &Request{
Insts: []FeedInst{inst},
FetchVarNames: []string{fetch}}
b, err := json.Marshal(req)
fmt.Println(string(b))
body := bytes.NewBuffer(b)
resp, err := http.Post(url, contentType, body)
if err != nil {
log.Println("Post failed:", err)
return
}
defer resp.Body.Close()
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Println("Read failed:", err)
return
}
log.Println("content:", string(content))
}
......@@ -20,12 +20,13 @@ import (
"io/ioutil"
"log"
"net/http"
"fmt"
pb "general_model"
"github.com/golang/protobuf/proto"
)
type Tensor struct {
Data []byte `json:"data"`
FloatData []float32 `json:"float_data"`
IntData []int64 `json:"int_data"`
ElemType int `json:"elem_type"`
Shape []int `json:"shape"`
......@@ -46,14 +47,16 @@ type Request struct {
type Response struct {
Insts []FetchInst `json:"insts"`
MeanInferUs float32 `json:"mean_infer_us"`
}
type Handle struct {
Url string
Port int
Port string
FeedAliasNameMap map[string]string
FeedShapeMap map[string][]int
FeedNameMap map[string]int
FeedAliasNames []string
FetchNameMap map[string]int
FetchAliasNameMap map[string]string
}
......@@ -67,16 +70,24 @@ func LoadModelConfig(config string) Handle {
if err := proto.Unmarshal(in, general_model_config); err != nil {
log.Fatalln("Failed to parse GeneralModelConfig: ", err)
}
log.Println("read protobuf succeed")
handle := Handle{}
handle.FeedNameMap = map[string]int{}
handle.FeedAliasNameMap = map[string]string{}
handle.FeedShapeMap = map[string][]int{}
handle.FetchNameMap = map[string]int{}
handle.FetchAliasNameMap = map[string]string{}
handle.FeedAliasNames = []string{}
for i, v := range general_model_config.FeedVar {
handle.FeedNameMap[*v.Name] = i
// handle.FeedShapeMap[*v.Name] := []int
tmp_array := []int{}
for _, vv := range v.Shape {
tmp_array = append(tmp_array, int(vv))
}
handle.FeedShapeMap[*v.Name] = tmp_array
handle.FeedAliasNameMap[*v.AliasName] = *v.Name
handle.FeedAliasNames = append(handle.FeedAliasNames, *v.AliasName)
}
for i, v := range general_model_config.FetchVar {
......@@ -87,7 +98,7 @@ func LoadModelConfig(config string) Handle {
return handle
}
func Connect(url string, port int, handle Handle) Handle {
func Connect(url string, port string, handle Handle) Handle {
handle.Url = url
handle.Port = port
return handle
......@@ -98,24 +109,38 @@ func Predict(handle Handle, int_feed_map map[string][]int64, fetch []string) map
var tensor_array []Tensor
var inst FeedInst
for k, _ := range int_feed_map {
var tmp Tensor
tmp.IntData = int_feed_map[k]
tensor_array = []Tensor{}
inst = FeedInst{}
for i := 0; i < len(handle.FeedAliasNames); i++ {
key_i := handle.FeedAliasNames[i]
var tmp Tensor
tmp.IntData = []int64{}
tmp.Shape = []int{}
tmp.IntData = int_feed_map[key_i]
tmp.ElemType = 0
tmp.Shape = handle.FeedShapeMap[k]
tmp.Shape = handle.FeedShapeMap[key_i]
tensor_array = append(tensor_array, tmp)
}
inst.TensorArray = tensor_array
req := &Request{
Insts: []FeedInst{inst},
FetchVarNames: fetch}
b, err := json.Marshal(req)
fmt.Println(string(b))
body := bytes.NewBuffer(b)
resp, err := http.Post(handle.Url, contentType, body)
var post_address bytes.Buffer
post_address.WriteString("http://")
post_address.WriteString(handle.Url)
post_address.WriteString(":")
post_address.WriteString(handle.Port)
post_address.WriteString("/GeneralModelService/inference")
resp, err := http.Post(post_address.String(), contentType, body)
if err != nil {
log.Println("Post failed:", err)
}
......@@ -124,11 +149,17 @@ func Predict(handle Handle, int_feed_map map[string][]int64, fetch []string) map
content, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Println("Read failed:", err)
log.Println("Read failed:", err)
}
log.Println("content:", string(content))
response_json := Response{}
err = json.Unmarshal([]byte(content), &response_json)
var result map[string][]float32
result = map[string][]float32{}
for i, v := range fetch {
result[v] = response_json.Insts[0].TensorArray[i].FloatData
}
return result
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册