diff --git a/core/general-client/CMakeLists.txt b/core/general-client/CMakeLists.txt index d6079317a75d3f45b61920836e6695bd6b31d951..0a7f2ee4b2899a1e6c6b4557dc26f767efe842e1 100644 --- a/core/general-client/CMakeLists.txt +++ b/core/general-client/CMakeLists.txt @@ -3,3 +3,24 @@ add_subdirectory(pybind11) pybind11_add_module(serving_client src/general_model.cpp src/pybind_general_model.cpp) target_link_libraries(serving_client PRIVATE -Wl,--whole-archive utils sdk-cpp pybind python -Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz -Wl,-rpath,'$ORIGIN'/lib) endif() + +if(CLIENT) +FILE(GLOB client_srcs include/*.h src/client.cpp src/brpc_client.cpp) +add_library(client ${client_srcs}) +add_dependencies(client utils sdk-cpp) +target_link_libraries(client utils sdk-cpp) +endif() + +if(CLIENT) +include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../../) +add_executable(simple_client example/simple_client.cpp) + +add_dependencies(simple_client utils sdk-cpp client) + +target_link_libraries(simple_client -Wl,--whole-archive + -Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz -Wl,-rpath,'$ORIGIN'/lib) + +target_link_libraries(simple_client utils) +target_link_libraries(simple_client sdk-cpp) +target_link_libraries(simple_client client) +endif() \ No newline at end of file diff --git a/core/general-client/README_CN.md b/core/general-client/README_CN.md new file mode 100755 index 0000000000000000000000000000000000000000..d391ed8612b5296843b7b0dfadf951a699c9dfa5 --- /dev/null +++ b/core/general-client/README_CN.md @@ -0,0 +1,33 @@ +# 用于Paddle Serving的C++客户端 + +(简体中文|[English](./README.md)) + +## 请求BRPC-Server + +### 服务端启动 + +以fit_a_line模型为例,服务端启动与常规BRPC-Server端启动命令一样。 + +``` +cd ../../python/examples/fit_a_line +sh get_data.sh +python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9393 +``` + +### 客户端预测 + +客户端目前支持BRPC +目前已经实现了BRPC的封装函数,详见[brpc_client.cpp](./src/brpc_client.cpp) + +``` +./simple_client --client_conf="uci_housing_client/serving_client_conf.prototxt" --server_port="127.0.0.1:9393" --test_type="brpc" --sample_type="fit_a_line" +``` + +更多示例详见[simple_client.cpp](./example/simple_client.cpp) + +| Argument | Type | Default | Description | +| ---------------------------------------------- | ---- | ------------------------------------ | ----------------------------------------------------- | +| `client_conf` | str | `"serving_client_conf.prototxt"` | Path of client conf | +| `server_port` | str | `"127.0.0.1:9393"` | Exposed ip:port of server | +| `test_type` | str | `"brpc"` | Mode of request "brpc" | +| `sample_type` | str | `"fit_a_line"` | Type of sample include "fit_a_line,bert" | diff --git a/core/general-client/example/simple_client.cpp b/core/general-client/example/simple_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1052c346f66569d36a4e7cddbe73ca4f70cbd9e --- /dev/null +++ b/core/general-client/example/simple_client.cpp @@ -0,0 +1,129 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "core/general-client/include/brpc_client.h" + +using baidu::paddle_serving::client::ServingClient; +using baidu::paddle_serving::client::ServingBrpcClient; +using baidu::paddle_serving::client::PredictorInputs; +using baidu::paddle_serving::client::PredictorOutputs; + +DEFINE_string(server_port, "127.0.0.1:9292", "ip:port"); +DEFINE_string(client_conf, "serving_client_conf.prototxt", "Path of client conf"); +DEFINE_string(test_type, "brpc", "brpc"); +// fit_a_line, bert +DEFINE_string(sample_type, "fit_a_line", "List: fit_a_line, bert"); + +namespace { +int prepare_fit_a_line(PredictorInputs& input, std::vector& fetch_name) { + std::vector float_feed = {0.0137f, -0.1136f, 0.2553f, -0.0692f, + 0.0582f, -0.0727f, -0.1583f, -0.0584f, + 0.6283f, 0.4919f, 0.1856f, 0.0795f, -0.0332f}; + std::vector float_shape = {1, 13}; + std::string feed_name = "x"; + fetch_name = {"price"}; + std::vector lod; + input.add_float_data(float_feed, feed_name, float_shape, lod); + return 0; +} + +int prepare_bert(PredictorInputs& input, std::vector& fetch_name) { + { + std::vector float_feed(128, 0.0f); + float_feed[0] = 1.0f; + std::vector float_shape = {1, 128, 1}; + std::string feed_name = "input_mask"; + std::vector lod; + input.add_float_data(float_feed, feed_name, float_shape, lod); + } + { + std::vector feed(128, 0); + std::vector shape = {1, 128, 1}; + std::string feed_name = "position_ids"; + std::vector lod; + input.add_int64_data(feed, feed_name, shape, lod); + } + { + std::vector feed(128, 0); + feed[0] = 101; + std::vector shape = {1, 128, 1}; + std::string feed_name = "input_ids"; + std::vector lod; + input.add_int64_data(feed, feed_name, shape, lod); + } + { + std::vector feed(128, 0); + std::vector shape = {1, 128, 1}; + std::string feed_name = "segment_ids"; + std::vector lod; + input.add_int64_data(feed, feed_name, shape, lod); + } + + fetch_name = {"pooled_output"}; + return 0; +} +} // namespace + +int main(int argc, char* argv[]) { + + google::ParseCommandLineFlags(&argc, &argv, true); + std::string url = FLAGS_server_port; + std::string conf = FLAGS_client_conf; + std::string test_type = FLAGS_test_type; + std::string sample_type = FLAGS_sample_type; + LOG(INFO) << "url = " << url << ";" + << "client_conf = " << conf << ";" + << "test_type = " << test_type + << "sample_type = " << sample_type; + std::unique_ptr client; + // default type is brpc + // will add grpc&http in the future + if (test_type == "brpc") { + client.reset(new ServingBrpcClient()); + } else { + client.reset(new ServingBrpcClient()); + } + std::vector confs; + confs.push_back(conf); + if (client->init(confs, url) != 0) { + LOG(ERROR) << "Failed to init client!"; + return 0; + } + + PredictorInputs input; + PredictorOutputs output; + std::vector fetch_name; + + if (sample_type == "fit_a_line") { + prepare_fit_a_line(input, fetch_name); + } + else if (sample_type == "bert") { + prepare_bert(input, fetch_name); + } + else { + prepare_fit_a_line(input, fetch_name); + } + + if (client->predict(input, output, fetch_name, 0) != 0) { + LOG(ERROR) << "Failed to predict!"; + } + else { + LOG(INFO) << output.print(); + } + + return 0; +} diff --git a/core/general-client/include/brpc_client.h b/core/general-client/include/brpc_client.h new file mode 100644 index 0000000000000000000000000000000000000000..05fc23e89950f92307fa4c887be82c5023ebc368 --- /dev/null +++ b/core/general-client/include/brpc_client.h @@ -0,0 +1,50 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "core/general-client/include/client.h" +#include "core/sdk-cpp/include/predictor_sdk.h" +using baidu::paddle_serving::sdk_cpp::Predictor; +using baidu::paddle_serving::sdk_cpp::PredictorApi; + +namespace baidu { +namespace paddle_serving { +namespace client { + +class ServingBrpcClient : public ServingClient { + public: + + ServingBrpcClient() {}; + + ~ServingBrpcClient() {}; + + virtual int connect(const std::string server_port); + + int predict(const PredictorInputs& inputs, + PredictorOutputs& outputs, + const std::vector& fetch_name, + const uint64_t log_id); + + private: + // generate default SDKConf + std::string gen_desc(const std::string server_port); + + private: + PredictorApi _api; + Predictor* _predictor; +}; + +} // namespace client +} // namespace paddle_serving +} // namespace baidu \ No newline at end of file diff --git a/core/general-client/include/client.h b/core/general-client/include/client.h new file mode 100644 index 0000000000000000000000000000000000000000..689732c512fcb7612cbd3af025a470f4cbfc84fe --- /dev/null +++ b/core/general-client/include/client.h @@ -0,0 +1,257 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include + +namespace baidu { +namespace paddle_serving { +namespace predictor { + namespace general_model { + class Request; + class Response; + } +} +namespace client { + +class PredictorInputs; +class PredictorOutputs; + +class ServingClient { + public: + ServingClient() {}; + + virtual ~ServingClient() = default; + + int init(const std::vector& client_conf, + const std::string server_port); + + int load_client_config(const std::vector& client_conf); + + virtual int connect(const std::string server_port) = 0; + + virtual int predict(const PredictorInputs& inputs, + PredictorOutputs& outputs, + const std::vector& fetch_name, + const uint64_t log_id) = 0; + + protected: + std::map _feed_name_to_idx; + std::vector _feed_name; + std::map _fetch_name_to_idx; + std::map _fetch_name_to_var_name; + std::map _fetch_name_to_type; + std::vector> _shape; + std::vector _type; + std::vector _last_request_ts; +}; + +class PredictorData { + public: + PredictorData() {}; + virtual ~PredictorData() {}; + + void add_float_data(const std::vector& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype = 1); + + void add_int64_data(const std::vector& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype = 0); + + void add_int32_data(const std::vector& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype = 2); + + void add_string_data(const std::string& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype = 3); + + const std::map>& float_data_map() const { + return _float_data_map; + }; + + std::map>* mutable_float_data_map() { + return &_float_data_map; + }; + + const std::map>& int64_data_map() const { + return _int64_data_map; + }; + + std::map>* mutable_int64_data_map() { + return &_int64_data_map; + }; + + const std::map>& int_data_map() const { + return _int32_data_map; + }; + + std::map>* mutable_int_data_map() { + return &_int32_data_map; + }; + + const std::map& string_data_map() const { + return _string_data_map; + }; + + std::map* mutable_string_data_map() { + return &_string_data_map; + }; + + const std::map>& shape_map() const { + return _shape_map; + }; + + std::map>* mutable_shape_map() { + return &_shape_map; + }; + + const std::map>& lod_map() const { + return _lod_map; + }; + + std::map>* mutable_lod_map() { + return &_lod_map; + }; + + int get_datatype(std::string name) const; + + std::string print(); + + private: + // used to print vector data map e.g. _float_data_map + template + std::string map2string(const std::map>& map) { + std::ostringstream oss; + oss.str(""); + oss.precision(6); + oss.setf(std::ios::fixed); + std::string key_seg = ":"; + std::string val_seg = ","; + std::string end_seg = "\n"; + typename std::map>::const_iterator it = map.begin(); + typename std::map>::const_iterator itEnd = map.end(); + for (; it != itEnd; it++) { + oss << "{"; + oss << it->first << key_seg; + const std::vector& v = it->second; + for (size_t i = 0; i < v.size(); ++i) { + if (i != v.size() - 1) { + oss << v[i] << val_seg; + } + else { + oss << v[i]; + } + } + oss << "}"; + } + return oss.str(); + }; + + // used to print data map without vector e.g. _string_data_map + template + std::string map2string(const std::map& map) { + std::ostringstream oss; + oss.str(""); + std::string key_seg = ":"; + std::string val_seg = ","; + std::string end_seg = "\n"; + typename std::map::const_iterator it = map.begin(); + typename std::map::const_iterator itEnd = map.end(); + for (; it != itEnd; it++) { + oss << "{"; + oss << it->first << key_seg << it->second; + oss << "}"; + } + return oss.str(); + }; + + protected: + std::map> _float_data_map; + std::map> _int64_data_map; + std::map> _int32_data_map; + std::map _string_data_map; + std::map> _shape_map; + std::map> _lod_map; + std::map _datatype_map; +}; + +class PredictorInputs : public PredictorData { + public: + PredictorInputs() {}; + virtual ~PredictorInputs() {}; + + // generate proto from inputs + // feed_name_to_idx: mapping alias name to idx + // feed_name: mapping idx to name + static int GenProto(const PredictorInputs& inputs, + const std::map& feed_name_to_idx, + const std::vector& feed_name, + predictor::general_model::Request& req); +}; + +class PredictorOutputs { + public: + struct PredictorOutput { + std::string engine_name; + PredictorData data; + }; + + PredictorOutputs() {}; + virtual ~PredictorOutputs() {}; + + const std::vector>& datas() { + return _datas; + }; + + std::vector>* mutable_datas() { + return &_datas; + }; + + void add_data(const std::shared_ptr& data) { + _datas.push_back(data); + }; + + std::string print(); + + void clear(); + + // Parse proto to outputs + // fetch_name: name of data to be output + // fetch_name_to_type: mapping of fetch_name to datatype + static int ParseProto(const predictor::general_model::Response& res, + const std::vector& fetch_name, + std::map& fetch_name_to_type, + PredictorOutputs& outputs); + + protected: + std::vector> _datas; +}; + +} // namespace client +} // namespace paddle_serving +} // namespace baidu \ No newline at end of file diff --git a/core/general-client/src/brpc_client.cpp b/core/general-client/src/brpc_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ad459604f3866b6aa1088e22de599071ebbae665 --- /dev/null +++ b/core/general-client/src/brpc_client.cpp @@ -0,0 +1,172 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-client/include/brpc_client.h" +#include "core/sdk-cpp/include/common.h" +#include "core/util/include/timer.h" +#include "core/sdk-cpp/builtin_format.pb.h" +#include "core/sdk-cpp/general_model_service.pb.h" +DEFINE_bool(profile_client, false, ""); +DEFINE_bool(profile_server, false, ""); +#define BRPC_MAX_BODY_SIZE 512 * 1024 * 1024 + +namespace baidu { +namespace paddle_serving { +namespace client { + +using baidu::paddle_serving::Timer; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Tensor; + +using configure::SDKConf; +using configure::VariantConf; +using configure::Predictor; +using configure::VariantConf; + +int ServingBrpcClient::connect(const std::string server_port) { + brpc::fLU64::FLAGS_max_body_size = BRPC_MAX_BODY_SIZE; + if (_api.create(gen_desc(server_port)) != 0) { + LOG(ERROR) << "Predictor Creation Failed"; + return -1; + } + // _api.thrd_initialize(); + return 0; +} + +std::string ServingBrpcClient::gen_desc(const std::string server_port) { + // default config for brpc + SDKConf sdk_conf; + + Predictor* predictor = sdk_conf.add_predictors(); + predictor->set_name("general_model"); + predictor->set_service_name("baidu.paddle_serving.predictor.general_model.GeneralModelService"); + predictor->set_endpoint_router("WeightedRandomRender"); + predictor->mutable_weighted_random_render_conf()->set_variant_weight_list("100"); + VariantConf* predictor_var = predictor->add_variants(); + predictor_var->set_tag("default_tag_1"); + std::string cluster = "list://" + server_port; + predictor_var->mutable_naming_conf()->set_cluster(cluster); + + VariantConf* var = sdk_conf.mutable_default_variant_conf(); + var->set_tag("default"); + var->mutable_connection_conf()->set_connect_timeout_ms(2000); + var->mutable_connection_conf()->set_rpc_timeout_ms(200000); + var->mutable_connection_conf()->set_connect_retry_count(2); + var->mutable_connection_conf()->set_max_connection_per_host(100); + var->mutable_connection_conf()->set_hedge_request_timeout_ms(-1); + var->mutable_connection_conf()->set_hedge_fetch_retry_count(2); + var->mutable_connection_conf()->set_connection_type("pooled"); + var->mutable_connection_conf()->set_connect_timeout_ms(2000); + + var->mutable_naming_conf()->set_cluster_filter_strategy("Default"); + var->mutable_naming_conf()->set_load_balance_strategy("la"); + + var->mutable_rpc_parameter()->set_compress_type(0); + var->mutable_rpc_parameter()->set_package_size(20); + var->mutable_rpc_parameter()->set_protocol("baidu_std"); + var->mutable_rpc_parameter()->set_max_channel_per_request(3); + + return sdk_conf.SerializePartialAsString(); +} + +int ServingBrpcClient::predict(const PredictorInputs& inputs, + PredictorOutputs& outputs, + const std::vector& fetch_name, + const uint64_t log_id) { + Timer timeline; + int64_t preprocess_start = timeline.TimeStampUS(); + // thread initialize for StubTLS + _api.thrd_initialize(); + std::string variant_tag; + // predictor is bound to request with brpc::Controller + _predictor = _api.fetch_predictor("general_model", &variant_tag); + if (_predictor == NULL) { + LOG(ERROR) << "Failed fetch predictor so predict error!"; + return -1; + } + // predict_res_batch.set_variant_tag(variant_tag); + VLOG(2) << "fetch general model predictor done."; + VLOG(2) << "variant_tag:" << variant_tag; + VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size; + Request req; + req.set_log_id(log_id); + for (auto &name : fetch_name) { + req.add_fetch_var_names(name); + } + + if (PredictorInputs::GenProto(inputs, _feed_name_to_idx, _feed_name, req) != 0) { + LOG(ERROR) << "Failed to preprocess req!"; + return -1; + } + + int64_t preprocess_end = timeline.TimeStampUS(); + int64_t client_infer_start = timeline.TimeStampUS(); + Response res; + + int64_t client_infer_end = 0; + int64_t postprocess_start = 0; + int64_t postprocess_end = 0; + + if (FLAGS_profile_client) { + if (FLAGS_profile_server) { + req.set_profile_server(true); + } + } + + res.Clear(); + if (_predictor->inference(&req, &res) != 0) { + LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); + return -1; + } + + client_infer_end = timeline.TimeStampUS(); + postprocess_start = client_infer_end; + if (PredictorOutputs::ParseProto(res, fetch_name, _fetch_name_to_type, outputs) != 0) { + LOG(ERROR) << "Failed to post_process res!"; + return -1; + } + postprocess_end = timeline.TimeStampUS(); + + if (FLAGS_profile_client) { + std::ostringstream oss; + oss << "PROFILE\t" + << "pid:" << getpid() << "\t" + << "prepro_0:" << preprocess_start << " " + << "prepro_1:" << preprocess_end << " " + << "client_infer_0:" << client_infer_start << " " + << "client_infer_1:" << client_infer_end << " "; + if (FLAGS_profile_server) { + int op_num = res.profile_time_size() / 2; + for (int i = 0; i < op_num; ++i) { + oss << "op" << i << "_0:" << res.profile_time(i * 2) << " "; + oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " "; + } + } + + oss << "postpro_0:" << postprocess_start << " "; + oss << "postpro_1:" << postprocess_end; + + fprintf(stderr, "%s\n", oss.str().c_str()); + } + + // release predictor + _api.thrd_clear(); + + return 0; +} + +} // namespace general_model +} // namespace paddle_serving +} // namespace baidu diff --git a/core/general-client/src/client.cpp b/core/general-client/src/client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56fb1cd1d53ba04d9d071e778594635e5e3cba6d --- /dev/null +++ b/core/general-client/src/client.cpp @@ -0,0 +1,416 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/general-client/include/client.h" +#include "core/sdk-cpp/include/common.h" +#include "core/sdk-cpp/general_model_service.pb.h" + +namespace baidu { +namespace paddle_serving { +namespace client { +using configure::GeneralModelConfig; +using baidu::paddle_serving::predictor::general_model::Request; +using baidu::paddle_serving::predictor::general_model::Response; +using baidu::paddle_serving::predictor::general_model::Tensor; +enum ProtoDataType { P_INT64, P_FLOAT32, P_INT32, P_STRING }; + +int ServingClient::init(const std::vector& client_conf, + const std::string server_port) { + if (load_client_config(client_conf) != 0) { + LOG(ERROR) << "Failed to load client config"; + return -1; + } + + // pure virtual func, subclass implementation + if (connect(server_port) != 0) { + LOG(ERROR) << "Failed to connect"; + return -1; + } + + return 0; +} + +int ServingClient::load_client_config(const std::vector &conf_file) { + try { + GeneralModelConfig model_config; + if (configure::read_proto_conf(conf_file[0].c_str(), &model_config) != 0) { + LOG(ERROR) << "Failed to load general model config" + << ", file path: " << conf_file[0]; + return -1; + } + + _feed_name_to_idx.clear(); + _fetch_name_to_idx.clear(); + _shape.clear(); + int feed_var_num = model_config.feed_var_size(); + _feed_name.clear(); + VLOG(2) << "feed var num: " << feed_var_num; + for (int i = 0; i < feed_var_num; ++i) { + _feed_name_to_idx[model_config.feed_var(i).alias_name()] = i; + VLOG(2) << "feed [" << i << "]" + << " name: " << model_config.feed_var(i).name(); + _feed_name.push_back(model_config.feed_var(i).name()); + VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name() + << " index: " << i; + std::vector tmp_feed_shape; + VLOG(2) << "feed" + << "[" << i << "] shape:"; + for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) { + tmp_feed_shape.push_back(model_config.feed_var(i).shape(j)); + VLOG(2) << "shape[" << j << "]: " << model_config.feed_var(i).shape(j); + } + _type.push_back(model_config.feed_var(i).feed_type()); + VLOG(2) << "feed" + << "[" << i + << "] feed type: " << model_config.feed_var(i).feed_type(); + _shape.push_back(tmp_feed_shape); + } + + if (conf_file.size() > 1) { + model_config.Clear(); + if (configure::read_proto_conf(conf_file[conf_file.size() - 1].c_str(), + &model_config) != 0) { + LOG(ERROR) << "Failed to load general model config" + << ", file path: " << conf_file[conf_file.size() - 1]; + return -1; + } + } + int fetch_var_num = model_config.fetch_var_size(); + VLOG(2) << "fetch_var_num: " << fetch_var_num; + for (int i = 0; i < fetch_var_num; ++i) { + _fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i; + VLOG(2) << "fetch [" << i << "]" + << " alias name: " << model_config.fetch_var(i).alias_name(); + _fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] = + model_config.fetch_var(i).name(); + _fetch_name_to_type[model_config.fetch_var(i).alias_name()] = + model_config.fetch_var(i).fetch_type(); + } + } catch (std::exception &e) { + LOG(ERROR) << "Failed load general model config" << e.what(); + return -1; + } + return 0; +} + +void PredictorData::add_float_data(const std::vector& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype) { + _float_data_map[name] = data; + _shape_map[name] = shape; + _lod_map[name] = lod; + _datatype_map[name] = datatype; +} + +void PredictorData::add_int64_data(const std::vector& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype) { + _int64_data_map[name] = data; + _shape_map[name] = shape; + _lod_map[name] = lod; + _datatype_map[name] = datatype; +} + +void PredictorData::add_int32_data(const std::vector& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype) { + _int32_data_map[name] = data; + _shape_map[name] = shape; + _lod_map[name] = lod; + _datatype_map[name] = datatype; +} + +void PredictorData::add_string_data(const std::string& data, + const std::string& name, + const std::vector& shape, + const std::vector& lod, + const int datatype) { + _string_data_map[name] = data; + _shape_map[name] = shape; + _lod_map[name] = lod; + _datatype_map[name] = datatype; +} + +int PredictorData::get_datatype(std::string name) const { + std::map::const_iterator it = _datatype_map.find(name); + if (it != _datatype_map.end()) { + return it->second; + } + return 0; +} + +std::string PredictorData::print() { + std::string res; + res.append(map2string(_float_data_map)); + res.append(map2string(_int64_data_map)); + res.append(map2string(_int32_data_map)); + res.append(map2string(_string_data_map)); + return res; +} + +int PredictorInputs::GenProto(const PredictorInputs& inputs, + const std::map& feed_name_to_idx, + const std::vector& feed_name, + Request& req) { + const std::map>& float_feed_map = inputs.float_data_map(); + const std::map>& int64_feed_map = inputs.int64_data_map(); + const std::map>& int32_feed_map = inputs.int_data_map(); + const std::map& string_feed_map = inputs.string_data_map(); + const std::map>& shape_map = inputs.shape_map(); + const std::map>& lod_map = inputs.lod_map(); + + VLOG(2) << "float feed name size: " << float_feed_map.size(); + VLOG(2) << "int feed name size: " << int64_feed_map.size(); + VLOG(2) << "string feed name size: " << string_feed_map.size(); + + // batch is already in Tensor. + + for (std::map>::const_iterator iter = float_feed_map.begin(); + iter != float_feed_map.end(); + ++iter) { + std::string name = iter->first; + const std::vector& float_data = iter->second; + const std::vector& float_shape = shape_map.at(name); + const std::vector& float_lod = lod_map.at(name); + // default datatype = P_FLOAT32 + int datatype = inputs.get_datatype(name); + std::map::const_iterator feed_name_it = feed_name_to_idx.find(name); + if (feed_name_it == feed_name_to_idx.end()) { + LOG(ERROR) << "Do not find [" << name << "] in feed_map!"; + return -1; + } + int idx = feed_name_to_idx.at(name); + VLOG(2) << "prepare float feed " << name << " idx " << idx; + int total_number = float_data.size(); + Tensor *tensor = req.add_tensor(); + + VLOG(2) << "prepare float feed " << name << " shape size " + << float_shape.size(); + for (uint32_t j = 0; j < float_shape.size(); ++j) { + tensor->add_shape(float_shape[j]); + } + for (uint32_t j = 0; j < float_lod.size(); ++j) { + tensor->add_lod(float_lod[j]); + } + tensor->set_elem_type(datatype); + + tensor->set_name(feed_name[idx]); + tensor->set_alias_name(name); + + tensor->mutable_float_data()->Resize(total_number, 0); + memcpy(tensor->mutable_float_data()->mutable_data(), float_data.data(), total_number * sizeof(float)); + } + + for (std::map>::const_iterator iter = int64_feed_map.begin(); + iter != int64_feed_map.end(); + ++iter) { + std::string name = iter->first; + const std::vector& int64_data = iter->second; + const std::vector& int64_shape = shape_map.at(name); + const std::vector& int64_lod = lod_map.at(name); + // default datatype = P_INT64 + int datatype = inputs.get_datatype(name); + std::map::const_iterator feed_name_it = feed_name_to_idx.find(name); + if (feed_name_it == feed_name_to_idx.end()) { + LOG(ERROR) << "Do not find [" << name << "] in feed_map!"; + return -1; + } + int idx = feed_name_to_idx.at(name); + Tensor *tensor = req.add_tensor(); + int total_number = int64_data.size(); + + for (uint32_t j = 0; j < int64_shape.size(); ++j) { + tensor->add_shape(int64_shape[j]); + } + for (uint32_t j = 0; j < int64_lod.size(); ++j) { + tensor->add_lod(int64_lod[j]); + } + tensor->set_elem_type(datatype); + tensor->set_name(feed_name[idx]); + tensor->set_alias_name(name); + + tensor->mutable_int64_data()->Resize(total_number, 0); + memcpy(tensor->mutable_int64_data()->mutable_data(), int64_data.data(), total_number * sizeof(int64_t)); + } + + for (std::map>::const_iterator iter = int32_feed_map.begin(); + iter != int32_feed_map.end(); + ++iter) { + std::string name = iter->first; + const std::vector& int32_data = iter->second; + const std::vector& int32_shape = shape_map.at(name); + const std::vector& int32_lod = lod_map.at(name); + // default datatype = P_INT32 + int datatype = inputs.get_datatype(name); + std::map::const_iterator feed_name_it = feed_name_to_idx.find(name); + if (feed_name_it == feed_name_to_idx.end()) { + LOG(ERROR) << "Do not find [" << name << "] in feed_map!"; + return -1; + } + int idx = feed_name_to_idx.at(name); + Tensor *tensor = req.add_tensor(); + int total_number = int32_data.size(); + + for (uint32_t j = 0; j < int32_shape.size(); ++j) { + tensor->add_shape(int32_shape[j]); + } + for (uint32_t j = 0; j < int32_lod.size(); ++j) { + tensor->add_lod(int32_lod[j]); + } + tensor->set_elem_type(datatype); + tensor->set_name(feed_name[idx]); + tensor->set_alias_name(name); + + tensor->mutable_int_data()->Resize(total_number, 0); + memcpy(tensor->mutable_int_data()->mutable_data(), int32_data.data(), total_number * sizeof(int32_t)); + } + + for (std::map::const_iterator iter = string_feed_map.begin(); + iter != string_feed_map.end(); + ++iter) { + std::string name = iter->first; + const std::string& string_data = iter->second; + const std::vector& string_shape = shape_map.at(name); + const std::vector& string_lod = lod_map.at(name); + // default datatype = P_STRING + int datatype = inputs.get_datatype(name); + std::map::const_iterator feed_name_it = feed_name_to_idx.find(name); + if (feed_name_it == feed_name_to_idx.end()) { + LOG(ERROR) << "Do not find [" << name << "] in feed_map!"; + return -1; + } + int idx = feed_name_to_idx.at(name); + Tensor *tensor = req.add_tensor(); + + for (uint32_t j = 0; j < string_shape.size(); ++j) { + tensor->add_shape(string_shape[j]); + } + for (uint32_t j = 0; j < string_lod.size(); ++j) { + tensor->add_lod(string_lod[j]); + } + tensor->set_elem_type(datatype); + tensor->set_name(feed_name[idx]); + tensor->set_alias_name(name); + + const int string_shape_size = string_shape.size(); + // string_shape[vec_idx] = [1];cause numpy has no datatype of string. + // we pass string via vector >. + if (string_shape_size != 1) { + LOG(ERROR) << "string_shape_size should be 1-D, but received is : " + << string_shape_size; + return -1; + } + switch (string_shape_size) { + case 1: { + tensor->add_data(string_data); + break; + } + } + } + return 0; +} + +std::string PredictorOutputs::print() { + std::string res = ""; + for (size_t i = 0; i < _datas.size(); ++i) { + res.append(_datas[i]->engine_name); + res.append(":"); + res.append(_datas[i]->data.print()); + res.append("\n"); + } + return res; +} + +void PredictorOutputs::clear() { + _datas.clear(); +} + +int PredictorOutputs::ParseProto(const Response& res, + const std::vector& fetch_name, + std::map& fetch_name_to_type, + PredictorOutputs& outputs) { + VLOG(2) << "get model output num"; + uint32_t model_num = res.outputs_size(); + VLOG(2) << "model num: " << model_num; + for (uint32_t m_idx = 0; m_idx < model_num; ++m_idx) { + VLOG(2) << "process model output index: " << m_idx; + auto& output = res.outputs(m_idx); + std::shared_ptr predictor_output = + std::make_shared(); + predictor_output->engine_name = output.engine_name(); + std::map>& float_data_map = *predictor_output->data.mutable_float_data_map(); + std::map>& int64_data_map = *predictor_output->data.mutable_int64_data_map(); + std::map>& int32_data_map = *predictor_output->data.mutable_int_data_map(); + std::map& string_data_map = *predictor_output->data.mutable_string_data_map(); + std::map>& shape_map = *predictor_output->data.mutable_shape_map(); + std::map>& lod_map = *predictor_output->data.mutable_lod_map(); + + int idx = 0; + for (auto &name : fetch_name) { + // int idx = _fetch_name_to_idx[name]; + int shape_size = output.tensor(idx).shape_size(); + VLOG(2) << "fetch var " << name << " index " << idx << " shape size " + << shape_size; + shape_map[name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_map[name][i] = output.tensor(idx).shape(i); + } + int lod_size = output.tensor(idx).lod_size(); + if (lod_size > 0) { + lod_map[name].resize(lod_size); + for (int i = 0; i < lod_size; ++i) { + lod_map[name][i] = output.tensor(idx).lod(i); + } + } + idx += 1; + } + idx = 0; + + for (auto &name : fetch_name) { + // int idx = _fetch_name_to_idx[name]; + if (fetch_name_to_type[name] == P_INT64) { + VLOG(2) << "fetch var " << name << "type int64"; + int size = output.tensor(idx).int64_data_size(); + int64_data_map[name] = std::vector( + output.tensor(idx).int64_data().begin(), + output.tensor(idx).int64_data().begin() + size); + } else if (fetch_name_to_type[name] == P_FLOAT32) { + VLOG(2) << "fetch var " << name << "type float"; + int size = output.tensor(idx).float_data_size(); + float_data_map[name] = std::vector( + output.tensor(idx).float_data().begin(), + output.tensor(idx).float_data().begin() + size); + } else if (fetch_name_to_type[name] == P_INT32) { + VLOG(2) << "fetch var " << name << "type int32"; + int size = output.tensor(idx).int_data_size(); + int32_data_map[name] = std::vector( + output.tensor(idx).int_data().begin(), + output.tensor(idx).int_data().begin() + size); + } + idx += 1; + } + outputs.add_data(predictor_output); + } + return 0; +} + +} // namespace client +} // namespace paddle_serving +} // namespace baidu