未验证 提交 c006be5d 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #122 from MRXLT/general-model-config

add cube press doc && general client batch predict
......@@ -104,7 +104,7 @@ int run(int argc, char** argv, int thread_id) {
keys.push_back(key_list[index]);
index += 1;
int ret = 0;
if (keys.size() > FLAGS_batch) {
if (keys.size() >= FLAGS_batch) {
TIME_FLAG(seek_start);
ret = cube->seek(FLAGS_dict, keys, &values);
TIME_FLAG(seek_end);
......@@ -214,8 +214,8 @@ int run_m(int argc, char** argv) {
<< " avg = " << std::to_string(mean_time)
<< " max = " << std::to_string(max_time)
<< " min = " << std::to_string(min_time);
LOG(INFO) << " total_request = " << std::to_string(request_num)
<< " speed = " << std::to_string(1000000 * thread_num / mean_time)
LOG(INFO) << " total_request = " << std::to_string(request_num) << " speed = "
<< std::to_string(1000000 * thread_num / mean_time) // mean_time us
<< " query per second";
}
......
......@@ -2,6 +2,9 @@
## 机器配置
Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz
## 测试方法
请参考[压测文档](./press.md)
## 测试数据
100w条样例kv 数据。 key为uint_64类型,单条value长度 40 Byte (一般实际场景对应一个10维特征向量)。
......
# cube压测文档
参考[大规模稀疏参数服务Cube的部署和使用](https://github.com/PaddlePaddle/Serving/blob/master/doc/DEPLOY.md#2-大规模稀疏参数服务cube的部署和使用)文档进行cube的部署。
压测工具链接:
https://paddle-serving.bj.bcebos.com/data/cube/cube-press.tar.gz
将压缩包解压,cube-press目录下包含了单机场景和分布式场景下client端压测所使用的脚本、可执行文件、配置文件、样例数据以及数据生成脚本。
其中,keys为client要读取用来查询的key值文件,feature为cube要加载的key-value文件,本次测试中使用的数据,key的范围为0~999999。
## 单机场景
在单个物理机部署cube服务,使用genernate_input.py脚本生成测数据,执行test.sh脚本启动cube client端向cube server发送请求。
genernate_input.py脚本接受1个参数,示例:
```bash
python genernate_input.py 1
```
参数表示生成的数据每一行含有多少个key,即test.sh执行的查询操作中的batch_size。
test.sh脚本接受3个参数,示例:
```bash
sh test.sh 1 127.0.0.1:8027 100000
```
第一个参数为并发数,第二个参数为cube server的ip与端口,第三个参数为qps。
输出:
脚本会进行9次压测,每次发送10次请求,每次请求耗时1秒,每次压测会打印出平均延时以及不同分位数的延时。
**注意事项:**
cube压测对于机器的网卡要求较高,高QPS的情况下单个client可能无法承受,可以采用两个或多个client,将查询请求进行平均。
如果执行test.sh出现问题需要停止,可以执行kill_rpc_press.sh
## 分布式场景
编译paddle serving完成后,分布式压测工具的客户端路径为 build/core/cube/cube-api/cube-cli,对应的源代码为core/cube/cube-api/src/cube_cli.cpp
在多台机器上部署cube服务,使用client_cli进行性能测试。
**注意事项:**
cube服务部署时的分片数和副本数会对性能造成影响,相同数据的条件下,分片数和副本数越多,性能越好,实际提升程度与数据相关。
使用方法:
```shell
./cube-cli --batch 500 --keys keys --dict dict --thread_num 1
```
接受的参数:
--batch指定每次请求的batch size。
--keys指定查询用的文件,文件中每一行为1个key。
--dict指定要查询的cube词典名。
--thread_num指定client端线程数
输出:
每个线程的查询的平均时间、最大时间、最小时间
进程中所有线程的查询的平均时间的平均值、最大值、最小值
进程中所有线程的总请求数、QPS
......@@ -17,10 +17,11 @@
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <fstream>
#include <map>
#include <string>
#include <vector>
#include <map>
#include "core/sdk-cpp/builtin_format.pb.h"
#include "core/sdk-cpp/general_model_service.pb.h"
......@@ -37,46 +38,51 @@ namespace general_model {
typedef std::map<std::string, std::vector<float>> FetchedMap;
typedef std::map<std::string, std::vector<std::vector<float> > >
BatchFetchedMap;
typedef std::map<std::string, std::vector<std::vector<float>>> BatchFetchedMap;
class PredictorClient {
public:
PredictorClient() {}
~PredictorClient() {}
void init(const std::string & client_conf);
void init(const std::string& client_conf);
void set_predictor_conf(
const std::string& conf_path,
const std::string& conf_file);
void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file);
int create_predictor();
std::vector<std::vector<float> > predict(
const std::vector<std::vector<float> > & float_feed,
const std::vector<std::string> & float_feed_name,
const std::vector<std::vector<int64_t> > & int_feed,
const std::vector<std::string> & int_feed_name,
const std::vector<std::string> & fetch_name);
std::vector<std::vector<float> > predict_with_profile(
const std::vector<std::vector<float> > & float_feed,
const std::vector<std::string> & float_feed_name,
const std::vector<std::vector<int64_t> > & int_feed,
const std::vector<std::string> & int_feed_name,
const std::vector<std::string> & fetch_name);
std::vector<std::vector<float>> predict(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
std::vector<std::vector<std::vector<float>>> batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
std::vector<std::vector<float>> predict_with_profile(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
private:
PredictorApi _api;
Predictor * _predictor;
Predictor* _predictor;
std::string _predictor_conf;
std::string _predictor_path;
std::string _conf_file;
std::map<std::string, int> _feed_name_to_idx;
std::map<std::string, int> _fetch_name_to_idx;
std::map<std::string, std::string> _fetch_name_to_var_name;
std::vector<std::vector<int> > _shape;
std::vector<std::vector<int>> _shape;
std::vector<int> _type;
};
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include "core/general-client/include/general_model.h"
#include <fstream>
#include "core/sdk-cpp/builtin_format.pb.h"
#include "core/sdk-cpp/include/common.h"
#include "core/sdk-cpp/include/predictor_sdk.h"
......@@ -28,7 +28,7 @@ namespace baidu {
namespace paddle_serving {
namespace general_model {
void PredictorClient::init(const std::string & conf_file) {
void PredictorClient::init(const std::string &conf_file) {
_conf_file = conf_file;
std::ifstream fin(conf_file);
if (!fin) {
......@@ -68,9 +68,8 @@ void PredictorClient::init(const std::string & conf_file) {
}
}
void PredictorClient::set_predictor_conf(
const std::string & conf_path,
const std::string & conf_file) {
void PredictorClient::set_predictor_conf(const std::string &conf_path,
const std::string &conf_file) {
_predictor_path = conf_path;
_predictor_conf = conf_file;
}
......@@ -83,14 +82,13 @@ int PredictorClient::create_predictor() {
_api.thrd_initialize();
}
std::vector<std::vector<float> > PredictorClient::predict(
const std::vector<std::vector<float> > & float_feed,
const std::vector<std::string> & float_feed_name,
const std::vector<std::vector<int64_t> > & int_feed,
const std::vector<std::string> & int_feed_name,
const std::vector<std::string> & fetch_name) {
std::vector<std::vector<float> > fetch_result;
std::vector<std::vector<float>> PredictorClient::predict(
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
std::vector<std::vector<float>> fetch_result;
if (fetch_name.size() == 0) {
return fetch_result;
}
......@@ -100,41 +98,43 @@ std::vector<std::vector<float> > PredictorClient::predict(
_predictor = _api.fetch_predictor("general_model");
Request req;
std::vector<Tensor *> tensor_vec;
FeedInst * inst = req.add_insts();
for (auto & name : float_feed_name) {
FeedInst *inst = req.add_insts();
for (auto &name : float_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
for (auto & name : int_feed_name) {
for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
int vec_idx = 0;
for (auto & name : float_feed_name) {
for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor * tensor = tensor_vec[idx];
Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]);
}
tensor->set_elem_type(1);
for (int j = 0; j < float_feed[vec_idx].size(); ++j) {
tensor->add_data(
(char *)(&(float_feed[vec_idx][j])), sizeof(float));
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(float_feed[vec_idx][j]))),
sizeof(float));
}
vec_idx++;
}
vec_idx = 0;
for (auto & name : int_feed_name) {
for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor * tensor = tensor_vec[idx];
Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]);
}
tensor->set_elem_type(0);
for (int j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_data(
(char *)(&(int_feed[vec_idx][j])), sizeof(int64_t));
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(int_feed[vec_idx][j]))),
sizeof(int64_t));
}
vec_idx++;
}
......@@ -147,7 +147,7 @@ std::vector<std::vector<float> > PredictorClient::predict(
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
exit(-1);
} else {
for (auto & name : fetch_name) {
for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
int len = res.insts(0).tensor_array(idx).data_size();
VLOG(3) << "fetch name: " << name;
......@@ -162,8 +162,8 @@ std::vector<std::vector<float> > PredictorClient::predict(
fetch_result[name][i] = *(const float *)
res.insts(0).tensor_array(idx).data(i).c_str();
*/
fetch_result[idx][i] = *(const float *)
res.insts(0).tensor_array(idx).data(i).c_str();
fetch_result[idx][i] =
*(const float *)res.insts(0).tensor_array(idx).data(i).c_str();
}
}
}
......@@ -171,13 +171,110 @@ std::vector<std::vector<float> > PredictorClient::predict(
return fetch_result;
}
std::vector<std::vector<float> > PredictorClient::predict_with_profile(
const std::vector<std::vector<float> > & float_feed,
const std::vector<std::string> & float_feed_name,
const std::vector<std::vector<int64_t> > & int_feed,
const std::vector<std::string> & int_feed_name,
const std::vector<std::string> & fetch_name) {
std::vector<std::vector<float> > res;
std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
const std::vector<std::vector<std::vector<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
std::vector<std::vector<std::vector<float>>> fetch_result_batch;
if (fetch_name.size() == 0) {
return fetch_result_batch;
}
fetch_result_batch.resize(batch_size);
int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num);
}
_api.thrd_clear();
_predictor = _api.fetch_predictor("general_model");
Request req;
//
for (int bi = 0; bi < batch_size; bi++) {
std::vector<Tensor *> tensor_vec;
FeedInst *inst = req.add_insts();
std::vector<std::vector<float>> float_feed = float_feed_batch[bi];
std::vector<std::vector<int64_t>> int_feed = int_feed_batch[bi];
for (auto &name : float_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
int vec_idx = 0;
for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]);
}
tensor->set_elem_type(1);
for (int j = 0; j < float_feed[vec_idx].size(); ++j) {
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(float_feed[vec_idx][j]))),
sizeof(float));
}
vec_idx++;
}
vec_idx = 0;
for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]);
}
tensor->set_elem_type(0);
VLOG(3) << "feed var name " << name << " index " << vec_idx
<< "first data " << int_feed[vec_idx][0];
for (int j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(int_feed[vec_idx][j]))),
sizeof(int64_t));
}
vec_idx++;
}
}
Response res;
res.Clear();
if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
exit(-1);
} else {
for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
int len = res.insts(bi).tensor_array(idx).data_size();
VLOG(3) << "fetch name: " << name;
VLOG(3) << "tensor data size: " << len;
fetch_result_batch[bi][idx].resize(len);
VLOG(3)
<< "fetch name " << name << " index " << idx << " first data "
<< *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str();
for (int i = 0; i < len; ++i) {
fetch_result_batch[bi][idx][i] =
*(const float *)res.insts(bi).tensor_array(idx).data(i).c_str();
}
}
}
}
return fetch_result_batch;
}
std::vector<std::vector<float>> PredictorClient::predict_with_profile(
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
std::vector<std::vector<float>> res;
return res;
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <unordered_map>
#include "core/general-client/include/general_model.h"
#include <pybind11/stl.h>
namespace py = pybind11;
using baidu::paddle_serving::general_model::FetchedMap;
......@@ -19,28 +32,45 @@ PYBIND11_MODULE(serving_client, m) {
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
.def("init",
[](PredictorClient &self, const std::string & conf) {
[](PredictorClient &self, const std::string &conf) {
self.init(conf);
})
.def("set_predictor_conf",
[](PredictorClient &self, const std::string & conf_path,
const std::string & conf_file) {
[](PredictorClient &self,
const std::string &conf_path,
const std::string &conf_file) {
self.set_predictor_conf(conf_path, conf_file);
})
.def("create_predictor",
[](PredictorClient & self) {
self.create_predictor();
})
[](PredictorClient &self) { self.create_predictor(); })
.def("predict",
[](PredictorClient &self,
const std::vector<std::vector<float> > & float_feed,
const std::vector<std::string> & float_feed_name,
const std::vector<std::vector<int64_t> > & int_feed,
const std::vector<std::string> & int_feed_name,
const std::vector<std::string> & fetch_name) {
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
return self.predict(float_feed,
float_feed_name,
int_feed,
int_feed_name,
fetch_name);
})
return self.predict(float_feed, float_feed_name,
int_feed, int_feed_name, fetch_name);
.def("batch_predict",
[](PredictorClient &self,
const std::vector<std::vector<std::vector<float>>>
&float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>>
&int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
return self.batch_predict(float_feed_batch,
float_feed_name,
int_feed_batch,
int_feed_name,
fetch_name);
});
}
......
### 使用方法
假设数据文件为test.data,配置文件为inference.conf
单进程client
```
cat test.data | python test_client.py inference.conf > result
```
多进程client,若进程数为4
```
python test_client_multithread.py inference.conf test.data 4 > result
```
batch clienit,若batch size为4
```
cat test.data | python test_client_batch.py inference.conf 4 > result
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving import Client
import sys
import subprocess
from multiprocessing import Pool
import time
def batch_predict(batch_size=4):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"])
start = time.time()
feed_batch = []
for line in sys.stdin:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
fetch = ["acc", "cost", "prediction"]
feed_batch.append(feed)
if len(feed_batch) == batch_size:
fetch_batch = client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
for i in range(batch_size):
print("{} {}".format(fetch_batch[i]["prediction"][1],
feed_batch[i]["label"][0]))
feed_batch = []
cost = time.time() - start
print("total cost : {}".format(cost))
if __name__ == '__main__':
conf_file = sys.argv[1]
batch_size = int(sys.argv[2])
batch_predict(batch_size)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_serving import Client
import sys
import subprocess
from multiprocessing import Pool
import time
def predict(p_id, p_size, data_list):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"])
result = []
for line in data_list:
group = line.strip().split()
words = [int(x) for x in group[1:int(group[0])]]
label = [int(group[-1])]
feed = {"words": words, "label": label}
fetch = ["acc", "cost", "prediction"]
fetch_map = client.predict(feed=feed, fetch=fetch)
#print("{} {}".format(fetch_map["prediction"][1], label[0]))
result.append([fetch_map["prediction"][1], label[0]])
return result
def predict_multi_thread(p_num):
data_list = []
with open(data_file) as f:
for line in f.readlines():
data_list.append(line)
start = time.time()
p = Pool(p_num)
p_size = len(data_list) / p_num
result_list = []
for i in range(p_num):
result_list.append(
p.apply_async(predict,
[i, p_size, data_list[i * p_size:(i + 1) * p_size]]))
p.close()
p.join()
for i in range(p_num):
result = result_list[i].get()
for j in result:
print("{} {}".format(j[0], j[1]))
cost = time.time() - start
print("{} threads cost {}".format(p_num, cost))
if __name__ == '__main__':
conf_file = sys.argv[1]
data_file = sys.argv[2]
p_num = int(sys.argv[3])
predict_multi_thread(p_num)
......@@ -19,6 +19,7 @@ import time
int_type = 0
float_type = 1
class SDKConfig(object):
def __init__(self):
self.sdk_desc = sdk.SDKConf()
......@@ -37,7 +38,8 @@ class SDKConfig(object):
variant_desc = sdk.VariantConf()
variant_desc.tag = "var1"
variant_desc.naming_conf.cluster = "list://{}".format(":".join(self.endpoints))
variant_desc.naming_conf.cluster = "list://{}".format(":".join(
self.endpoints))
predictor_desc.variants.extend([variant_desc])
......@@ -50,7 +52,7 @@ class SDKConfig(object):
self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"
......@@ -114,8 +116,7 @@ class Client(object):
predictor_file = "%s_predictor.conf" % timestamp
with open(predictor_path + predictor_file, "w") as fout:
fout.write(sdk_desc)
self.client_handle_.set_predictor_conf(
predictor_path, predictor_file)
self.client_handle_.set_predictor_conf(predictor_path, predictor_file)
self.client_handle_.create_predictor()
def get_feed_names(self):
......@@ -145,13 +146,52 @@ class Client(object):
fetch_names.append(key)
result = self.client_handle_.predict(
float_slot, float_feed_names,
int_slot, int_feed_names,
fetch_names)
float_slot, float_feed_names, int_slot, int_feed_names, fetch_names)
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
return result_map
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
float_feed_names = []
fetch_names = []
counter = 0
for feed in feed_batch:
int_slot = []
float_slot = []
for key in feed:
if key not in self.feed_names_:
continue
if self.feed_types_[key] == int_type:
if counter == 0:
int_feed_names.append(key)
int_slot.append(feed[key])
elif self.feed_types_[key] == float_type:
if counter == 0:
float_feed_names.append(key)
float_slot.append(feed[key])
counter += 1
int_slot_batch.append(int_slot)
float_slot_batch.append(float_slot)
for key in fetch:
if key in self.fetch_names_:
fetch_names.append(key)
result_batch = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
fetch_names)
result_map_batch = []
for result in result_batch:
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
result_map_batch.append(result_map)
return result_map_batch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册