提交 93af8d73 编写于 作者: M MRXLT

change batch predict inferface

上级 21e1cc08
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <algorithm>
#include <fstream> #include <fstream>
#include <map> #include <map>
#include <string> #include <string>
...@@ -58,13 +59,12 @@ class PredictorClient { ...@@ -58,13 +59,12 @@ class PredictorClient {
const std::vector<std::string>& int_feed_name, const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name); const std::vector<std::string>& fetch_name);
std::vector<std::vector<std::vector<float>>> predict_for_batch( std::vector<std::vector<std::vector<float>>> batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch, const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name, const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch, const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name, const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name);
const int64_t& batch_size);
std::vector<std::vector<float>> predict_with_profile( std::vector<std::vector<float>> predict_with_profile(
const std::vector<std::vector<float>>& float_feed, const std::vector<std::vector<float>>& float_feed,
......
...@@ -171,13 +171,13 @@ std::vector<std::vector<float>> PredictorClient::predict( ...@@ -171,13 +171,13 @@ std::vector<std::vector<float>> PredictorClient::predict(
return fetch_result; return fetch_result;
} }
std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch( std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
const std::vector<std::vector<std::vector<float>>> &float_feed_batch, const std::vector<std::vector<std::vector<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name, const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch, const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name) {
const int64_t &batch_size) { int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
std::vector<std::vector<std::vector<float>>> fetch_result_batch; std::vector<std::vector<std::vector<float>>> fetch_result_batch;
if (fetch_name.size() == 0) { if (fetch_name.size() == 0) {
return fetch_result_batch; return fetch_result_batch;
...@@ -229,6 +229,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch( ...@@ -229,6 +229,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch(
tensor->add_shape(_shape[idx][j]); tensor->add_shape(_shape[idx][j]);
} }
tensor->set_elem_type(0); tensor->set_elem_type(0);
VLOG(3) << "feed var name " << name << " index " << vec_idx
<< "first data " << int_feed[vec_idx][0];
for (int j = 0; j < int_feed[vec_idx].size(); ++j) { for (int j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>( tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(int_feed[vec_idx][j]))), &(int_feed[vec_idx][j]))),
...@@ -248,10 +250,13 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch( ...@@ -248,10 +250,13 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch(
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name]; int idx = _fetch_name_to_idx[name];
int len = res.insts(0).tensor_array(idx).data_size(); int len = res.insts(bi).tensor_array(idx).data_size();
VLOG(3) << "fetch name: " << name; VLOG(3) << "fetch name: " << name;
VLOG(3) << "tensor data size: " << len; VLOG(3) << "tensor data size: " << len;
fetch_result_batch[bi][idx].resize(len); fetch_result_batch[bi][idx].resize(len);
VLOG(3)
<< "fetch name " << name << " index " << idx << " first data "
<< *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str();
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
fetch_result_batch[bi][idx][i] = fetch_result_batch[bi][idx][i] =
*(const float *)res.insts(bi).tensor_array(idx).data(i).c_str(); *(const float *)res.insts(bi).tensor_array(idx).data(i).c_str();
......
...@@ -57,7 +57,7 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -57,7 +57,7 @@ PYBIND11_MODULE(serving_client, m) {
fetch_name); fetch_name);
}) })
.def("predict_for_batch", .def("batch_predict",
[](PredictorClient &self, [](PredictorClient &self,
const std::vector<std::vector<std::vector<float>>> const std::vector<std::vector<std::vector<float>>>
&float_feed_batch, &float_feed_batch,
...@@ -65,14 +65,12 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -65,14 +65,12 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::vector<std::vector<int64_t>>> const std::vector<std::vector<std::vector<int64_t>>>
&int_feed_batch, &int_feed_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name) {
const int64_t &batch_size) { return self.batch_predict(float_feed_batch,
return self.predict_for_batch(float_feed_batch, float_feed_name,
float_feed_name, int_feed_batch,
int_feed_batch, int_feed_name,
int_feed_name, fetch_name);
fetch_name,
batch_size);
}); });
} }
......
...@@ -19,7 +19,7 @@ from multiprocessing import Pool ...@@ -19,7 +19,7 @@ from multiprocessing import Pool
import time import time
def predict_for_batch(batch_size=4): def batch_predict(batch_size=4):
client = Client() client = Client()
client.load_client_config(conf_file) client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"]) client.connect(["127.0.0.1:8010"])
...@@ -33,7 +33,7 @@ def predict_for_batch(batch_size=4): ...@@ -33,7 +33,7 @@ def predict_for_batch(batch_size=4):
fetch = ["acc", "cost", "prediction"] fetch = ["acc", "cost", "prediction"]
feed_batch.append(feed) feed_batch.append(feed)
if len(feed_batch) == batch_size: if len(feed_batch) == batch_size:
fetch_batch = client.predict_for_batch( fetch_batch = client.batch_predict(
feed_batch=feed_batch, fetch=fetch) feed_batch=feed_batch, fetch=fetch)
for i in range(batch_size): for i in range(batch_size):
print("{} {}".format(fetch_batch[i]["prediction"][1], print("{} {}".format(fetch_batch[i]["prediction"][1],
...@@ -47,4 +47,4 @@ def predict_for_batch(batch_size=4): ...@@ -47,4 +47,4 @@ def predict_for_batch(batch_size=4):
if __name__ == '__main__': if __name__ == '__main__':
conf_file = sys.argv[1] conf_file = sys.argv[1]
batch_size = int(sys.argv[2]) batch_size = int(sys.argv[2])
predict_for_batch(batch_size) batch_predict(batch_size)
...@@ -154,8 +154,7 @@ class Client(object): ...@@ -154,8 +154,7 @@ class Client(object):
return result_map return result_map
def predict_for_batch(self, feed_batch=[], fetch=[]): def batch_predict(self, feed_batch=[], fetch=[]):
batch_size = len(feed_batch)
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -184,9 +183,9 @@ class Client(object): ...@@ -184,9 +183,9 @@ class Client(object):
if key in self.fetch_names_: if key in self.fetch_names_:
fetch_names.append(key) fetch_names.append(key)
result_batch = self.client_handle_.predict_for_batch( result_batch = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
fetch_names, batch_size) fetch_names)
result_map_batch = [] result_map_batch = []
for result in result_batch: for result in result_batch:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册