提交 2548edd4 编写于 作者: G guru4elephant

send and recv through int64 and float value

上级 7fd482b3
......@@ -31,7 +31,7 @@ message( "WITH_GPU = ${WITH_GPU}")
# Paddle Version should be one of:
# latest: latest develop build
# version number like 1.5.2
SET(PADDLE_VERSION "latest")
SET(PADDLE_VERSION "1.6.3")
if (WITH_GPU)
SET(PADDLE_LIB_VERSION "${PADDLE_VERSION}-gpu-cuda${CUDA_VERSION_MAJOR}-cudnn7-avx-mkl")
......
......@@ -26,7 +26,8 @@ message FetchVar {
optional string name = 1;
optional string alias_name = 2;
optional bool is_lod_tensor = 3 [ default = false ];
repeated int32 shape = 4;
optional int32 fetch_type = 4 [ default = 0 ];
repeated int32 shape = 5;
}
message GeneralModelConfig {
repeated FeedVar feed_var = 1;
......
......@@ -39,9 +39,25 @@ namespace baidu {
namespace paddle_serving {
namespace general_model {
typedef std::map<std::string, std::vector<float>> FetchedMap;
class PredictorRes {
public:
PredictorRes() {}
~PredictorRes() {}
public:
const std::vector<std::vector<int64_t>> & get_int64_by_name(
const std::string & name) {
return _int64_map[name];
}
const std::vector<std::vector<float>> & get_float_by_name(
const std::string & name) {
return _float_map[name];
}
typedef std::map<std::string, std::vector<std::vector<float>>> BatchFetchedMap;
public:
std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map;
std::map<std::string, std::vector<std::vector<float>>> _float_map;
};
class PredictorClient {
public:
......@@ -60,6 +76,13 @@ class PredictorClient {
int create_predictor();
int destroy_predictor();
int predict(const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name,
PredictorRes & predict_res); // NOLINT
std::vector<std::vector<float>> predict(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
......@@ -74,13 +97,6 @@ class PredictorClient {
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
std::vector<std::vector<float>> predict_with_profile(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
private:
PredictorApi _api;
Predictor* _predictor;
......@@ -90,6 +106,7 @@ class PredictorClient {
std::map<std::string, int> _feed_name_to_idx;
std::map<std::string, int> _fetch_name_to_idx;
std::map<std::string, std::string> _fetch_name_to_var_name;
std::map<std::string, int> _fetch_name_to_type;
std::vector<std::vector<int>> _shape;
std::vector<int> _type;
std::vector<int64_t> _last_request_ts;
......
......@@ -93,6 +93,8 @@ int PredictorClient::init(const std::string &conf_file) {
<< " alias name: " << model_config.fetch_var(i).alias_name();
_fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] =
model_config.fetch_var(i).name();
_fetch_name_to_type[model_config.fetch_var(i).alias_name()] =
model_config.fetch_var(i).fetch_type();
}
} catch (std::exception &e) {
LOG(ERROR) << "Failed load general model config" << e.what();
......@@ -130,35 +132,25 @@ int PredictorClient::create_predictor() {
_api.thrd_initialize();
}
std::vector<std::vector<float>> PredictorClient::predict(
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
std::vector<std::vector<float>> fetch_result;
if (fetch_name.size() == 0) {
return fetch_result;
}
int PredictorClient::predict(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name,
PredictorRes & predict_res) { // NOLINT
predict_res._int64_map.clear();
predict_res._float_map.clear();
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
// we save infer_us at fetch_result[fetch_name.size()]
fetch_result.resize(fetch_name.size());
_api.thrd_clear();
_predictor = _api.fetch_predictor("general_model");
VLOG(2) << "fetch general model predictor done.";
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
VLOG(2) << "fetch name size: " << fetch_name.size();
Request req;
for (auto &name : fetch_name) {
req.add_fetch_var_names(name);
}
std::vector<Tensor *> tensor_vec;
FeedInst *inst = req.add_insts();
for (auto &name : float_feed_name) {
......@@ -168,7 +160,6 @@ std::vector<std::vector<float>> PredictorClient::predict(
for (auto &name : int_feed_name) {
tensor_vec.push_back(inst->add_tensor_array());
}
VLOG(2) << "prepare tensor vec done.";
int vec_idx = 0;
for (auto &name : float_feed_name) {
......@@ -179,16 +170,14 @@ std::vector<std::vector<float>> PredictorClient::predict(
}
tensor->set_elem_type(1);
for (int j = 0; j < float_feed[vec_idx].size(); ++j) {
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(float_feed[vec_idx][j]))),
sizeof(float));
tensor->add_float_data(float_feed[vec_idx][j]);
}
vec_idx++;
}
VLOG(2) << "feed float feed var done.";
vec_idx = 0;
for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx];
......@@ -197,15 +186,12 @@ std::vector<std::vector<float>> PredictorClient::predict(
}
tensor->set_elem_type(0);
for (int j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(int_feed[vec_idx][j]))),
sizeof(int64_t));
tensor->add_int64_data(int_feed[vec_idx][j]);
}
vec_idx++;
}
int64_t preprocess_end = timeline.TimeStampUS();
int64_t client_infer_start = timeline.TimeStampUS();
Response res;
......@@ -222,22 +208,33 @@ std::vector<std::vector<float>> PredictorClient::predict(
res.Clear();
if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
exit(-1);
return -1;
} else {
VLOG(2) << "predict done.";
client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end;
for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
int len = res.insts(0).tensor_array(idx).data_size();
VLOG(2) << "fetch name: " << name;
VLOG(2) << "tensor data size: " << len;
fetch_result[idx].resize(len);
for (int i = 0; i < len; ++i) {
fetch_result[idx][i] =
*(const float *)res.insts(0).tensor_array(idx).data(i).c_str();
if (_fetch_name_to_type[name] == 0) {
int len = res.insts(0).tensor_array(idx).int64_data_size();
predict_res._int64_map[name].resize(1);
predict_res._int64_map[name][0].resize(len);
for (int i = 0; i < len; ++i) {
predict_res._int64_map[name][0][i] =
res.insts(0).tensor_array(idx).int64_data(i);
}
} else if (_fetch_name_to_type[name] == 1) {
int len = res.insts(0).tensor_array(idx).float_data_size();
predict_res._float_map[name].resize(1);
predict_res._float_map[name][0].resize(len);
for (int i = 0; i < len; ++i) {
predict_res._float_map[name][0][i] =
res.insts(0).tensor_array(idx).float_data(i);
}
}
postprocess_end = timeline.TimeStampUS();
}
postprocess_end = timeline.TimeStampUS();
}
if (FLAGS_profile_client) {
......@@ -247,7 +244,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
<< "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) {
......@@ -255,14 +252,13 @@ std::vector<std::vector<float>> PredictorClient::predict(
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
}
}
oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end;
fprintf(stderr, "%s\n", oss.str().c_str());
}
return fetch_result;
return 0;
}
std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
......@@ -321,9 +317,12 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
}
tensor->set_elem_type(1);
for (int j = 0; j < float_feed[vec_idx].size(); ++j) {
/*
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(float_feed[vec_idx][j]))),
sizeof(float));
*/
tensor->add_float_data(float_feed[vec_idx][j]);
}
vec_idx++;
}
......@@ -342,9 +341,12 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(3) << "feed var name " << name << " index " << vec_idx
<< "first data " << int_feed[vec_idx][0];
for (int j = 0; j < int_feed[vec_idx].size(); ++j) {
/*
tensor->add_data(const_cast<char *>(reinterpret_cast<const char *>(
&(int_feed[vec_idx][j]))),
sizeof(int64_t));
*/
tensor->add_int64_data(int_feed[vec_idx][j]);
}
vec_idx++;
}
......@@ -387,10 +389,18 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2)
<< "fetch name " << name << " index " << idx << " first data "
<< *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str();
for (int i = 0; i < len; ++i) {
fetch_result_batch[bi][idx][i] =
*(const float *)res.insts(bi).tensor_array(idx).data(i).c_str();
/*
if (_fetch_name_to_va[name] == 0) { // int64
for (int i = 0; i < len; ++i) {
fetch_result_batch[bi][idx][i] =
*(const int64 *)res.insts(bi).tensor_array(idx).int64_data(i).c_str();
}
} else {
for (int i = 0; i < len; ++i) {
fetch_result_batch
}
}
*/
}
}
postprocess_end = timeline.TimeStampUS();
......@@ -420,16 +430,6 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
return fetch_result_batch;
}
std::vector<std::vector<float>> PredictorClient::predict_with_profile(
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) {
std::vector<std::vector<float>> res;
return res;
}
} // namespace general_model
} // namespace paddle_serving
} // namespace baidu
......@@ -20,8 +20,6 @@
namespace py = pybind11;
using baidu::paddle_serving::general_model::FetchedMap;
namespace baidu {
namespace paddle_serving {
namespace general_model {
......@@ -29,6 +27,18 @@ namespace general_model {
PYBIND11_MODULE(serving_client, m) {
m.doc() = R"pddoc(this is a practice
)pddoc";
py::class_<PredictorRes>(m, "PredictorRes", py::buffer_protocol())
.def(py::init())
.def("get_int64_by_name",
[](PredictorRes &self, std::string & name) {
return self.get_int64_by_name(name);
}, py::return_value_policy::reference)
.def("get_float_by_name",
[](PredictorRes &self, std::string & name) {
return self.get_float_by_name(name);
}, py::return_value_policy::reference);
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
.def("init_gflags",
......@@ -52,6 +62,21 @@ PYBIND11_MODULE(serving_client, m) {
[](PredictorClient &self) { self.create_predictor(); })
.def("destroy_predictor",
[](PredictorClient &self) { self.destroy_predictor(); })
.def("predict",
[](PredictorClient &self,
const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
PredictorRes & predict_res) {
return self.predict(float_feed,
float_feed_name,
int_feed,
int_feed_name,
fetch_name,
predict_res);
})
.def("predict",
[](PredictorClient &self,
const std::vector<std::vector<float>> &float_feed,
......@@ -65,7 +90,6 @@ PYBIND11_MODULE(serving_client, m) {
int_feed_name,
fetch_name);
})
.def("batch_predict",
[](PredictorClient &self,
const std::vector<std::vector<std::vector<float>>>
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_dist_kv_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::Timer;
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::general_model::Tensor;
using baidu::paddle_serving::predictor::general_model::Request;
using baidu::paddle_serving::predictor::general_model::FeedInst;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralDistKVOp::inference() {
// reade request from client
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
VLOG(2) << "precedent name: " << pre_name();
const TensorVector *in = &input_blob->tensor_vector;
VLOG(2) << "input size: " << in->size();
int batch_size = input_blob->GetBatchSize();
int input_var_num = 0;
GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *out = &res->tensor_vector;
VLOG(2) << "input batch size: " << batch_size;
res->SetBatchSize(batch_size);
if (!res) {
LOG(ERROR) << "Failed get op tls reader object output";
}
Timer timeline;
int64_t start = timeline.TimeStampUS();
VLOG(2) << "Going to init lod tensor";
for (int i = 0; i < in->size(); ++i) {
paddle::PaddleTensor lod_tensor;
CopyLod(&in->at(i), &lod_tensor);
lod_tensor.dtype = in->at(i).dtype;
lod_tensor.name = in->at(i).name;
VLOG(2) << "lod tensor [" << i << "].name = " << lod_tensor.name;
out->push_back(lod_tensor);
}
VLOG(2) << "pack done.";
for (int i = 0; i < out->size(); ++i) {
int64_t *src_ptr = static_cast<int64_t *>(in->at(i).data.data());
out->at(i).data.Resize(
out->at(i).lod[0].back() * sizeof(int64_t));
out->at(i).shape = {out->at(i).lod[0].back(), 1};
int64_t *tgt_ptr = static_cast<int64_t *>(out->at(i).data.data());
for (int j = 0; j < out->at(i).lod[0].back(); ++j) {
tgt_ptr[j] = src_ptr[j];
}
}
VLOG(2) << "output done.";
timeline.Pause();
int64_t end = timeline.TimeStampUS();
res->p_size = 0;
AddBlobInfo(res, start);
AddBlobInfo(res, end);
VLOG(2) << "read data from client success";
return 0;
}
DEFINE_OP(GeneralDistKVOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle_inference_api.h" // NOLINT
#endif
#include <string>
#include "core/predictor/framework/resource.h"
#include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/general_model_service.pb.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
class GeneralDistKVOp :
public baidu::paddle_serving::predictor::OpWithChannel<GeneralBlob> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(GeneralDistKVOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -65,6 +65,19 @@ static void CopyBlobInfo(const GeneralBlob* src, GeneralBlob* tgt) {
src->p_size * sizeof(int64_t));
}
static void CopyLod(const paddle::PaddleTensor* src,
paddle::PaddleTensor* tgt) {
VLOG(2) << "copy lod done.";
tgt->lod.resize(src->lod.size());
VLOG(2) << "src lod size: " << src->lod.size();
for (int i = 0; i < src->lod.size(); ++i) {
tgt->lod[i].resize(src->lod[i].size());
for (int j = 0; j < src->lod[i].size(); ++j) {
tgt->lod[i][j] = src->lod[i][j];
}
}
}
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -104,17 +104,21 @@ int GeneralReaderOp::inference() {
VLOG(2) << "print general model config done.";
// TODO(guru4elephant): how to do conditional check?
/*
int ret = conf_check(req, model_config);
if (ret != 0) {
LOG(INFO) << "model conf of server:";
LOG(ERROR) << "model conf of server:";
resource.print_general_model_config(model_config);
return 0;
}
*/
// package tensor
elem_type.resize(var_num);
elem_size.resize(var_num);
capacity.resize(var_num);
// prepare basic information for input
for (int i = 0; i < var_num; ++i) {
paddle::PaddleTensor lod_tensor;
elem_type[i] = req->insts(0).tensor_array(i).elem_type();
......@@ -146,14 +150,22 @@ int GeneralReaderOp::inference() {
out->push_back(lod_tensor);
}
// specify the memory needed for output tensor_vector
for (int i = 0; i < var_num; ++i) {
if (out->at(i).lod.size() == 1) {
for (int j = 0; j < batch_size; ++j) {
const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = tensor.data_size();
VLOG(2) << "tensor size for var[" << i << "]: " << tensor.data_size();
int data_len = 0;
if (tensor.int64_data_size() > 0) {
data_len = tensor.int64_data_size();
} else {
data_len = tensor.float_data_size();
}
VLOG(2) << "tensor size for var[" << i << "]: " << data_len;
int cur_len = out->at(i).lod[0].back();
VLOG(2) << "current len: " << cur_len;
out->at(i).lod[0].push_back(cur_len + data_len);
VLOG(2) << "new len: " << cur_len + data_len;
}
......@@ -168,14 +180,16 @@ int GeneralReaderOp::inference() {
}
}
// fill the data into output general_blob
for (int i = 0; i < var_num; ++i) {
if (elem_type[i] == 0) {
int64_t *dst_ptr = static_cast<int64_t *>(out->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) {
int elem_num = req->insts(j).tensor_array(i).int64_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] =
*(const int64_t *)req->insts(j).tensor_array(i).data(k).c_str();
req->insts(j).tensor_array(i).int64_data(k);
}
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
......@@ -187,9 +201,10 @@ int GeneralReaderOp::inference() {
float *dst_ptr = static_cast<float *>(out->at(i).data.data());
int offset = 0;
for (int j = 0; j < batch_size; ++j) {
for (int k = 0; k < req->insts(j).tensor_array(i).data_size(); ++k) {
int elem_num = req->insts(j).tensor_array(i).float_data_size();
for (int k = 0; k < elem_num; ++k) {
dst_ptr[offset + k] =
*(const float *)req->insts(j).tensor_array(i).data(k).c_str();
req->insts(j).tensor_array(i).float_data(k);
}
if (out->at(i).lod.size() == 1) {
offset = out->at(i).lod[0][j + 1];
......@@ -200,6 +215,8 @@ int GeneralReaderOp::inference() {
}
}
VLOG(2) << "output size: " << out->size();
timeline.Pause();
int64_t end = timeline.TimeStampUS();
res->p_size = 0;
......
......@@ -95,36 +95,69 @@ int GeneralResponseOp::inference() {
int var_idx = 0;
for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j];
}
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
if (in->at(idx).dtype == paddle::PaddleDType::INT64) {
int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
FetchInst *fetch_p = res->mutable_insts(j);
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[k]);
}
}
} else {
int var_size = in->at(idx).shape[0];
if (var_size == batch_size) {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
FetchInst *fetch_p = res->mutable_insts(j);
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(
data_ptr[k]);
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
FetchInst *fetch_p = res->mutable_insts(j);
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(
data_ptr[0]);
}
}
}
} else {
int var_size = in->at(idx).shape[0];
if (var_size == batch_size) {
var_idx++;
} else if (in->at(idx).dtype == paddle::PaddleDType::FLOAT32) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
FetchInst *fetch_p = res->mutable_insts(j);
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[k]);
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[0])), sizeof(float));
int var_size = in->at(idx).shape[0];
if (var_size == batch_size) {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
FetchInst * fetch_p = res->mutable_insts(j);
fetch_p->mutable_tensor_array(var_idx)->add_float_data(
data_ptr[k]);
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
FetchInst * fetch_p = res->mutable_insts(j);
fetch_p->mutable_tensor_array(var_idx)->add_float_data(
data_ptr[0]);
}
}
}
var_idx++;
}
var_idx++;
}
if (req->profile_server()) {
......
......@@ -22,9 +22,10 @@ option cc_generic_services = true;
message Tensor {
repeated bytes data = 1;
repeated int32 int_data = 2;
repeated float float_data = 3;
optional int32 elem_type = 4;
repeated int32 shape = 5;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
};
message FeedInst {
......
......@@ -20,11 +20,12 @@ package baidu.paddle_serving.predictor.general_model;
option cc_generic_services = true;
message Tensor {
repeated bytes data = 1; // most general format
repeated int32 int_data = 2; // for simple debug only
repeated float float_data = 3; // for simple debug only
optional int32 elem_type = 4; // support int64, float32
repeated int32 shape = 5;
repeated bytes data = 1;
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
};
message FeedInst {
......
......@@ -16,7 +16,7 @@ def ctr_dnn_model_dataset(dense_input, sparse_inputs, label,
return fluid.layers.sequence_pool(input=emb, pool_type='sum')
sparse_embed_seq = list(map(embedding_layer, sparse_inputs))
concated = fluid.layers.concat(sparse_embed_seq + [dense_input], axis=1)
concated = fluid.layers.concat(sparse_embed_seq, axis=1)
fc1 = fluid.layers.fc(input=concated, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(concated.shape[1]))))
......
......@@ -21,12 +21,8 @@ label_list = []
prob_list = []
for data in reader():
feed_dict = {}
feed_dict["dense_0"] = data[0][0]
for i in range(1, 27):
feed_dict["sparse_{}".format(i - 1)] = data[0][i]
feed_dict["label"] = data[0][-1]
fetch_map = client.predict(feed=feed_dict, fetch=["prob"])
prob_list.append(fetch_map["prob"][0])
label_list.append(data[0][-1][0])
print(fetch_map)
print(auc(prob_list, label_list))
......@@ -6,11 +6,15 @@ from paddle_serving_server import Server
op_maker = OpMaker()
read_op = op_maker.create('general_reader')
dist_op = op_maker.create('general_dist_kv')
general_infer_op = op_maker.create('general_infer')
response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op)
op_seq_maker.add_op(dist_op)
op_seq_maker.add_op(general_infer_op)
op_seq_maker.add_op(response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
......
......@@ -3,7 +3,7 @@ import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9292"])
client.connect(["127.0.0.1:9393"])
for line in sys.stdin:
group = line.strip().split()
......
......@@ -73,6 +73,7 @@ class Client(object):
self.feed_names_ = []
self.fetch_names_ = []
self.client_handle_ = None
self.result_handle_ = None
self.feed_shapes_ = []
self.feed_types_ = {}
self.feed_names_to_idx_ = {}
......@@ -87,6 +88,7 @@ class Client(object):
def load_client_config(self, path):
from .serving_client import PredictorClient
from .serving_client import PredictorRes
model_conf = m_config.GeneralModelConfig()
f = open(path, 'r')
model_conf = google.protobuf.text_format.Merge(
......@@ -96,6 +98,7 @@ class Client(object):
# get feed vars, fetch vars
# get feed shapes, feed types
# map feed names to index
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
read_env_flags = ["profile_client", "profile_server"]
......@@ -120,6 +123,7 @@ class Client(object):
sdk_desc = predictor_sdk.gen_desc()
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
def get_feed_names(self):
return self.feed_names_
......@@ -147,15 +151,19 @@ class Client(object):
if key in self.fetch_names_:
fetch_names.append(key)
'''
result = self.client_handle_.predict(
float_slot, float_feed_names, int_slot, int_feed_names, fetch_names)
'''
ret = self.client_handle_.predict(
float_slot, float_feed_names, int_slot, int_feed_names, fetch_names, self.result_handle_)
# TODO(guru4elephant): the order of fetch var name should be consistent with
# general_model_config, this is not friendly
# In the future, we need make the number of fetched variable changable
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
result_map[name] = self.result_handle_.get_float_by_name(name)
return result_map
......
......@@ -62,6 +62,13 @@ def save_model(server_model_folder,
fetch_var.alias_name = key
fetch_var.name = fetch_var_dict[key].name
fetch_var.is_lod_tensor = fetch_var_dict[key].lod_level >= 1
if fetch_var_dict[key].dtype == core.VarDesc.VarType.INT32 or \
fetch_var_dict[key].dtype == core.VarDesc.VarType.INT64:
fetch_var.fetch_type = 0
if fetch_var_dict[key].dtype == core.VarDesc.VarType.FP32:
fetch_var.fetch_type = 1
if fetch_var.is_lod_tensor:
fetch_var.shape.extend([-1])
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册