general_model.cpp 15.9 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

G
guru4elephant 已提交
15
#include "core/general-client/include/general_model.h"
M
MRXLT 已提交
16
#include <fstream>
G
guru4elephant 已提交
17 18 19
#include "core/sdk-cpp/builtin_format.pb.h"
#include "core/sdk-cpp/include/common.h"
#include "core/sdk-cpp/include/predictor_sdk.h"
G
guru4elephant 已提交
20
#include "core/util/include/timer.h"
21 22 23
DEFINE_bool(profile_client, false, "");
DEFINE_bool(profile_server, false, "");

G
guru4elephant 已提交
24
using baidu::paddle_serving::Timer;
G
guru4elephant 已提交
25 26 27
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::Response;
using baidu::paddle_serving::predictor::general_model::Tensor;
S
ShiningZhang 已提交
28 29
// paddle inference support: FLOAT32, INT64, INT32, UINT8, INT8
// will support: FLOAT16
S
ShiningZhang 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42
enum ProtoDataType {
  P_INT64 = 0,
  P_FLOAT32,
  P_INT32,
  P_FP64,
  P_INT16,
  P_FP16,
  P_BF16,
  P_UINT8,
  P_INT8,
  P_BOOL,
  P_COMPLEX64,
  P_COMPLEX128,
S
ShiningZhang 已提交
43
  P_STRING = 20,
S
ShiningZhang 已提交
44
};
45
std::once_flag gflags_init_flag;
M
MRXLT 已提交
46
namespace py = pybind11;
47

G
guru4elephant 已提交
48 49 50
namespace baidu {
namespace paddle_serving {
namespace general_model {
51
using configure::GeneralModelConfig;
G
guru4elephant 已提交
52

53 54
void PredictorClient::init_gflags(std::vector<std::string> argv) {
  std::call_once(gflags_init_flag, [&]() {
55
#ifndef BCLOUD
M
MRXLT 已提交
56
    FLAGS_logtostderr = true;
57
#endif
M
MRXLT 已提交
58 59 60 61
    argv.insert(argv.begin(), "dummy");
    int argc = argv.size();
    char **arr = new char *[argv.size()];
    std::string line;
H
HexToString 已提交
62
    for (size_t i = 0; i < argv.size(); ++i) {
M
MRXLT 已提交
63 64 65 66 67 68 69
      arr[i] = &argv[i][0];
      line += argv[i];
      line += ' ';
    }
    google::ParseCommandLineFlags(&argc, &arr, true);
    VLOG(2) << "Init commandline: " << line;
  });
70 71
}

H
HexToString 已提交
72
int PredictorClient::init(const std::vector<std::string> &conf_file) {
73 74
  try {
    GeneralModelConfig model_config;
H
HexToString 已提交
75
    if (configure::read_proto_conf(conf_file[0].c_str(), &model_config) != 0) {
76
      LOG(ERROR) << "Failed to load general model config"
H
HexToString 已提交
77
                 << ", file path: " << conf_file[0];
78 79
      return -1;
    }
H
HexToString 已提交
80

81 82 83 84
    _feed_name_to_idx.clear();
    _fetch_name_to_idx.clear();
    _shape.clear();
    int feed_var_num = model_config.feed_var_size();
H
HexToString 已提交
85
    _feed_name.clear();
H
HexToString 已提交
86
    VLOG(2) << "feed var num: " << feed_var_num;
87 88
    for (int i = 0; i < feed_var_num; ++i) {
      _feed_name_to_idx[model_config.feed_var(i).alias_name()] = i;
H
HexToString 已提交
89 90 91
      VLOG(2) << "feed [" << i << "]"
              << " name: " << model_config.feed_var(i).name();
      _feed_name.push_back(model_config.feed_var(i).name());
92 93
      VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name()
              << " index: " << i;
94
      std::vector<int> tmp_feed_shape;
M
MRXLT 已提交
95 96
      VLOG(2) << "feed"
              << "[" << i << "] shape:";
97 98
      for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) {
        tmp_feed_shape.push_back(model_config.feed_var(i).shape(j));
M
MRXLT 已提交
99
        VLOG(2) << "shape[" << j << "]: " << model_config.feed_var(i).shape(j);
100 101
      }
      _type.push_back(model_config.feed_var(i).feed_type());
M
MRXLT 已提交
102 103 104
      VLOG(2) << "feed"
              << "[" << i
              << "] feed type: " << model_config.feed_var(i).feed_type();
105
      _shape.push_back(tmp_feed_shape);
G
guru4elephant 已提交
106 107
    }

H
HexToString 已提交
108
    if (conf_file.size() > 1) {
H
HexToString 已提交
109
      model_config.Clear();
H
HexToString 已提交
110 111
      if (configure::read_proto_conf(conf_file[conf_file.size() - 1].c_str(),
                                     &model_config) != 0) {
H
HexToString 已提交
112
        LOG(ERROR) << "Failed to load general model config"
H
HexToString 已提交
113
                   << ", file path: " << conf_file[conf_file.size() - 1];
H
HexToString 已提交
114 115 116 117 118
        return -1;
      }
    }
    int fetch_var_num = model_config.fetch_var_size();
    VLOG(2) << "fetch_var_num: " << fetch_var_num;
119 120
    for (int i = 0; i < fetch_var_num; ++i) {
      _fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i;
M
MRXLT 已提交
121 122
      VLOG(2) << "fetch [" << i << "]"
              << " alias name: " << model_config.fetch_var(i).alias_name();
123 124
      _fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] =
          model_config.fetch_var(i).name();
125 126
      _fetch_name_to_type[model_config.fetch_var(i).alias_name()] =
          model_config.fetch_var(i).fetch_type();
127
    }
M
MRXLT 已提交
128
  } catch (std::exception &e) {
129 130
    LOG(ERROR) << "Failed load general model config" << e.what();
    return -1;
G
guru4elephant 已提交
131
  }
132
  return 0;
G
guru4elephant 已提交
133 134
}

M
MRXLT 已提交
135 136
void PredictorClient::set_predictor_conf(const std::string &conf_path,
                                         const std::string &conf_file) {
G
guru4elephant 已提交
137 138 139
  _predictor_path = conf_path;
  _predictor_conf = conf_file;
}
140 141 142
int PredictorClient::destroy_predictor() {
  _api.thrd_finalize();
  _api.destroy();
B
barrierye 已提交
143
  return 0;
144 145
}

M
MRXLT 已提交
146
int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
G
guru4elephant 已提交
147 148 149 150
  if (_api.create(sdk_desc) != 0) {
    LOG(ERROR) << "Predictor Creation Failed";
    return -1;
  }
D
dongdaxiang 已提交
151
  // _api.thrd_initialize();
B
barrierye 已提交
152
  return 0;
G
guru4elephant 已提交
153 154
}

G
guru4elephant 已提交
155
int PredictorClient::create_predictor() {
G
guru4elephant 已提交
156 157
  VLOG(2) << "Predictor path: " << _predictor_path
          << " predictor file: " << _predictor_conf;
G
guru4elephant 已提交
158 159 160 161
  if (_api.create(_predictor_path.c_str(), _predictor_conf.c_str()) != 0) {
    LOG(ERROR) << "Predictor Creation Failed";
    return -1;
  }
D
dongdaxiang 已提交
162
  // _api.thrd_initialize();
B
barrierye 已提交
163
  return 0;
G
guru4elephant 已提交
164 165
}

M
MRXLT 已提交
166
int PredictorClient::numpy_predict(
H
HexToString 已提交
167
    const std::vector<py::array_t<float>> &float_feed,
M
MRXLT 已提交
168 169
    const std::vector<std::string> &float_feed_name,
    const std::vector<std::vector<int>> &float_shape,
W
wangjiawei04 已提交
170
    const std::vector<std::vector<int>> &float_lod_slot_batch,
H
HexToString 已提交
171
    const std::vector<py::array_t<int64_t>> &int_feed,
M
MRXLT 已提交
172 173
    const std::vector<std::string> &int_feed_name,
    const std::vector<std::vector<int>> &int_shape,
W
wangjiawei04 已提交
174
    const std::vector<std::vector<int>> &int_lod_slot_batch,
H
HexToString 已提交
175
    const std::vector<std::string> &string_feed,
H
HexToString 已提交
176 177 178
    const std::vector<std::string> &string_feed_name,
    const std::vector<std::vector<int>> &string_shape,
    const std::vector<std::vector<int>> &string_lod_slot_batch,
M
MRXLT 已提交
179 180
    const std::vector<std::string> &fetch_name,
    PredictorRes &predict_res_batch,
181 182
    const int &pid,
    const uint64_t log_id) {
M
MRXLT 已提交
183 184 185 186 187 188 189 190 191 192 193
  predict_res_batch.clear();
  Timer timeline;
  int64_t preprocess_start = timeline.TimeStampUS();

  _api.thrd_initialize();
  std::string variant_tag;
  _predictor = _api.fetch_predictor("general_model", &variant_tag);
  predict_res_batch.set_variant_tag(variant_tag);
  VLOG(2) << "fetch general model predictor done.";
  VLOG(2) << "float feed name size: " << float_feed_name.size();
  VLOG(2) << "int feed name size: " << int_feed_name.size();
H
HexToString 已提交
194
  VLOG(2) << "string feed name size: " << string_feed_name.size();
M
MRXLT 已提交
195 196
  VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
  Request req;
197
  req.set_log_id(log_id);
M
MRXLT 已提交
198 199 200 201
  for (auto &name : fetch_name) {
    req.add_fetch_var_names(name);
  }

H
HexToString 已提交
202
  int vec_idx = 0;
H
HexToString 已提交
203 204
  // batch is already in Tensor.
  std::vector<Tensor *> tensor_vec;
M
MRXLT 已提交
205

H
HexToString 已提交
206 207 208
  for (auto &name : float_feed_name) {
    tensor_vec.push_back(req.add_tensor());
  }
H
HexToString 已提交
209

H
HexToString 已提交
210 211 212
  for (auto &name : int_feed_name) {
    tensor_vec.push_back(req.add_tensor());
  }
M
MRXLT 已提交
213

H
HexToString 已提交
214 215 216
  for (auto &name : string_feed_name) {
    tensor_vec.push_back(req.add_tensor());
  }
H
HexToString 已提交
217

H
HexToString 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  vec_idx = 0;
  for (auto &name : float_feed_name) {
    int idx = _feed_name_to_idx[name];
    if (idx >= tensor_vec.size()) {
      LOG(ERROR) << "idx > tensor_vec.size()";
      return -1;
    }
    VLOG(2) << "prepare float feed " << name << " idx " << idx;
    int nbytes = float_feed[vec_idx].nbytes();
    void *rawdata_ptr = (void *)(float_feed[vec_idx].data(0));
    int total_number = float_feed[vec_idx].size();
    Tensor *tensor = tensor_vec[idx];

    VLOG(2) << "prepare float feed " << name << " shape size "
            << float_shape[vec_idx].size();
    for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) {
      tensor->add_shape(float_shape[vec_idx][j]);
    }
    for (uint32_t j = 0; j < float_lod_slot_batch[vec_idx].size(); ++j) {
      tensor->add_lod(float_lod_slot_batch[vec_idx][j]);
M
MRXLT 已提交
238
    }
H
HexToString 已提交
239
    tensor->set_elem_type(P_FLOAT32);
H
HexToString 已提交
240

H
HexToString 已提交
241 242
    tensor->set_name(_feed_name[idx]);
    tensor->set_alias_name(name);
M
MRXLT 已提交
243

H
HexToString 已提交
244 245 246 247
    tensor->mutable_float_data()->Resize(total_number, 0);
    memcpy(tensor->mutable_float_data()->mutable_data(), rawdata_ptr, nbytes);
    vec_idx++;
  }
M
MRXLT 已提交
248

H
HexToString 已提交
249 250 251 252 253 254
  vec_idx = 0;
  for (auto &name : int_feed_name) {
    int idx = _feed_name_to_idx[name];
    if (idx >= tensor_vec.size()) {
      LOG(ERROR) << "idx > tensor_vec.size()";
      return -1;
M
MRXLT 已提交
255
    }
H
HexToString 已提交
256 257 258 259
    Tensor *tensor = tensor_vec[idx];
    int nbytes = int_feed[vec_idx].nbytes();
    void *rawdata_ptr = (void *)(int_feed[vec_idx].data(0));
    int total_number = int_feed[vec_idx].size();
M
MRXLT 已提交
260

H
HexToString 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
      tensor->add_shape(int_shape[vec_idx][j]);
    }
    for (uint32_t j = 0; j < int_lod_slot_batch[vec_idx].size(); ++j) {
      tensor->add_lod(int_lod_slot_batch[vec_idx][j]);
    }
    tensor->set_elem_type(_type[idx]);
    tensor->set_name(_feed_name[idx]);
    tensor->set_alias_name(name);

    if (_type[idx] == P_INT64) {
      tensor->mutable_int64_data()->Resize(total_number, 0);
      memcpy(tensor->mutable_int64_data()->mutable_data(), rawdata_ptr, nbytes);
    } else {
      tensor->mutable_int_data()->Resize(total_number, 0);
      memcpy(tensor->mutable_int_data()->mutable_data(), rawdata_ptr, nbytes);
    }
    vec_idx++;
  }
H
HexToString 已提交
280

S
ShiningZhang 已提交
281 282
  // Add !P_STRING feed data of string_input to tensor_content
  // UINT8 INT8 FLOAT16
H
HexToString 已提交
283 284 285 286 287 288 289
  vec_idx = 0;
  for (auto &name : string_feed_name) {
    int idx = _feed_name_to_idx[name];
    if (idx >= tensor_vec.size()) {
      LOG(ERROR) << "idx > tensor_vec.size()";
      return -1;
    }
S
ShiningZhang 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
    if (_type[idx] == P_STRING) {
      continue;
    }
    Tensor *tensor = tensor_vec[idx];

    for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
      tensor->add_shape(string_shape[vec_idx][j]);
    }
    for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
      tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
    }
    tensor->set_elem_type(_type[idx]);
    tensor->set_name(_feed_name[idx]);
    tensor->set_alias_name(name);

    tensor->set_tensor_content(string_feed[vec_idx]);
    vec_idx++;
  }

  vec_idx = 0;
  for (auto &name : string_feed_name) {
    int idx = _feed_name_to_idx[name];
    if (idx >= tensor_vec.size()) {
      LOG(ERROR) << "idx > tensor_vec.size()";
      return -1;
    }
    if (_type[idx] != P_STRING) {
      continue;
    }
H
HexToString 已提交
319
    Tensor *tensor = tensor_vec[idx];
H
HexToString 已提交
320

H
HexToString 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
      tensor->add_shape(string_shape[vec_idx][j]);
    }
    for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
      tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
    }
    tensor->set_elem_type(P_STRING);
    tensor->set_name(_feed_name[idx]);
    tensor->set_alias_name(name);

    const int string_shape_size = string_shape[vec_idx].size();
    // string_shape[vec_idx] = [1];cause numpy has no datatype of string.
    // we pass string via vector<vector<string> >.
    if (string_shape_size != 1) {
      LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
                 << string_shape_size;
      return -1;
    }
    switch (string_shape_size) {
      case 1: {
        tensor->add_data(string_feed[vec_idx]);
        break;
H
HexToString 已提交
343 344
      }
    }
H
HexToString 已提交
345
    vec_idx++;
M
MRXLT 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
  }

  int64_t preprocess_end = timeline.TimeStampUS();
  int64_t client_infer_start = timeline.TimeStampUS();
  Response res;

  int64_t client_infer_end = 0;
  int64_t postprocess_start = 0;
  int64_t postprocess_end = 0;

  if (FLAGS_profile_client) {
    if (FLAGS_profile_server) {
      req.set_profile_server(true);
    }
  }

  res.Clear();
  if (_predictor->inference(&req, &res) != 0) {
    LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
    return -1;
  } else {
    client_infer_end = timeline.TimeStampUS();
    postprocess_start = client_infer_end;
    VLOG(2) << "get model output num";
    uint32_t model_num = res.outputs_size();
    VLOG(2) << "model num: " << model_num;
B
barrierye 已提交
372 373 374
    for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) {
      VLOG(2) << "process model output index: " << m_idx;
      auto output = res.outputs(m_idx);
B
barrierye 已提交
375 376
      ModelRes model;
      model.set_engine_name(output.engine_name());
H
HexToString 已提交
377 378 379
      // 在ResponseOp处,已经按照fetch_name对输出数据进行了处理
      // 所以,输出的数据与fetch_name是严格对应的,按顺序处理即可。
      for (int idx = 0; idx < output.tensor_size(); ++idx) {
B
barrierye 已提交
380
        // int idx = _fetch_name_to_idx[name];
H
HexToString 已提交
381 382
        const std::string name = output.tensor(idx).alias_name();
        model._tensor_alias_names.push_back(name);
H
HexToString 已提交
383
        int shape_size = output.tensor(idx).shape_size();
B
barrierye 已提交
384 385
        VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
                << shape_size;
B
barrierye 已提交
386 387
        model._shape_map[name].resize(shape_size);
        for (int i = 0; i < shape_size; ++i) {
H
HexToString 已提交
388
          model._shape_map[name][i] = output.tensor(idx).shape(i);
B
barrierye 已提交
389
        }
H
HexToString 已提交
390
        int lod_size = output.tensor(idx).lod_size();
B
barrierye 已提交
391 392 393
        if (lod_size > 0) {
          model._lod_map[name].resize(lod_size);
          for (int i = 0; i < lod_size; ++i) {
H
HexToString 已提交
394
            model._lod_map[name][i] = output.tensor(idx).lod(i);
B
barrierye 已提交
395
          }
396 397
        }

H
HexToString 已提交
398
        if (_fetch_name_to_type[name] == P_INT64) {
M
MRXLT 已提交
399
          VLOG(2) << "ferch var " << name << "type int64";
H
HexToString 已提交
400
          int size = output.tensor(idx).int64_data_size();
W
WangXi 已提交
401
          model._int64_value_map[name] = std::vector<int64_t>(
H
HexToString 已提交
402 403
              output.tensor(idx).int64_data().begin(),
              output.tensor(idx).int64_data().begin() + size);
H
HexToString 已提交
404
        } else if (_fetch_name_to_type[name] == P_FLOAT32) {
B
barrierye 已提交
405
          VLOG(2) << "fetch var " << name << "type float";
H
HexToString 已提交
406
          int size = output.tensor(idx).float_data_size();
W
WangXi 已提交
407
          model._float_value_map[name] = std::vector<float>(
H
HexToString 已提交
408 409
              output.tensor(idx).float_data().begin(),
              output.tensor(idx).float_data().begin() + size);
H
HexToString 已提交
410
        } else if (_fetch_name_to_type[name] == P_INT32) {
M
MRXLT 已提交
411
          VLOG(2) << "fetch var " << name << "type int32";
H
HexToString 已提交
412
          int size = output.tensor(idx).int_data_size();
M
MRXLT 已提交
413
          model._int32_value_map[name] = std::vector<int32_t>(
H
HexToString 已提交
414 415
              output.tensor(idx).int_data().begin(),
              output.tensor(idx).int_data().begin() + size);
S
ShiningZhang 已提交
416 417 418 419 420 421 422 423 424
        } else if (_fetch_name_to_type[name] == P_UINT8) {
          VLOG(2) << "fetch var " << name << "type uint8";
          model._string_value_map[name] = output.tensor(idx).tensor_content();
        } else if (_fetch_name_to_type[name] == P_INT8) {
          VLOG(2) << "fetch var " << name << "type int8";
          model._string_value_map[name] = output.tensor(idx).tensor_content();
        } else if (_fetch_name_to_type[name] == P_FP16) {
          VLOG(2) << "fetch var " << name << "type float16";
          model._string_value_map[name] = output.tensor(idx).tensor_content();
M
MRXLT 已提交
425
        }
M
MRXLT 已提交
426
      }
B
barrierye 已提交
427
      predict_res_batch.add_model_res(std::move(model));
M
MRXLT 已提交
428
    }
429
    postprocess_end = timeline.TimeStampUS();
M
MRXLT 已提交
430 431
  }

M
MRXLT 已提交
432 433 434
  if (FLAGS_profile_client) {
    std::ostringstream oss;
    oss << "PROFILE\t"
M
MRXLT 已提交
435
        << "pid:" << pid << "\t"
M
MRXLT 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
        << "prepro_0:" << preprocess_start << " "
        << "prepro_1:" << preprocess_end << " "
        << "client_infer_0:" << client_infer_start << " "
        << "client_infer_1:" << client_infer_end << " ";
    if (FLAGS_profile_server) {
      int op_num = res.profile_time_size() / 2;
      for (int i = 0; i < op_num; ++i) {
        oss << "op" << i << "_0:" << res.profile_time(i * 2) << " ";
        oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
      }
    }

    oss << "postpro_0:" << postprocess_start << " ";
    oss << "postpro_1:" << postprocess_end;

    fprintf(stderr, "%s\n", oss.str().c_str());
  }
D
dongdaxiang 已提交
453 454

  _api.thrd_clear();
M
MRXLT 已提交
455
  return 0;
M
MRXLT 已提交
456
}
G
guru4elephant 已提交
457 458 459
}  // namespace general_model
}  // namespace paddle_serving
}  // namespace baidu