imdb_client.go 1.9 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14 15 16 17

package main

import (
G
guru4elephant 已提交
18
       "io"
G
guru4elephant 已提交
19
       "fmt"
G
guru4elephant 已提交
20 21 22
       "strings"
       "bufio"
       "strconv"
23
       "os"
G
guru4elephant 已提交
24
       "serving_client"
25 26 27
)

func main() {
G
guru4elephant 已提交
28 29 30
     var config_file_path string
     config_file_path = os.Args[1]
     handle := serving_client.LoadModelConfig(config_file_path)
G
guru4elephant 已提交
31
     handle = serving_client.Connect("127.0.0.1", "9292", handle)
32

G
guru4elephant 已提交
33
     test_file_path := os.Args[2]
34 35 36 37 38 39 40 41
     fi, err := os.Open(test_file_path)
     if err != nil {
     	fmt.Print(err)
     }

     defer fi.Close()
     br := bufio.NewReader(fi)

G
guru4elephant 已提交
42
     fetch := []string{"cost", "acc", "prediction"}     
43

G
guru4elephant 已提交
44
     var result map[string][]float32
45 46 47 48 49 50 51

     for {
     	 line, err := br.ReadString('\n')
	 if err == io.EOF {
	    break
	 }

G
guru4elephant 已提交
52 53
	 line = strings.Trim(line, "\n")

G
guru4elephant 已提交
54
	 var words = []int64{}
55 56 57

	 s := strings.Split(line, " ")
	 value, err := strconv.Atoi(s[0])
G
guru4elephant 已提交
58
	 var feed_int_map map[string][]int64
59 60

	 for _, v := range s[1:value + 1] {
G
guru4elephant 已提交
61 62
	     int_v, _ := strconv.Atoi(v)
	     words = append(words, int64(int_v))
63 64
	 }

G
guru4elephant 已提交
65 66 67 68 69
	 label, err := strconv.Atoi(s[len(s)-1])

	 if err != nil {
	    panic(err)
	 }
70

G
guru4elephant 已提交
71
	 feed_int_map = map[string][]int64{}
72
	 feed_int_map["words"] = words
G
guru4elephant 已提交
73
	 feed_int_map["label"] = []int64{int64(label)}
74
	 
G
guru4elephant 已提交
75 76
	 result = serving_client.Predict(handle,
	 	 feed_int_map, fetch)
G
guru4elephant 已提交
77
	 fmt.Println(result["prediction"][1], "\t", int64(label))
78 79
     }
}