提交 cf48c467 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #186 from MRXLT/general-server-v1

bug fix
......@@ -56,8 +56,7 @@ void PredictorClient::init_gflags(std::vector<std::string> argv) {
int PredictorClient::init(const std::string &conf_file) {
try {
GeneralModelConfig model_config;
if (configure::read_proto_conf(conf_file.c_str(),
&model_config) != 0) {
if (configure::read_proto_conf(conf_file.c_str(), &model_config) != 0) {
LOG(ERROR) << "Failed to load general model config"
<< ", file path: " << conf_file;
return -1;
......@@ -75,26 +74,27 @@ int PredictorClient::init(const std::string &conf_file) {
VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name()
<< " index: " << i;
std::vector<int> tmp_feed_shape;
VLOG(2) << "feed" << "[" << i << "] shape:";
VLOG(2) << "feed"
<< "[" << i << "] shape:";
for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) {
tmp_feed_shape.push_back(model_config.feed_var(i).shape(j));
VLOG(2) << "shape[" << j << "]: "
<< model_config.feed_var(i).shape(j);
VLOG(2) << "shape[" << j << "]: " << model_config.feed_var(i).shape(j);
}
_type.push_back(model_config.feed_var(i).feed_type());
VLOG(2) << "feed" << "[" << i << "] feed type: "
<< model_config.feed_var(i).feed_type();
VLOG(2) << "feed"
<< "[" << i
<< "] feed type: " << model_config.feed_var(i).feed_type();
_shape.push_back(tmp_feed_shape);
}
for (int i = 0; i < fetch_var_num; ++i) {
_fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i;
VLOG(2) << "fetch [" << i << "]" << " alias name: "
<< model_config.fetch_var(i).alias_name();
VLOG(2) << "fetch [" << i << "]"
<< " alias name: " << model_config.fetch_var(i).alias_name();
_fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] =
model_config.fetch_var(i).name();
}
} catch (std::exception& e) {
} catch (std::exception &e) {
LOG(ERROR) << "Failed load general model config" << e.what();
return -1;
}
......@@ -112,7 +112,7 @@ int PredictorClient::destroy_predictor() {
_api.destroy();
}
int PredictorClient::create_predictor_by_desc(const std::string & sdk_desc) {
int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
if (_api.create(sdk_desc) != 0) {
LOG(ERROR) << "Predictor Creation Failed";
return -1;
......@@ -156,7 +156,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
VLOG(2) << "fetch name size: " << fetch_name.size();
Request req;
for (auto & name : fetch_name) {
for (auto &name : fetch_name) {
req.add_fetch_var_names(name);
}
std::vector<Tensor *> tensor_vec;
......@@ -288,7 +288,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
Request req;
for (auto & name : fetch_name) {
for (auto &name : fetch_name) {
req.add_fetch_var_names(name);
}
//
......@@ -324,7 +324,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
vec_idx++;
}
VLOG(2) << "batch [" << bi << "] " << "float feed value prepared";
VLOG(2) << "batch [" << bi << "] "
<< "float feed value prepared";
vec_idx = 0;
for (auto &name : int_feed_name) {
......@@ -344,7 +345,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
vec_idx++;
}
VLOG(2) << "batch [" << bi << "] " << "itn feed value prepared";
VLOG(2) << "batch [" << bi << "] "
<< "itn feed value prepared";
}
Response res;
......
......@@ -38,6 +38,8 @@ struct GeneralBlob {
int64_t time_stamp[20];
int p_size = 0;
int _batch_size;
void Clear() {
size_t tensor_count = tensor_vector.size();
for (size_t ti = 0; ti < tensor_count; ++ti) {
......@@ -46,30 +48,20 @@ struct GeneralBlob {
tensor_vector.clear();
}
int GetBatchSize() const {
if (tensor_vector.size() > 0) {
if (tensor_vector[0].lod.size() == 1) {
return tensor_vector[0].lod[0].size() - 1;
} else {
return tensor_vector[0].shape[0];
}
} else {
return -1;
}
}
int SetBatchSize(int batch_size) { _batch_size = batch_size; }
int GetBatchSize() const { return _batch_size; }
std::string ShortDebugString() const { return "Not implemented!"; }
};
static void AddBlobInfo(GeneralBlob * blob,
int64_t init_value) {
static void AddBlobInfo(GeneralBlob* blob, int64_t init_value) {
blob->time_stamp[blob->p_size] = init_value;
blob->p_size++;
}
static void CopyBlobInfo(const GeneralBlob * src,
GeneralBlob * tgt) {
memcpy(&(tgt->time_stamp[0]), &(src->time_stamp[0]),
static void CopyBlobInfo(const GeneralBlob* src, GeneralBlob* tgt) {
memcpy(&(tgt->time_stamp[0]),
&(src->time_stamp[0]),
src->p_size * sizeof(int64_t));
}
......
......@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_infer_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
......@@ -36,25 +36,26 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralInferOp::inference() {
const GeneralBlob * input_blob =
get_depend_argument<GeneralBlob>(pre_name());
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
GeneralBlob * output_blob = mutable_data<GeneralBlob>();
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op:"
<< pre_name();
LOG(ERROR) << "Failed mutable depended argument, op:" << pre_name();
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->GetBatchSize();
output_blob->SetBatchSize(batch_size);
VLOG(2) << "infer batch size: " << batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS();
timeline.Start();
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1;
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_reader_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_reader_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h"
......@@ -75,7 +75,6 @@ int GeneralReaderOp::inference() {
int batch_size = req->insts_size();
int input_var_num = 0;
std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size;
std::vector<int64_t> capacity;
......@@ -83,6 +82,8 @@ int GeneralReaderOp::inference() {
GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *out = &res->tensor_vector;
res->SetBatchSize(batch_size);
if (!res) {
LOG(ERROR) << "Failed get op tls reader object output";
}
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_response_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_response_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
......@@ -37,12 +37,10 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralResponseOp::inference() {
const GeneralBlob *input_blob =
get_depend_argument<GeneralBlob>(pre_name());
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op: "
<< pre_name();
LOG(ERROR) << "Failed mutable depended argument, op: " << pre_name();
return -1;
}
......@@ -73,11 +71,12 @@ int GeneralResponseOp::inference() {
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
}
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts();
for (auto & idx : fetch_index) {
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor
tensor->set_elem_type(1);
......@@ -87,8 +86,7 @@ int GeneralResponseOp::inference() {
} else {
VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k - 1 << "]: "
<< in->at(idx).shape[k];
VLOG(2) << "shape[" << k - 1 << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
}
......@@ -96,7 +94,7 @@ int GeneralResponseOp::inference() {
}
int var_idx = 0;
for (auto & idx : fetch_index) {
for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) {
......@@ -104,19 +102,27 @@ int GeneralResponseOp::inference() {
}
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j];
k < in->at(idx).lod[0][j + 1]; k++) {
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
} else {
int var_size = in->at(idx).shape[0];
if (var_size == batch_size) {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[0])), sizeof(float));
}
}
}
var_idx++;
}
......
......@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_text_response_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_text_response_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
......@@ -36,12 +36,10 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralTextResponseOp::inference() {
const GeneralBlob *input_blob =
get_depend_argument<GeneralBlob>(pre_name());
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op: "
<< pre_name();
LOG(ERROR) << "Failed mutable depended argument, op: " << pre_name();
return -1;
}
......@@ -74,7 +72,7 @@ int GeneralTextResponseOp::inference() {
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts();
for (auto & idx : fetch_index) {
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor
tensor->set_elem_type(1);
......@@ -84,8 +82,7 @@ int GeneralTextResponseOp::inference() {
} else {
VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k - 1 << "]: "
<< in->at(idx).shape[k];
VLOG(2) << "shape[" << k - 1 << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
}
......@@ -93,7 +90,7 @@ int GeneralTextResponseOp::inference() {
}
int var_idx = 0;
for (auto & idx : fetch_index) {
for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) {
......@@ -101,8 +98,8 @@ int GeneralTextResponseOp::inference() {
}
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j];
k < in->at(idx).lod[0][j + 1]; k++) {
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_float_data(
data_ptr[k]);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册