提交 cddc7ff5 编写于 作者: W WangXi

refine numpy predict

上级 f5fd7c9d
......@@ -151,6 +151,10 @@ class PredictorRes {
class PredictorClient {
public:
template <typename T>
using batch_numpy_t = std::vector<
std::vector<py::array_t<T, py::array::c_style | py::array::forcecast>>>;
PredictorClient() {}
~PredictorClient() {}
......@@ -178,16 +182,15 @@ class PredictorClient {
PredictorRes& predict_res_batch, // NOLINT
const int& pid);
int numpy_predict(
const std::vector<std::vector<py::array_t<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape,
const std::vector<std::vector<py::array_t<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid);
int numpy_predict(const batch_numpy_t<float>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int>>& float_shape,
const batch_numpy_t<int64_t>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::vector<int>>& int_shape,
const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid);
private:
PredictorApi _api;
......
......@@ -335,10 +335,10 @@ int PredictorClient::batch_predict(
}
int PredictorClient::numpy_predict(
const std::vector<std::vector<py::array_t<float>>> &float_feed_batch,
const PredictorClient::batch_numpy_t<float> &float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<py::array_t<int64_t>>> &int_feed_batch,
const PredictorClient::batch_numpy_t<int64_t> &int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name,
......@@ -369,8 +369,8 @@ int PredictorClient::numpy_predict(
VLOG(2) << "prepare batch " << bi;
std::vector<Tensor *> tensor_vec;
FeedInst *inst = req.add_insts();
std::vector<py::array_t<float>> float_feed = float_feed_batch[bi];
std::vector<py::array_t<int64_t>> int_feed = int_feed_batch[bi];
auto float_feed = float_feed_batch[bi];
auto int_feed = int_feed_batch[bi];
for (auto &name : float_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
......@@ -390,52 +390,24 @@ int PredictorClient::numpy_predict(
Tensor *tensor = tensor_vec[idx];
VLOG(2) << "prepare float feed " << name << " shape size "
<< float_shape[vec_idx].size();
auto size = 1;
for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) {
tensor->add_shape(float_shape[vec_idx][j]);
size *= float_shape[vec_idx][j];
}
tensor->set_elem_type(1);
const int float_shape_size = float_shape[vec_idx].size();
switch (float_shape_size) {
case 4: {
auto float_array = float_feed[vec_idx].unchecked<4>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
for (ssize_t j = 0; j < float_array.shape(1); j++) {
for (ssize_t k = 0; k < float_array.shape(2); k++) {
for (ssize_t l = 0; l < float_array.shape(3); l++) {
tensor->add_float_data(float_array(i, j, k, l));
}
}
}
}
break;
}
case 3: {
auto float_array = float_feed[vec_idx].unchecked<3>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
for (ssize_t j = 0; j < float_array.shape(1); j++) {
for (ssize_t k = 0; k < float_array.shape(2); k++) {
tensor->add_float_data(float_array(i, j, k));
}
}
}
break;
}
case 2: {
auto float_array = float_feed[vec_idx].unchecked<2>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
for (ssize_t j = 0; j < float_array.shape(1); j++) {
tensor->add_float_data(float_array(i, j));
}
}
break;
}
case 1: {
auto float_array = float_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
tensor->add_float_data(float_array(i));
}
break;
}
auto float_buf = float_feed[vec_idx].request();
const float *float_ptr = static_cast<const float *>(float_buf.ptr);
if (size != float_buf.size) {
LOG(ERROR) << "feed size=" << size << "!= buf_size=" << float_buf.size;
return -1;
}
for (auto i = 0; i < float_buf.size; ++i) {
tensor->add_float_data(float_ptr[i]);
}
vec_idx++;
}
......@@ -449,53 +421,23 @@ int PredictorClient::numpy_predict(
Tensor *tensor = tensor_vec[idx];
VLOG(2) << "prepare int feed " << name << " shape size "
<< int_shape[vec_idx].size();
auto size = 1;
for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
tensor->add_shape(int_shape[vec_idx][j]);
size *= int_shape[vec_idx][j];
}
tensor->set_elem_type(0);
const int int_shape_size = int_shape[vec_idx].size();
switch (int_shape_size) {
case 4: {
auto int_array = int_feed[vec_idx].unchecked<4>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
for (ssize_t l = 0; k < int_array.shape(3); l++) {
tensor->add_float_data(int_array(i, j, k, l));
}
}
}
}
break;
}
case 3: {
auto int_array = int_feed[vec_idx].unchecked<3>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
tensor->add_float_data(int_array(i, j, k));
}
}
}
break;
}
case 2: {
auto int_array = int_feed[vec_idx].unchecked<2>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
tensor->add_float_data(int_array(i, j));
}
}
break;
}
case 1: {
auto int_array = int_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
tensor->add_float_data(int_array(i));
}
break;
}
auto int_buf = int_feed[vec_idx].request();
const int64_t *int_ptr = static_cast<const int64_t *>(int_buf.ptr);
if (size != int_buf.size) {
LOG(ERROR) << "feed size=" << size << "!= buf_size=" << int_buf.size;
return -1;
}
for (auto i = 0; i < int_buf.size; ++i) {
tensor->add_int64_data(int_ptr[i]);
}
vec_idx++;
}
......
......@@ -103,12 +103,10 @@ PYBIND11_MODULE(serving_client, m) {
})
.def("numpy_predict",
[](PredictorClient &self,
const std::vector<std::vector<py::array_t<float>>>
&float_feed_batch,
const PredictorClient::batch_numpy_t<float> &float_feed_batch,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int>> &float_shape,
const std::vector<std::vector<py::array_t<int64_t>>>
&int_feed_batch,
const PredictorClient::batch_numpy_t<int64_t> &int_feed_batch,
const std::vector<std::string> &int_feed_name,
const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册