general_model_main.cpp 2.1 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fstream>
#include <vector>

B
barrierye 已提交
18
#include "general_model.h"  // NOLINT
G
guru4elephant 已提交
19

B
barrierye 已提交
20
using namespace std;  // NOLINT
G
guru4elephant 已提交
21 22 23 24

using baidu::paddle_serving::general_model::PredictorClient;
using baidu::paddle_serving::general_model::FetchedMap;

B
barrierye 已提交
25 26
int main(int argc, char* argv[]) {
  PredictorClient* client = new PredictorClient();
G
guru4elephant 已提交
27 28 29
  client->init("inference.conf");
  client->set_predictor_conf("./", "predictor.conf");
  client->create_predictor();
B
barrierye 已提交
30 31
  std::vector<std::vector<float>> float_feed;
  std::vector<std::vector<int64_t>> int_feed;
G
guru4elephant 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
  std::vector<std::string> float_feed_name;
  std::vector<std::string> int_feed_name = {"words", "label"};
  std::vector<std::string> fetch_name = {"cost", "acc", "prediction"};

  std::string line;
  int64_t text_id = 0;
  int64_t label = 0;
  int text_id_num = 0;
  int label_num = 0;
  int line_num = 0;
  while (cin >> text_id_num) {
    int_feed.clear();
    float_feed.clear();
    std::vector<int64_t> ids;
    ids.reserve(text_id_num);
    for (int i = 0; i < text_id_num; ++i) {
      cin >> text_id;
      ids.push_back(text_id);
    }
    int_feed.push_back(ids);
    cin >> label_num;
    cin >> label;
    int_feed.push_back({label});

    FetchedMap result;

B
barrierye 已提交
58 59 60 61 62 63
    client->predict(float_feed,
                    float_feed_name,
                    int_feed,
                    int_feed_name,
                    fetch_name,
                    &result);
G
guru4elephant 已提交
64 65 66 67 68 69 70 71 72

    cout << label << "\t" << result["prediction"][1] << endl;

    line_num++;
    if (line_num % 100 == 0) {
      cerr << "line num: " << line_num << endl;
    }
  }
}