未验证 提交 b8d560a7 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #186 from MRXLT/general-server-v1

bug fix
...@@ -38,26 +38,25 @@ using configure::GeneralModelConfig; ...@@ -38,26 +38,25 @@ using configure::GeneralModelConfig;
void PredictorClient::init_gflags(std::vector<std::string> argv) { void PredictorClient::init_gflags(std::vector<std::string> argv) {
std::call_once(gflags_init_flag, [&]() { std::call_once(gflags_init_flag, [&]() {
FLAGS_logtostderr = true; FLAGS_logtostderr = true;
argv.insert(argv.begin(), "dummy"); argv.insert(argv.begin(), "dummy");
int argc = argv.size(); int argc = argv.size();
char **arr = new char *[argv.size()]; char **arr = new char *[argv.size()];
std::string line; std::string line;
for (size_t i = 0; i < argv.size(); i++) { for (size_t i = 0; i < argv.size(); i++) {
arr[i] = &argv[i][0]; arr[i] = &argv[i][0];
line += argv[i]; line += argv[i];
line += ' '; line += ' ';
} }
google::ParseCommandLineFlags(&argc, &arr, true); google::ParseCommandLineFlags(&argc, &arr, true);
VLOG(2) << "Init commandline: " << line; VLOG(2) << "Init commandline: " << line;
}); });
} }
int PredictorClient::init(const std::string &conf_file) { int PredictorClient::init(const std::string &conf_file) {
try { try {
GeneralModelConfig model_config; GeneralModelConfig model_config;
if (configure::read_proto_conf(conf_file.c_str(), if (configure::read_proto_conf(conf_file.c_str(), &model_config) != 0) {
&model_config) != 0) {
LOG(ERROR) << "Failed to load general model config" LOG(ERROR) << "Failed to load general model config"
<< ", file path: " << conf_file; << ", file path: " << conf_file;
return -1; return -1;
...@@ -75,26 +74,27 @@ int PredictorClient::init(const std::string &conf_file) { ...@@ -75,26 +74,27 @@ int PredictorClient::init(const std::string &conf_file) {
VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name() VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name()
<< " index: " << i; << " index: " << i;
std::vector<int> tmp_feed_shape; std::vector<int> tmp_feed_shape;
VLOG(2) << "feed" << "[" << i << "] shape:"; VLOG(2) << "feed"
<< "[" << i << "] shape:";
for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) { for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) {
tmp_feed_shape.push_back(model_config.feed_var(i).shape(j)); tmp_feed_shape.push_back(model_config.feed_var(i).shape(j));
VLOG(2) << "shape[" << j << "]: " VLOG(2) << "shape[" << j << "]: " << model_config.feed_var(i).shape(j);
<< model_config.feed_var(i).shape(j);
} }
_type.push_back(model_config.feed_var(i).feed_type()); _type.push_back(model_config.feed_var(i).feed_type());
VLOG(2) << "feed" << "[" << i << "] feed type: " VLOG(2) << "feed"
<< model_config.feed_var(i).feed_type(); << "[" << i
<< "] feed type: " << model_config.feed_var(i).feed_type();
_shape.push_back(tmp_feed_shape); _shape.push_back(tmp_feed_shape);
} }
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
_fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i; _fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i;
VLOG(2) << "fetch [" << i << "]" << " alias name: " VLOG(2) << "fetch [" << i << "]"
<< model_config.fetch_var(i).alias_name(); << " alias name: " << model_config.fetch_var(i).alias_name();
_fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] = _fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] =
model_config.fetch_var(i).name(); model_config.fetch_var(i).name();
} }
} catch (std::exception& e) { } catch (std::exception &e) {
LOG(ERROR) << "Failed load general model config" << e.what(); LOG(ERROR) << "Failed load general model config" << e.what();
return -1; return -1;
} }
...@@ -112,7 +112,7 @@ int PredictorClient::destroy_predictor() { ...@@ -112,7 +112,7 @@ int PredictorClient::destroy_predictor() {
_api.destroy(); _api.destroy();
} }
int PredictorClient::create_predictor_by_desc(const std::string & sdk_desc) { int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
if (_api.create(sdk_desc) != 0) { if (_api.create(sdk_desc) != 0) {
LOG(ERROR) << "Predictor Creation Failed"; LOG(ERROR) << "Predictor Creation Failed";
return -1; return -1;
...@@ -156,7 +156,7 @@ std::vector<std::vector<float>> PredictorClient::predict( ...@@ -156,7 +156,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
VLOG(2) << "fetch name size: " << fetch_name.size(); VLOG(2) << "fetch name size: " << fetch_name.size();
Request req; Request req;
for (auto & name : fetch_name) { for (auto &name : fetch_name) {
req.add_fetch_var_names(name); req.add_fetch_var_names(name);
} }
std::vector<Tensor *> tensor_vec; std::vector<Tensor *> tensor_vec;
...@@ -247,7 +247,7 @@ std::vector<std::vector<float>> PredictorClient::predict( ...@@ -247,7 +247,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
<< "prepro_1:" << preprocess_end << " " << "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " " << "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " "; << "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) { if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2; int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) { for (int i = 0; i < op_num; ++i) {
...@@ -255,7 +255,7 @@ std::vector<std::vector<float>> PredictorClient::predict( ...@@ -255,7 +255,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " "; oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
} }
} }
oss << "postpro_0:" << postprocess_start << " "; oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end; oss << "postpro_1:" << postprocess_end;
...@@ -288,7 +288,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -288,7 +288,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
VLOG(2) << "float feed name size: " << float_feed_name.size(); VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
Request req; Request req;
for (auto & name : fetch_name) { for (auto &name : fetch_name) {
req.add_fetch_var_names(name); req.add_fetch_var_names(name);
} }
// //
...@@ -324,7 +324,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -324,7 +324,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
vec_idx++; vec_idx++;
} }
VLOG(2) << "batch [" << bi << "] " << "float feed value prepared"; VLOG(2) << "batch [" << bi << "] "
<< "float feed value prepared";
vec_idx = 0; vec_idx = 0;
for (auto &name : int_feed_name) { for (auto &name : int_feed_name) {
...@@ -344,7 +345,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -344,7 +345,8 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
vec_idx++; vec_idx++;
} }
VLOG(2) << "batch [" << bi << "] " << "itn feed value prepared"; VLOG(2) << "batch [" << bi << "] "
<< "itn feed value prepared";
} }
Response res; Response res;
......
...@@ -38,6 +38,8 @@ struct GeneralBlob { ...@@ -38,6 +38,8 @@ struct GeneralBlob {
int64_t time_stamp[20]; int64_t time_stamp[20];
int p_size = 0; int p_size = 0;
int _batch_size;
void Clear() { void Clear() {
size_t tensor_count = tensor_vector.size(); size_t tensor_count = tensor_vector.size();
for (size_t ti = 0; ti < tensor_count; ++ti) { for (size_t ti = 0; ti < tensor_count; ++ti) {
...@@ -45,31 +47,21 @@ struct GeneralBlob { ...@@ -45,31 +47,21 @@ struct GeneralBlob {
} }
tensor_vector.clear(); tensor_vector.clear();
} }
int GetBatchSize() const {
if (tensor_vector.size() > 0) {
if (tensor_vector[0].lod.size() == 1) {
return tensor_vector[0].lod[0].size() - 1;
} else {
return tensor_vector[0].shape[0];
}
} else {
return -1;
}
}
int SetBatchSize(int batch_size) { _batch_size = batch_size; }
int GetBatchSize() const { return _batch_size; }
std::string ShortDebugString() const { return "Not implemented!"; } std::string ShortDebugString() const { return "Not implemented!"; }
}; };
static void AddBlobInfo(GeneralBlob * blob, static void AddBlobInfo(GeneralBlob* blob, int64_t init_value) {
int64_t init_value) {
blob->time_stamp[blob->p_size] = init_value; blob->time_stamp[blob->p_size] = init_value;
blob->p_size++; blob->p_size++;
} }
static void CopyBlobInfo(const GeneralBlob * src, static void CopyBlobInfo(const GeneralBlob* src, GeneralBlob* tgt) {
GeneralBlob * tgt) { memcpy(&(tgt->time_stamp[0]),
memcpy(&(tgt->time_stamp[0]), &(src->time_stamp[0]), &(src->time_stamp[0]),
src->p_size * sizeof(int64_t)); src->p_size * sizeof(int64_t));
} }
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_infer_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_infer_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
...@@ -36,25 +36,26 @@ using baidu::paddle_serving::predictor::InferManager; ...@@ -36,25 +36,26 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralInferOp::inference() { int GeneralInferOp::inference() {
const GeneralBlob * input_blob = const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
get_depend_argument<GeneralBlob>(pre_name());
GeneralBlob * output_blob = mutable_data<GeneralBlob>(); GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!input_blob) { if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op:" LOG(ERROR) << "Failed mutable depended argument, op:" << pre_name();
<< pre_name();
return -1; return -1;
} }
const TensorVector *in = &input_blob->tensor_vector; const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector; TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->GetBatchSize(); int batch_size = input_blob->GetBatchSize();
output_blob->SetBatchSize(batch_size);
VLOG(2) << "infer batch size: " << batch_size; VLOG(2) << "infer batch size: " << batch_size;
Timer timeline; Timer timeline;
int64_t start = timeline.TimeStampUS(); int64_t start = timeline.TimeStampUS();
timeline.Start(); timeline.Start();
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) { if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME; LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1; return -1;
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_reader_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_infer_helper.h" #include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_reader_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h" #include "core/util/include/timer.h"
...@@ -75,7 +75,6 @@ int GeneralReaderOp::inference() { ...@@ -75,7 +75,6 @@ int GeneralReaderOp::inference() {
int batch_size = req->insts_size(); int batch_size = req->insts_size();
int input_var_num = 0; int input_var_num = 0;
std::vector<int64_t> elem_type; std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size; std::vector<int64_t> elem_size;
std::vector<int64_t> capacity; std::vector<int64_t> capacity;
...@@ -83,6 +82,8 @@ int GeneralReaderOp::inference() { ...@@ -83,6 +82,8 @@ int GeneralReaderOp::inference() {
GeneralBlob *res = mutable_data<GeneralBlob>(); GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *out = &res->tensor_vector; TensorVector *out = &res->tensor_vector;
res->SetBatchSize(batch_size);
if (!res) { if (!res) {
LOG(ERROR) << "Failed get op tls reader object output"; LOG(ERROR) << "Failed get op tls reader object output";
} }
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_response_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_infer_helper.h" #include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_response_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
...@@ -37,12 +37,10 @@ using baidu::paddle_serving::predictor::InferManager; ...@@ -37,12 +37,10 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralResponseOp::inference() { int GeneralResponseOp::inference() {
const GeneralBlob *input_blob = const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
get_depend_argument<GeneralBlob>(pre_name());
if (!input_blob) { if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op: " LOG(ERROR) << "Failed mutable depended argument, op: " << pre_name();
<< pre_name();
return -1; return -1;
} }
...@@ -61,7 +59,7 @@ int GeneralResponseOp::inference() { ...@@ -61,7 +59,7 @@ int GeneralResponseOp::inference() {
VLOG(2) << "start to call load general model_conf op"; VLOG(2) << "start to call load general model_conf op";
baidu::paddle_serving::predictor::Resource &resource = baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance(); baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "get resource pointer done."; VLOG(2) << "get resource pointer done.";
std::shared_ptr<PaddleGeneralModelConfig> model_config = std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config(); resource.get_general_model_config();
...@@ -73,11 +71,12 @@ int GeneralResponseOp::inference() { ...@@ -73,11 +71,12 @@ int GeneralResponseOp::inference() {
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)]; model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
} }
// response inst with only fetch_var_names
Response *res = mutable_data<Response>(); Response *res = mutable_data<Response>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts(); FetchInst *fetch_inst = res->add_insts();
for (auto & idx : fetch_index) { for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array(); Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor // currently only response float tensor or lod_tensor
tensor->set_elem_type(1); tensor->set_elem_type(1);
...@@ -87,8 +86,7 @@ int GeneralResponseOp::inference() { ...@@ -87,8 +86,7 @@ int GeneralResponseOp::inference() {
} else { } else {
VLOG(2) << "out[" << idx << "] is tensor"; VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) { for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k - 1 << "]: " VLOG(2) << "shape[" << k - 1 << "]: " << in->at(idx).shape[k];
<< in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]); tensor->add_shape(in->at(idx).shape[k]);
} }
} }
...@@ -96,7 +94,7 @@ int GeneralResponseOp::inference() { ...@@ -96,7 +94,7 @@ int GeneralResponseOp::inference() {
} }
int var_idx = 0; int var_idx = 0;
for (auto & idx : fetch_index) { for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data()); float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1; int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) { for (int j = 1; j < in->at(idx).shape.size(); ++j) {
...@@ -104,17 +102,25 @@ int GeneralResponseOp::inference() { ...@@ -104,17 +102,25 @@ int GeneralResponseOp::inference() {
} }
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) { for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j]; for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k < in->at(idx).lod[0][j + 1]; k++) { k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data( res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float)); reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
} }
} }
} else { } else {
for (int j = 0; j < batch_size; ++j) { int var_size = in->at(idx).shape[0];
for (int k = j * cap; k < (j + 1) * cap; ++k) { if (var_size == batch_size) {
for (int j = 0; j < batch_size; ++j) {
for (int k = j * cap; k < (j + 1) * cap; ++k) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
}
} else {
for (int j = 0; j < batch_size; ++j) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data( res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float)); reinterpret_cast<char *>(&(data_ptr[0])), sizeof(float));
} }
} }
} }
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_text_response_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_text_response_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
...@@ -36,12 +36,10 @@ using baidu::paddle_serving::predictor::InferManager; ...@@ -36,12 +36,10 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralTextResponseOp::inference() { int GeneralTextResponseOp::inference() {
const GeneralBlob *input_blob = const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
get_depend_argument<GeneralBlob>(pre_name());
if (!input_blob) { if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op: " LOG(ERROR) << "Failed mutable depended argument, op: " << pre_name();
<< pre_name();
return -1; return -1;
} }
...@@ -68,13 +66,13 @@ int GeneralTextResponseOp::inference() { ...@@ -68,13 +66,13 @@ int GeneralTextResponseOp::inference() {
fetch_index[i] = fetch_index[i] =
model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)]; model_config->_fetch_alias_name_to_index[req->fetch_var_names(i)];
} }
// response inst with only fetch_var_names // response inst with only fetch_var_names
Response *res = mutable_data<Response>(); Response *res = mutable_data<Response>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts(); FetchInst *fetch_inst = res->add_insts();
for (auto & idx : fetch_index) { for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array(); Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor // currently only response float tensor or lod_tensor
tensor->set_elem_type(1); tensor->set_elem_type(1);
...@@ -84,8 +82,7 @@ int GeneralTextResponseOp::inference() { ...@@ -84,8 +82,7 @@ int GeneralTextResponseOp::inference() {
} else { } else {
VLOG(2) << "out[" << idx << "] is tensor"; VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) { for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k - 1 << "]: " VLOG(2) << "shape[" << k - 1 << "]: " << in->at(idx).shape[k];
<< in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]); tensor->add_shape(in->at(idx).shape[k]);
} }
} }
...@@ -93,7 +90,7 @@ int GeneralTextResponseOp::inference() { ...@@ -93,7 +90,7 @@ int GeneralTextResponseOp::inference() {
} }
int var_idx = 0; int var_idx = 0;
for (auto & idx : fetch_index) { for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data()); float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1; int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) { for (int j = 1; j < in->at(idx).shape.size(); ++j) {
...@@ -101,8 +98,8 @@ int GeneralTextResponseOp::inference() { ...@@ -101,8 +98,8 @@ int GeneralTextResponseOp::inference() {
} }
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) { for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j]; for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k < in->at(idx).lod[0][j + 1]; k++) { k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_float_data( res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_float_data(
data_ptr[k]); data_ptr[k]);
} }
...@@ -117,10 +114,10 @@ int GeneralTextResponseOp::inference() { ...@@ -117,10 +114,10 @@ int GeneralTextResponseOp::inference() {
} }
var_idx++; var_idx++;
} }
if (req->profile_server()) { if (req->profile_server()) {
int64_t end = timeline.TimeStampUS(); int64_t end = timeline.TimeStampUS();
for (int i = 0; i < input_blob->p_size; ++i) { for (int i = 0; i < input_blob->p_size; ++i) {
res->add_profile_time(input_blob->time_stamp[i]); res->add_profile_time(input_blob->time_stamp[i]);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册