提交 1344d989 编写于 作者: S ShiningZhang

add comment for client

上级 54814efe
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -17,17 +17,16 @@ ...@@ -17,17 +17,16 @@
#include "core/general-client/include/brpc_client.h" #include "core/general-client/include/brpc_client.h"
using namespace std; // NOLINT
using baidu::paddle_serving::client::ServingClient; using baidu::paddle_serving::client::ServingClient;
using baidu::paddle_serving::client::ServingBrpcClient; using baidu::paddle_serving::client::ServingBrpcClient;
using baidu::paddle_serving::client::PredictorInputs; using baidu::paddle_serving::client::PredictorInputs;
using baidu::paddle_serving::client::PredictorOutputs; using baidu::paddle_serving::client::PredictorOutputs;
DEFINE_string(server_port, "127.0.0.1:9292", ""); DEFINE_string(server_port, "127.0.0.1:9292", "ip:port");
DEFINE_string(client_conf, "serving_client_conf.prototxt", ""); DEFINE_string(client_conf, "serving_client_conf.prototxt", "Path of client conf");
DEFINE_string(test_type, "brpc", ""); DEFINE_string(test_type, "brpc", "brpc");
DEFINE_string(sample_type, "fit_a_line", ""); // fit_a_line, bert
DEFINE_string(sample_type, "fit_a_line", "List: fit_a_line, bert");
namespace { namespace {
int prepare_fit_a_line(PredictorInputs& input, std::vector<std::string>& fetch_name) { int prepare_fit_a_line(PredictorInputs& input, std::vector<std::string>& fetch_name) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -33,7 +33,7 @@ class ServingBrpcClient : public ServingClient { ...@@ -33,7 +33,7 @@ class ServingBrpcClient : public ServingClient {
int predict(const PredictorInputs& inputs, int predict(const PredictorInputs& inputs,
PredictorOutputs& outputs, PredictorOutputs& outputs,
std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
const uint64_t log_id); const uint64_t log_id);
private: private:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -47,7 +47,7 @@ class ServingClient { ...@@ -47,7 +47,7 @@ class ServingClient {
virtual int predict(const PredictorInputs& inputs, virtual int predict(const PredictorInputs& inputs,
PredictorOutputs& outputs, PredictorOutputs& outputs,
std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
const uint64_t log_id) = 0; const uint64_t log_id) = 0;
protected: protected:
...@@ -66,75 +66,75 @@ class PredictorData { ...@@ -66,75 +66,75 @@ class PredictorData {
PredictorData() {}; PredictorData() {};
virtual ~PredictorData() {}; virtual ~PredictorData() {};
virtual void add_float_data(const std::vector<float>& data, void add_float_data(const std::vector<float>& data,
const std::string& name, const std::string& name,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<int>& lod); const std::vector<int>& lod);
virtual void add_int64_data(const std::vector<int64_t>& data, void add_int64_data(const std::vector<int64_t>& data,
const std::string& name, const std::string& name,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<int>& lod); const std::vector<int>& lod);
virtual void add_int32_data(const std::vector<int32_t>& data, void add_int32_data(const std::vector<int32_t>& data,
const std::string& name, const std::string& name,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<int>& lod); const std::vector<int>& lod);
virtual void add_string_data(const std::string& data, void add_string_data(const std::string& data,
const std::string& name, const std::string& name,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<int>& lod); const std::vector<int>& lod);
virtual const std::map<std::string, std::vector<float>>& float_data_map() const { const std::map<std::string, std::vector<float>>& float_data_map() const {
return _float_data_map; return _float_data_map;
}; };
virtual std::map<std::string, std::vector<float>>* mutable_float_data_map() { std::map<std::string, std::vector<float>>* mutable_float_data_map() {
return &_float_data_map; return &_float_data_map;
}; };
virtual const std::map<std::string, std::vector<int64_t>>& int64_data_map() const { const std::map<std::string, std::vector<int64_t>>& int64_data_map() const {
return _int64_data_map; return _int64_data_map;
}; };
virtual std::map<std::string, std::vector<int64_t>>* mutable_int64_data_map() { std::map<std::string, std::vector<int64_t>>* mutable_int64_data_map() {
return &_int64_data_map; return &_int64_data_map;
}; };
virtual const std::map<std::string, std::vector<int32_t>>& int_data_map() const { const std::map<std::string, std::vector<int32_t>>& int_data_map() const {
return _int32_data_map; return _int32_data_map;
}; };
virtual std::map<std::string, std::vector<int32_t>>* mutable_int_data_map() { std::map<std::string, std::vector<int32_t>>* mutable_int_data_map() {
return &_int32_data_map; return &_int32_data_map;
}; };
virtual const std::map<std::string, std::string>& string_data_map() const { const std::map<std::string, std::string>& string_data_map() const {
return _string_data_map; return _string_data_map;
}; };
virtual std::map<std::string, std::string>* mutable_string_data_map() { std::map<std::string, std::string>* mutable_string_data_map() {
return &_string_data_map; return &_string_data_map;
}; };
virtual const std::map<std::string, std::vector<int>>& shape_map() const { const std::map<std::string, std::vector<int>>& shape_map() const {
return _shape_map; return _shape_map;
}; };
virtual std::map<std::string, std::vector<int>>* mutable_shape_map() { std::map<std::string, std::vector<int>>* mutable_shape_map() {
return &_shape_map; return &_shape_map;
}; };
virtual const std::map<std::string, std::vector<int>>& lod_map() const { const std::map<std::string, std::vector<int>>& lod_map() const {
return _lod_map; return _lod_map;
}; };
virtual std::map<std::string, std::vector<int>>* mutable_lod_map() { std::map<std::string, std::vector<int>>* mutable_lod_map() {
return &_lod_map; return &_lod_map;
}; };
virtual std::string print(); std::string print();
private: private:
template<typename T1, typename T2> template<typename T1, typename T2>
...@@ -196,7 +196,7 @@ class PredictorInputs : public PredictorData { ...@@ -196,7 +196,7 @@ class PredictorInputs : public PredictorData {
PredictorInputs() {}; PredictorInputs() {};
virtual ~PredictorInputs() {}; virtual ~PredictorInputs() {};
static int gen_proto(const PredictorInputs& inputs, static int GenProto(const PredictorInputs& inputs,
const std::map<std::string, int>& feed_name_to_idx, const std::map<std::string, int>& feed_name_to_idx,
const std::vector<std::string>& feed_name, const std::vector<std::string>& feed_name,
predictor::general_model::Request& req); predictor::general_model::Request& req);
...@@ -212,23 +212,23 @@ class PredictorOutputs { ...@@ -212,23 +212,23 @@ class PredictorOutputs {
PredictorOutputs() {}; PredictorOutputs() {};
virtual ~PredictorOutputs() {}; virtual ~PredictorOutputs() {};
virtual std::vector<std::shared_ptr<PredictorOutputs::PredictorOutput>>& datas() { const std::vector<std::shared_ptr<PredictorOutputs::PredictorOutput>>& datas() {
return _datas; return _datas;
}; };
virtual std::vector<std::shared_ptr<PredictorOutputs::PredictorOutput>>* mutable_datas() { std::vector<std::shared_ptr<PredictorOutputs::PredictorOutput>>* mutable_datas() {
return &_datas; return &_datas;
}; };
virtual void add_data(const std::shared_ptr<PredictorOutputs::PredictorOutput>& data) { void add_data(const std::shared_ptr<PredictorOutputs::PredictorOutput>& data) {
_datas.push_back(data); _datas.push_back(data);
}; };
virtual std::string print(); std::string print();
virtual void clear(); void clear();
static int parse_proto(const predictor::general_model::Response& res, static int ParseProto(const predictor::general_model::Response& res,
const std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
std::map<std::string, int>& fetch_name_to_type, std::map<std::string, int>& fetch_name_to_type,
PredictorOutputs& outputs); PredictorOutputs& outputs);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -47,6 +47,7 @@ int ServingBrpcClient::connect(const std::string server_port) { ...@@ -47,6 +47,7 @@ int ServingBrpcClient::connect(const std::string server_port) {
} }
std::string ServingBrpcClient::gen_desc(const std::string server_port) { std::string ServingBrpcClient::gen_desc(const std::string server_port) {
// default config for brpc
SDKConf sdk_conf; SDKConf sdk_conf;
Predictor* predictor = sdk_conf.add_predictors(); Predictor* predictor = sdk_conf.add_predictors();
...@@ -83,13 +84,14 @@ std::string ServingBrpcClient::gen_desc(const std::string server_port) { ...@@ -83,13 +84,14 @@ std::string ServingBrpcClient::gen_desc(const std::string server_port) {
int ServingBrpcClient::predict(const PredictorInputs& inputs, int ServingBrpcClient::predict(const PredictorInputs& inputs,
PredictorOutputs& outputs, PredictorOutputs& outputs,
std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
const uint64_t log_id) { const uint64_t log_id) {
Timer timeline; Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS(); int64_t preprocess_start = timeline.TimeStampUS();
// thread initialize for StubTLS
_api.thrd_initialize(); _api.thrd_initialize();
std::string variant_tag; std::string variant_tag;
// predictor is bound to request with brpc::Controller
_predictor = _api.fetch_predictor("general_model", &variant_tag); _predictor = _api.fetch_predictor("general_model", &variant_tag);
if (_predictor == NULL) { if (_predictor == NULL) {
LOG(ERROR) << "Failed fetch predictor so predict error!"; LOG(ERROR) << "Failed fetch predictor so predict error!";
...@@ -105,7 +107,7 @@ int ServingBrpcClient::predict(const PredictorInputs& inputs, ...@@ -105,7 +107,7 @@ int ServingBrpcClient::predict(const PredictorInputs& inputs,
req.add_fetch_var_names(name); req.add_fetch_var_names(name);
} }
if (PredictorInputs::gen_proto(inputs, _feed_name_to_idx, _feed_name, req) != 0) { if (PredictorInputs::GenProto(inputs, _feed_name_to_idx, _feed_name, req) != 0) {
LOG(ERROR) << "Failed to preprocess req!"; LOG(ERROR) << "Failed to preprocess req!";
return -1; return -1;
} }
...@@ -132,7 +134,7 @@ int ServingBrpcClient::predict(const PredictorInputs& inputs, ...@@ -132,7 +134,7 @@ int ServingBrpcClient::predict(const PredictorInputs& inputs,
client_infer_end = timeline.TimeStampUS(); client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end; postprocess_start = client_infer_end;
if (PredictorOutputs::parse_proto(res, fetch_name, _fetch_name_to_type, outputs) != 0) { if (PredictorOutputs::ParseProto(res, fetch_name, _fetch_name_to_type, outputs) != 0) {
LOG(ERROR) << "Failed to post_process res!"; LOG(ERROR) << "Failed to post_process res!";
return -1; return -1;
} }
...@@ -160,6 +162,7 @@ int ServingBrpcClient::predict(const PredictorInputs& inputs, ...@@ -160,6 +162,7 @@ int ServingBrpcClient::predict(const PredictorInputs& inputs,
fprintf(stderr, "%s\n", oss.str().c_str()); fprintf(stderr, "%s\n", oss.str().c_str());
} }
// release predictor
_api.thrd_clear(); _api.thrd_clear();
return 0; return 0;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -32,6 +32,7 @@ int ServingClient::init(const std::vector<std::string>& client_conf, ...@@ -32,6 +32,7 @@ int ServingClient::init(const std::vector<std::string>& client_conf,
return -1; return -1;
} }
// pure virtual func, subclass implementation
if (connect(server_port) != 0) { if (connect(server_port) != 0) {
LOG(ERROR) << "Failed to connect"; LOG(ERROR) << "Failed to connect";
return -1; return -1;
...@@ -148,7 +149,7 @@ std::string PredictorData::print() { ...@@ -148,7 +149,7 @@ std::string PredictorData::print() {
return res; return res;
} }
int PredictorInputs::gen_proto(const PredictorInputs& inputs, int PredictorInputs::GenProto(const PredictorInputs& inputs,
const std::map<std::string, int>& feed_name_to_idx, const std::map<std::string, int>& feed_name_to_idx,
const std::vector<std::string>& feed_name, const std::vector<std::string>& feed_name,
Request& req) { Request& req) {
...@@ -317,7 +318,7 @@ void PredictorOutputs::clear() { ...@@ -317,7 +318,7 @@ void PredictorOutputs::clear() {
_datas.clear(); _datas.clear();
} }
int PredictorOutputs::parse_proto(const Response& res, int PredictorOutputs::ParseProto(const Response& res,
const std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
std::map<std::string, int>& fetch_name_to_type, std::map<std::string, int>& fetch_name_to_type,
PredictorOutputs& outputs) { PredictorOutputs& outputs) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册