提交 d986670e 编写于 作者: M MRXLT

fix batch size

上级 4d435db4
......@@ -37,6 +37,8 @@ struct GeneralBlob {
double infer_time;
std::vector<std::string> fetch_name_vector;
int _batch_size;
void Clear() {
size_t tensor_count = tensor_vector.size();
for (size_t ti = 0; ti < tensor_count; ++ti) {
......@@ -45,6 +47,10 @@ struct GeneralBlob {
tensor_vector.clear();
}
int SetBatchSize(int batch_size) { _batch_size = batch_size; }
int GetBatchSize() const { return _batch_size; }
/*
int GetBatchSize() const {
if (tensor_vector.size() > 0) {
if (tensor_vector[0].lod.size() == 1) {
......@@ -56,7 +62,7 @@ struct GeneralBlob {
return -1;
}
}
*/
std::string ShortDebugString() const { return "Not implemented!"; }
};
......
......@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_infer_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
......@@ -36,33 +36,31 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralInferOp::inference() {
const GeneralBlob * input_blob =
get_depend_argument<GeneralBlob>(pre_name());
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
GeneralBlob * output_blob = mutable_data<GeneralBlob>();
GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op:"
<< pre_name();
LOG(ERROR) << "Failed mutable depended argument, op:" << pre_name();
return -1;
}
const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->GetBatchSize();
output_blob->SetBatchSize(batch_size);
VLOG(2) << "infer batch size: " << batch_size;
// infer
Timer timeline;
double infer_time = 0.0;
timeline.Start();
// Timer timeline;
// double infer_time = 0.0;
// timeline.Start();
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1;
}
timeline.Pause();
infer_time = timeline.ElapsedUS();
// timeline.Pause();
// infer_time = timeline.ElapsedUS();
return 0;
}
DEFINE_OP(GeneralInferOp);
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_reader_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_reader_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
......@@ -73,7 +73,6 @@ int GeneralReaderOp::inference() {
int batch_size = req->insts_size();
int input_var_num = 0;
std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size;
std::vector<int64_t> capacity;
......@@ -81,6 +80,8 @@ int GeneralReaderOp::inference() {
GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *out = &res->tensor_vector;
res->SetBatchSize(batch_size);
if (!res) {
LOG(ERROR) << "Failed get op tls reader object output";
}
......
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/general-server/op/general_response_op.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_response_op.h"
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
......@@ -37,12 +37,10 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralResponseOp::inference() {
const GeneralBlob *input_blob =
get_depend_argument<GeneralBlob>(pre_name());
const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op: "
<< pre_name();
LOG(ERROR) << "Failed mutable depended argument, op: " << pre_name();
return -1;
}
......@@ -75,7 +73,7 @@ int GeneralResponseOp::inference() {
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts();
for (auto & idx : fetch_index) {
for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor
tensor->set_elem_type(1);
......@@ -85,8 +83,7 @@ int GeneralResponseOp::inference() {
} else {
VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k - 1 << "]: "
<< in->at(idx).shape[k];
VLOG(2) << "shape[" << k - 1 << "]: " << in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]);
}
}
......@@ -94,7 +91,7 @@ int GeneralResponseOp::inference() {
}
int var_idx = 0;
for (auto & idx : fetch_index) {
for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) {
......@@ -102,8 +99,8 @@ int GeneralResponseOp::inference() {
}
if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j];
k < in->at(idx).lod[0][j + 1]; k++) {
for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册