diff --git a/core/cube/cube-api/src/cube_cli.cpp b/core/cube/cube-api/src/cube_cli.cpp index 6bbbd435db0448bb6a96fe5c160f292992d813e9..7f45e436a821af3db06d4a90fe4cc8cc4a4936b6 100644 --- a/core/cube/cube-api/src/cube_cli.cpp +++ b/core/cube/cube-api/src/cube_cli.cpp @@ -104,7 +104,7 @@ int run(int argc, char** argv, int thread_id) { keys.push_back(key_list[index]); index += 1; int ret = 0; - if (keys.size() > FLAGS_batch) { + if (keys.size() >= FLAGS_batch) { TIME_FLAG(seek_start); ret = cube->seek(FLAGS_dict, keys, &values); TIME_FLAG(seek_end); @@ -214,8 +214,8 @@ int run_m(int argc, char** argv) { << " avg = " << std::to_string(mean_time) << " max = " << std::to_string(max_time) << " min = " << std::to_string(min_time); - LOG(INFO) << " total_request = " << std::to_string(request_num) - << " speed = " << std::to_string(1000000 * thread_num / mean_time) + LOG(INFO) << " total_request = " << std::to_string(request_num) << " speed = " + << std::to_string(1000000 * thread_num / mean_time) // mean_time us << " query per second"; } diff --git a/core/cube/doc/performance.md b/core/cube/doc/performance.md index 5fa297772c5fb3a0668e4a2a2f720b6f0a079ff7..c61e9eaad51be3d6e45508c6370616a146ea2fe8 100644 --- a/core/cube/doc/performance.md +++ b/core/cube/doc/performance.md @@ -2,6 +2,9 @@ ## 机器配置 Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz +## 测试方法 +请参考[压测文档](./press.md) + ## 测试数据 100w条样例kv 数据。 key为uint_64类型,单条value长度 40 Byte (一般实际场景对应一个10维特征向量)。 diff --git a/core/cube/doc/press.md b/core/cube/doc/press.md new file mode 100644 index 0000000000000000000000000000000000000000..25c326f944cc936ea3a4b0695dee57ca24518837 --- /dev/null +++ b/core/cube/doc/press.md @@ -0,0 +1,76 @@ +# cube压测文档 + +参考[大规模稀疏参数服务Cube的部署和使用](https://github.com/PaddlePaddle/Serving/blob/master/doc/DEPLOY.md#2-大规模稀疏参数服务cube的部署和使用)文档进行cube的部署。 + +压测工具链接: + +https://paddle-serving.bj.bcebos.com/data/cube/cube-press.tar.gz + +将压缩包解压,cube-press目录下包含了单机场景和分布式场景下client端压测所使用的脚本、可执行文件、配置文件、样例数据以及数据生成脚本。 + +其中,keys为client要读取用来查询的key值文件,feature为cube要加载的key-value文件,本次测试中使用的数据,key的范围为0~999999。 + +## 单机场景 + +在单个物理机部署cube服务,使用genernate_input.py脚本生成测数据,执行test.sh脚本启动cube client端向cube server发送请求。 + +genernate_input.py脚本接受1个参数,示例: + +```bash +python genernate_input.py 1 +``` + +参数表示生成的数据每一行含有多少个key,即test.sh执行的查询操作中的batch_size。 + + +test.sh脚本接受3个参数,示例: + +```bash +sh test.sh 1 127.0.0.1:8027 100000 +``` + +第一个参数为并发数,第二个参数为cube server的ip与端口,第三个参数为qps。 + +输出: + +脚本会进行9次压测,每次发送10次请求,每次请求耗时1秒,每次压测会打印出平均延时以及不同分位数的延时。 + +**注意事项:** + +cube压测对于机器的网卡要求较高,高QPS的情况下单个client可能无法承受,可以采用两个或多个client,将查询请求进行平均。 + +如果执行test.sh出现问题需要停止,可以执行kill_rpc_press.sh + +## 分布式场景 + +编译paddle serving完成后,分布式压测工具的客户端路径为 build/core/cube/cube-api/cube-cli,对应的源代码为core/cube/cube-api/src/cube_cli.cpp + +在多台机器上部署cube服务,使用client_cli进行性能测试。 + +**注意事项:** + +cube服务部署时的分片数和副本数会对性能造成影响,相同数据的条件下,分片数和副本数越多,性能越好,实际提升程度与数据相关。 + +使用方法: + +```shell +./cube-cli --batch 500 --keys keys --dict dict --thread_num 1 +``` + +接受的参数: + +--batch指定每次请求的batch size。 + +--keys指定查询用的文件,文件中每一行为1个key。 + +--dict指定要查询的cube词典名。 + +--thread_num指定client端线程数 + +输出: + +每个线程的查询的平均时间、最大时间、最小时间 + +进程中所有线程的查询的平均时间的平均值、最大值、最小值 + +进程中所有线程的总请求数、QPS diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index cec57bdd9e5ae955586693a2ad5a5143e5e4b74b..3567fbdaef75adf6dbf759056c6b4c6d062d1ca9 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -17,10 +17,11 @@ #include #include +#include #include +#include #include #include -#include #include "core/sdk-cpp/builtin_format.pb.h" #include "core/sdk-cpp/general_model_service.pb.h" @@ -37,46 +38,51 @@ namespace general_model { typedef std::map> FetchedMap; -typedef std::map > > - BatchFetchedMap; +typedef std::map>> BatchFetchedMap; class PredictorClient { public: PredictorClient() {} ~PredictorClient() {} - void init(const std::string & client_conf); + void init(const std::string& client_conf); - void set_predictor_conf( - const std::string& conf_path, - const std::string& conf_file); + void set_predictor_conf(const std::string& conf_path, + const std::string& conf_file); int create_predictor(); - std::vector > predict( - const std::vector > & float_feed, - const std::vector & float_feed_name, - const std::vector > & int_feed, - const std::vector & int_feed_name, - const std::vector & fetch_name); - - std::vector > predict_with_profile( - const std::vector > & float_feed, - const std::vector & float_feed_name, - const std::vector > & int_feed, - const std::vector & int_feed_name, - const std::vector & fetch_name); + std::vector> predict( + const std::vector>& float_feed, + const std::vector& float_feed_name, + const std::vector>& int_feed, + const std::vector& int_feed_name, + const std::vector& fetch_name); + + std::vector>> batch_predict( + const std::vector>>& float_feed_batch, + const std::vector& float_feed_name, + const std::vector>>& int_feed_batch, + const std::vector& int_feed_name, + const std::vector& fetch_name); + + std::vector> predict_with_profile( + const std::vector>& float_feed, + const std::vector& float_feed_name, + const std::vector>& int_feed, + const std::vector& int_feed_name, + const std::vector& fetch_name); private: PredictorApi _api; - Predictor * _predictor; + Predictor* _predictor; std::string _predictor_conf; std::string _predictor_path; std::string _conf_file; std::map _feed_name_to_idx; std::map _fetch_name_to_idx; std::map _fetch_name_to_var_name; - std::vector > _shape; + std::vector> _shape; std::vector _type; }; diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 2b18543c8d278eae96adf144d4b0d108e542c296..a593117db992a76f9a223cc15a768c92601dc879 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "core/general-client/include/general_model.h" +#include #include "core/sdk-cpp/builtin_format.pb.h" #include "core/sdk-cpp/include/common.h" #include "core/sdk-cpp/include/predictor_sdk.h" @@ -28,7 +28,7 @@ namespace baidu { namespace paddle_serving { namespace general_model { -void PredictorClient::init(const std::string & conf_file) { +void PredictorClient::init(const std::string &conf_file) { _conf_file = conf_file; std::ifstream fin(conf_file); if (!fin) { @@ -68,9 +68,8 @@ void PredictorClient::init(const std::string & conf_file) { } } -void PredictorClient::set_predictor_conf( - const std::string & conf_path, - const std::string & conf_file) { +void PredictorClient::set_predictor_conf(const std::string &conf_path, + const std::string &conf_file) { _predictor_path = conf_path; _predictor_conf = conf_file; } @@ -83,14 +82,13 @@ int PredictorClient::create_predictor() { _api.thrd_initialize(); } -std::vector > PredictorClient::predict( - const std::vector > & float_feed, - const std::vector & float_feed_name, - const std::vector > & int_feed, - const std::vector & int_feed_name, - const std::vector & fetch_name) { - - std::vector > fetch_result; +std::vector> PredictorClient::predict( + const std::vector> &float_feed, + const std::vector &float_feed_name, + const std::vector> &int_feed, + const std::vector &int_feed_name, + const std::vector &fetch_name) { + std::vector> fetch_result; if (fetch_name.size() == 0) { return fetch_result; } @@ -100,41 +98,43 @@ std::vector > PredictorClient::predict( _predictor = _api.fetch_predictor("general_model"); Request req; std::vector tensor_vec; - FeedInst * inst = req.add_insts(); - for (auto & name : float_feed_name) { + FeedInst *inst = req.add_insts(); + for (auto &name : float_feed_name) { tensor_vec.push_back(inst->add_tensor_array()); } - for (auto & name : int_feed_name) { + for (auto &name : int_feed_name) { tensor_vec.push_back(inst->add_tensor_array()); } int vec_idx = 0; - for (auto & name : float_feed_name) { + for (auto &name : float_feed_name) { int idx = _feed_name_to_idx[name]; - Tensor * tensor = tensor_vec[idx]; + Tensor *tensor = tensor_vec[idx]; for (int j = 0; j < _shape[idx].size(); ++j) { tensor->add_shape(_shape[idx][j]); } tensor->set_elem_type(1); for (int j = 0; j < float_feed[vec_idx].size(); ++j) { - tensor->add_data( - (char *)(&(float_feed[vec_idx][j])), sizeof(float)); + tensor->add_data(const_cast(reinterpret_cast( + &(float_feed[vec_idx][j]))), + sizeof(float)); } vec_idx++; } vec_idx = 0; - for (auto & name : int_feed_name) { + for (auto &name : int_feed_name) { int idx = _feed_name_to_idx[name]; - Tensor * tensor = tensor_vec[idx]; + Tensor *tensor = tensor_vec[idx]; for (int j = 0; j < _shape[idx].size(); ++j) { tensor->add_shape(_shape[idx][j]); } tensor->set_elem_type(0); for (int j = 0; j < int_feed[vec_idx].size(); ++j) { - tensor->add_data( - (char *)(&(int_feed[vec_idx][j])), sizeof(int64_t)); + tensor->add_data(const_cast(reinterpret_cast( + &(int_feed[vec_idx][j]))), + sizeof(int64_t)); } vec_idx++; } @@ -147,7 +147,7 @@ std::vector > PredictorClient::predict( LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); exit(-1); } else { - for (auto & name : fetch_name) { + for (auto &name : fetch_name) { int idx = _fetch_name_to_idx[name]; int len = res.insts(0).tensor_array(idx).data_size(); VLOG(3) << "fetch name: " << name; @@ -162,8 +162,8 @@ std::vector > PredictorClient::predict( fetch_result[name][i] = *(const float *) res.insts(0).tensor_array(idx).data(i).c_str(); */ - fetch_result[idx][i] = *(const float *) - res.insts(0).tensor_array(idx).data(i).c_str(); + fetch_result[idx][i] = + *(const float *)res.insts(0).tensor_array(idx).data(i).c_str(); } } } @@ -171,13 +171,110 @@ std::vector > PredictorClient::predict( return fetch_result; } -std::vector > PredictorClient::predict_with_profile( - const std::vector > & float_feed, - const std::vector & float_feed_name, - const std::vector > & int_feed, - const std::vector & int_feed_name, - const std::vector & fetch_name) { - std::vector > res; +std::vector>> PredictorClient::batch_predict( + const std::vector>> &float_feed_batch, + const std::vector &float_feed_name, + const std::vector>> &int_feed_batch, + const std::vector &int_feed_name, + const std::vector &fetch_name) { + int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); + std::vector>> fetch_result_batch; + if (fetch_name.size() == 0) { + return fetch_result_batch; + } + fetch_result_batch.resize(batch_size); + int fetch_name_num = fetch_name.size(); + for (int bi = 0; bi < batch_size; bi++) { + fetch_result_batch[bi].resize(fetch_name_num); + } + + _api.thrd_clear(); + _predictor = _api.fetch_predictor("general_model"); + Request req; + // + for (int bi = 0; bi < batch_size; bi++) { + std::vector tensor_vec; + FeedInst *inst = req.add_insts(); + std::vector> float_feed = float_feed_batch[bi]; + std::vector> int_feed = int_feed_batch[bi]; + for (auto &name : float_feed_name) { + tensor_vec.push_back(inst->add_tensor_array()); + } + + for (auto &name : int_feed_name) { + tensor_vec.push_back(inst->add_tensor_array()); + } + + int vec_idx = 0; + for (auto &name : float_feed_name) { + int idx = _feed_name_to_idx[name]; + Tensor *tensor = tensor_vec[idx]; + for (int j = 0; j < _shape[idx].size(); ++j) { + tensor->add_shape(_shape[idx][j]); + } + tensor->set_elem_type(1); + for (int j = 0; j < float_feed[vec_idx].size(); ++j) { + tensor->add_data(const_cast(reinterpret_cast( + &(float_feed[vec_idx][j]))), + sizeof(float)); + } + vec_idx++; + } + + vec_idx = 0; + for (auto &name : int_feed_name) { + int idx = _feed_name_to_idx[name]; + Tensor *tensor = tensor_vec[idx]; + for (int j = 0; j < _shape[idx].size(); ++j) { + tensor->add_shape(_shape[idx][j]); + } + tensor->set_elem_type(0); + VLOG(3) << "feed var name " << name << " index " << vec_idx + << "first data " << int_feed[vec_idx][0]; + for (int j = 0; j < int_feed[vec_idx].size(); ++j) { + tensor->add_data(const_cast(reinterpret_cast( + &(int_feed[vec_idx][j]))), + sizeof(int64_t)); + } + vec_idx++; + } + } + + Response res; + + res.Clear(); + if (_predictor->inference(&req, &res) != 0) { + LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); + exit(-1); + } else { + for (int bi = 0; bi < batch_size; bi++) { + for (auto &name : fetch_name) { + int idx = _fetch_name_to_idx[name]; + int len = res.insts(bi).tensor_array(idx).data_size(); + VLOG(3) << "fetch name: " << name; + VLOG(3) << "tensor data size: " << len; + fetch_result_batch[bi][idx].resize(len); + VLOG(3) + << "fetch name " << name << " index " << idx << " first data " + << *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str(); + for (int i = 0; i < len; ++i) { + fetch_result_batch[bi][idx][i] = + *(const float *)res.insts(bi).tensor_array(idx).data(i).c_str(); + } + } + } + } + + return fetch_result_batch; +} + +std::vector> PredictorClient::predict_with_profile( + const std::vector> &float_feed, + const std::vector &float_feed_name, + const std::vector> &int_feed, + const std::vector &int_feed_name, + const std::vector &fetch_name) { + std::vector> res; return res; } diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index 287b7e337d78f2f4ac0a11fc0334a79c53680eee..caa88acbcdc514bdcf94fbea2ee9458105d7bbd7 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -1,10 +1,23 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include #include +#include #include #include "core/general-client/include/general_model.h" -#include - namespace py = pybind11; using baidu::paddle_serving::general_model::FetchedMap; @@ -19,28 +32,45 @@ PYBIND11_MODULE(serving_client, m) { py::class_(m, "PredictorClient", py::buffer_protocol()) .def(py::init()) .def("init", - [](PredictorClient &self, const std::string & conf) { + [](PredictorClient &self, const std::string &conf) { self.init(conf); }) .def("set_predictor_conf", - [](PredictorClient &self, const std::string & conf_path, - const std::string & conf_file) { + [](PredictorClient &self, + const std::string &conf_path, + const std::string &conf_file) { self.set_predictor_conf(conf_path, conf_file); }) .def("create_predictor", - [](PredictorClient & self) { - self.create_predictor(); - }) + [](PredictorClient &self) { self.create_predictor(); }) .def("predict", [](PredictorClient &self, - const std::vector > & float_feed, - const std::vector & float_feed_name, - const std::vector > & int_feed, - const std::vector & int_feed_name, - const std::vector & fetch_name) { + const std::vector> &float_feed, + const std::vector &float_feed_name, + const std::vector> &int_feed, + const std::vector &int_feed_name, + const std::vector &fetch_name) { + return self.predict(float_feed, + float_feed_name, + int_feed, + int_feed_name, + fetch_name); + }) - return self.predict(float_feed, float_feed_name, - int_feed, int_feed_name, fetch_name); + .def("batch_predict", + [](PredictorClient &self, + const std::vector>> + &float_feed_batch, + const std::vector &float_feed_name, + const std::vector>> + &int_feed_batch, + const std::vector &int_feed_name, + const std::vector &fetch_name) { + return self.batch_predict(float_feed_batch, + float_feed_name, + int_feed_batch, + int_feed_name, + fetch_name); }); } diff --git a/python/examples/imdb/README.md b/python/examples/imdb/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8f040c409575010e784d6aca56855c297598608a --- /dev/null +++ b/python/examples/imdb/README.md @@ -0,0 +1,16 @@ +### 使用方法 + +假设数据文件为test.data,配置文件为inference.conf + +单进程client +``` +cat test.data | python test_client.py inference.conf > result +``` +多进程client,若进程数为4 +``` +python test_client_multithread.py inference.conf test.data 4 > result +``` +batch clienit,若batch size为4 +``` +cat test.data | python test_client_batch.py inference.conf 4 > result +``` diff --git a/python/examples/imdb/test_client_batch.py b/python/examples/imdb/test_client_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..bb9a1d871b91b88c4d63bd5f4fa270d3b82291bd --- /dev/null +++ b/python/examples/imdb/test_client_batch.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving import Client +import sys +import subprocess +from multiprocessing import Pool +import time + + +def batch_predict(batch_size=4): + client = Client() + client.load_client_config(conf_file) + client.connect(["127.0.0.1:8010"]) + start = time.time() + feed_batch = [] + for line in sys.stdin: + group = line.strip().split() + words = [int(x) for x in group[1:int(group[0])]] + label = [int(group[-1])] + feed = {"words": words, "label": label} + fetch = ["acc", "cost", "prediction"] + feed_batch.append(feed) + if len(feed_batch) == batch_size: + fetch_batch = client.batch_predict( + feed_batch=feed_batch, fetch=fetch) + for i in range(batch_size): + print("{} {}".format(fetch_batch[i]["prediction"][1], + feed_batch[i]["label"][0])) + feed_batch = [] + cost = time.time() - start + print("total cost : {}".format(cost)) + + +if __name__ == '__main__': + conf_file = sys.argv[1] + batch_size = int(sys.argv[2]) + batch_predict(batch_size) diff --git a/python/examples/imdb/test_client_multithread.py b/python/examples/imdb/test_client_multithread.py new file mode 100644 index 0000000000000000000000000000000000000000..770d14665cf9c3287b8274ef11ae8945f5759b6d --- /dev/null +++ b/python/examples/imdb/test_client_multithread.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving import Client +import sys +import subprocess +from multiprocessing import Pool +import time + + +def predict(p_id, p_size, data_list): + client = Client() + client.load_client_config(conf_file) + client.connect(["127.0.0.1:8010"]) + result = [] + for line in data_list: + group = line.strip().split() + words = [int(x) for x in group[1:int(group[0])]] + label = [int(group[-1])] + feed = {"words": words, "label": label} + fetch = ["acc", "cost", "prediction"] + fetch_map = client.predict(feed=feed, fetch=fetch) + #print("{} {}".format(fetch_map["prediction"][1], label[0])) + result.append([fetch_map["prediction"][1], label[0]]) + return result + + +def predict_multi_thread(p_num): + data_list = [] + with open(data_file) as f: + for line in f.readlines(): + data_list.append(line) + start = time.time() + p = Pool(p_num) + p_size = len(data_list) / p_num + result_list = [] + for i in range(p_num): + result_list.append( + p.apply_async(predict, + [i, p_size, data_list[i * p_size:(i + 1) * p_size]])) + p.close() + p.join() + for i in range(p_num): + result = result_list[i].get() + for j in result: + print("{} {}".format(j[0], j[1])) + cost = time.time() - start + print("{} threads cost {}".format(p_num, cost)) + + +if __name__ == '__main__': + conf_file = sys.argv[1] + data_file = sys.argv[2] + p_num = int(sys.argv[3]) + predict_multi_thread(p_num) diff --git a/python/paddle_serving/serving_client/__init__.py b/python/paddle_serving/serving_client/__init__.py index d12dc2b4f2604f8a0f9e02adc74e0af298f999e3..e21b5c0bdd74883d050f50275e16b2cbedf712f0 100644 --- a/python/paddle_serving/serving_client/__init__.py +++ b/python/paddle_serving/serving_client/__init__.py @@ -19,6 +19,7 @@ import time int_type = 0 float_type = 1 + class SDKConfig(object): def __init__(self): self.sdk_desc = sdk.SDKConf() @@ -37,7 +38,8 @@ class SDKConfig(object): variant_desc = sdk.VariantConf() variant_desc.tag = "var1" - variant_desc.naming_conf.cluster = "list://{}".format(":".join(self.endpoints)) + variant_desc.naming_conf.cluster = "list://{}".format(":".join( + self.endpoints)) predictor_desc.variants.extend([variant_desc]) @@ -50,7 +52,7 @@ class SDKConfig(object): self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1 self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2 self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled" - + self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default" self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la" @@ -114,8 +116,7 @@ class Client(object): predictor_file = "%s_predictor.conf" % timestamp with open(predictor_path + predictor_file, "w") as fout: fout.write(sdk_desc) - self.client_handle_.set_predictor_conf( - predictor_path, predictor_file) + self.client_handle_.set_predictor_conf(predictor_path, predictor_file) self.client_handle_.create_predictor() def get_feed_names(self): @@ -145,13 +146,52 @@ class Client(object): fetch_names.append(key) result = self.client_handle_.predict( - float_slot, float_feed_names, - int_slot, int_feed_names, - fetch_names) - + float_slot, float_feed_names, int_slot, int_feed_names, fetch_names) + result_map = {} for i, name in enumerate(fetch_names): result_map[name] = result[i] - + return result_map + def batch_predict(self, feed_batch=[], fetch=[]): + int_slot_batch = [] + float_slot_batch = [] + int_feed_names = [] + float_feed_names = [] + fetch_names = [] + counter = 0 + for feed in feed_batch: + int_slot = [] + float_slot = [] + for key in feed: + if key not in self.feed_names_: + continue + if self.feed_types_[key] == int_type: + if counter == 0: + int_feed_names.append(key) + int_slot.append(feed[key]) + elif self.feed_types_[key] == float_type: + if counter == 0: + float_feed_names.append(key) + float_slot.append(feed[key]) + counter += 1 + int_slot_batch.append(int_slot) + float_slot_batch.append(float_slot) + + for key in fetch: + if key in self.fetch_names_: + fetch_names.append(key) + + result_batch = self.client_handle_.batch_predict( + float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, + fetch_names) + + result_map_batch = [] + for result in result_batch: + result_map = {} + for i, name in enumerate(fetch_names): + result_map[name] = result[i] + result_map_batch.append(result_map) + + return result_map_batch