提交 93af8d73 编写于 作者: M MRXLT

change batch predict inferface

上级 21e1cc08
......@@ -17,6 +17,7 @@
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include <fstream>
#include <map>
#include <string>
......@@ -58,13 +59,12 @@ class PredictorClient {
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
std::vector<std::vector<std::vector<float>>> predict_for_batch(
std::vector<std::vector<std::vector<float>>> batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name,
const int64_t& batch_size);
const std::vector<std::string>& fetch_name);
std::vector<std::vector<float>> predict_with_profile(
const std::vector<std::vector<float>>& float_feed,
......
......@@ -171,13 +171,13 @@ std::vector<std::vector<float>> PredictorClient::predict(
return fetch_result;
}
std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch(
std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
const std::vector<std::vector<std::vector<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
const int64_t &batch_size) {
const std::vector<std::string> &fetch_name) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
std::vector<std::vector<std::vector<float>>> fetch_result_batch;
if (fetch_name.size() == 0) {
return fetch_result_batch;
......@@ -229,6 +229,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch(
tensor->add_shape(_shape[idx][j]);
}
tensor->set_elem_type(0);
VLOG(3) << "feed var name " << name << " index " << vec_idx
<< "first data " << int_feed[vec_idx][0];
for (int j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(int_feed[vec_idx][j]))),
......@@ -248,10 +250,13 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::predict_for_batch(
for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
int len = res.insts(0).tensor_array(idx).data_size();
int len = res.insts(bi).tensor_array(idx).data_size();
VLOG(3) << "fetch name: " << name;
VLOG(3) << "tensor data size: " << len;
fetch_result_batch[bi][idx].resize(len);
VLOG(3)
<< "fetch name " << name << " index " << idx << " first data "
<< *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str();
for (int i = 0; i < len; ++i) {
fetch_result_batch[bi][idx][i] =
*(const float *)res.insts(bi).tensor_array(idx).data(i).c_str();
......
......@@ -57,7 +57,7 @@ PYBIND11_MODULE(serving_client, m) {
fetch_name);
})
.def("predict_for_batch",
.def("batch_predict",
[](PredictorClient &self,
const std::vector<std::vector<std::vector<float>>>
&float_feed_batch,
......@@ -65,14 +65,12 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::vector<std::vector<int64_t>>>
&int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
const int64_t &batch_size) {
return self.predict_for_batch(float_feed_batch,
const std::vector<std::string> &fetch_name) {
return self.batch_predict(float_feed_batch,
float_feed_name,
int_feed_batch,
int_feed_name,
fetch_name,
batch_size);
fetch_name);
});
}
......
......@@ -19,7 +19,7 @@ from multiprocessing import Pool
import time
def predict_for_batch(batch_size=4):
def batch_predict(batch_size=4):
client = Client()
client.load_client_config(conf_file)
client.connect(["127.0.0.1:8010"])
......@@ -33,7 +33,7 @@ def predict_for_batch(batch_size=4):
fetch = ["acc", "cost", "prediction"]
feed_batch.append(feed)
if len(feed_batch) == batch_size:
fetch_batch = client.predict_for_batch(
fetch_batch = client.batch_predict(
feed_batch=feed_batch, fetch=fetch)
for i in range(batch_size):
print("{} {}".format(fetch_batch[i]["prediction"][1],
......@@ -47,4 +47,4 @@ def predict_for_batch(batch_size=4):
if __name__ == '__main__':
conf_file = sys.argv[1]
batch_size = int(sys.argv[2])
predict_for_batch(batch_size)
batch_predict(batch_size)
......@@ -154,8 +154,7 @@ class Client(object):
return result_map
def predict_for_batch(self, feed_batch=[], fetch=[]):
batch_size = len(feed_batch)
def batch_predict(self, feed_batch=[], fetch=[]):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -184,9 +183,9 @@ class Client(object):
if key in self.fetch_names_:
fetch_names.append(key)
result_batch = self.client_handle_.predict_for_batch(
result_batch = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
fetch_names, batch_size)
fetch_names)
result_map_batch = []
for result in result_batch:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册