提交 d986670e 编写于 作者: M MRXLT

fix batch size

上级 4d435db4
...@@ -37,6 +37,8 @@ struct GeneralBlob { ...@@ -37,6 +37,8 @@ struct GeneralBlob {
double infer_time; double infer_time;
std::vector<std::string> fetch_name_vector; std::vector<std::string> fetch_name_vector;
int _batch_size;
void Clear() { void Clear() {
size_t tensor_count = tensor_vector.size(); size_t tensor_count = tensor_vector.size();
for (size_t ti = 0; ti < tensor_count; ++ti) { for (size_t ti = 0; ti < tensor_count; ++ti) {
...@@ -45,6 +47,10 @@ struct GeneralBlob { ...@@ -45,6 +47,10 @@ struct GeneralBlob {
tensor_vector.clear(); tensor_vector.clear();
} }
int SetBatchSize(int batch_size) { _batch_size = batch_size; }
int GetBatchSize() const { return _batch_size; }
/*
int GetBatchSize() const { int GetBatchSize() const {
if (tensor_vector.size() > 0) { if (tensor_vector.size() > 0) {
if (tensor_vector[0].lod.size() == 1) { if (tensor_vector[0].lod.size() == 1) {
...@@ -56,7 +62,7 @@ struct GeneralBlob { ...@@ -56,7 +62,7 @@ struct GeneralBlob {
return -1; return -1;
} }
} }
*/
std::string ShortDebugString() const { return "Not implemented!"; } std::string ShortDebugString() const { return "Not implemented!"; }
}; };
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_infer_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_infer_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
...@@ -36,33 +36,31 @@ using baidu::paddle_serving::predictor::InferManager; ...@@ -36,33 +36,31 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralInferOp::inference() { int GeneralInferOp::inference() {
const GeneralBlob * input_blob = const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
get_depend_argument<GeneralBlob>(pre_name());
GeneralBlob * output_blob = mutable_data<GeneralBlob>(); GeneralBlob *output_blob = mutable_data<GeneralBlob>();
if (!input_blob) { if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op:" LOG(ERROR) << "Failed mutable depended argument, op:" << pre_name();
<< pre_name();
return -1; return -1;
} }
const TensorVector *in = &input_blob->tensor_vector; const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector; TensorVector *out = &output_blob->tensor_vector;
int batch_size = input_blob->GetBatchSize(); int batch_size = input_blob->GetBatchSize();
output_blob->SetBatchSize(batch_size);
VLOG(2) << "infer batch size: " << batch_size; VLOG(2) << "infer batch size: " << batch_size;
// infer // infer
Timer timeline; // Timer timeline;
double infer_time = 0.0; // double infer_time = 0.0;
timeline.Start(); // timeline.Start();
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) { if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME; LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME;
return -1; return -1;
} }
timeline.Pause(); // timeline.Pause();
infer_time = timeline.ElapsedUS(); // infer_time = timeline.ElapsedUS();
return 0; return 0;
} }
DEFINE_OP(GeneralInferOp); DEFINE_OP(GeneralInferOp);
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_reader_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_infer_helper.h" #include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_reader_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
...@@ -73,7 +73,6 @@ int GeneralReaderOp::inference() { ...@@ -73,7 +73,6 @@ int GeneralReaderOp::inference() {
int batch_size = req->insts_size(); int batch_size = req->insts_size();
int input_var_num = 0; int input_var_num = 0;
std::vector<int64_t> elem_type; std::vector<int64_t> elem_type;
std::vector<int64_t> elem_size; std::vector<int64_t> elem_size;
std::vector<int64_t> capacity; std::vector<int64_t> capacity;
...@@ -81,6 +80,8 @@ int GeneralReaderOp::inference() { ...@@ -81,6 +80,8 @@ int GeneralReaderOp::inference() {
GeneralBlob *res = mutable_data<GeneralBlob>(); GeneralBlob *res = mutable_data<GeneralBlob>();
TensorVector *out = &res->tensor_vector; TensorVector *out = &res->tensor_vector;
res->SetBatchSize(batch_size);
if (!res) { if (!res) {
LOG(ERROR) << "Failed get op tls reader object output"; LOG(ERROR) << "Failed get op tls reader object output";
} }
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_response_op.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "core/general-server/op/general_infer_helper.h" #include "core/general-server/op/general_infer_helper.h"
#include "core/general-server/op/general_response_op.h"
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
...@@ -37,12 +37,10 @@ using baidu::paddle_serving::predictor::InferManager; ...@@ -37,12 +37,10 @@ using baidu::paddle_serving::predictor::InferManager;
using baidu::paddle_serving::predictor::PaddleGeneralModelConfig; using baidu::paddle_serving::predictor::PaddleGeneralModelConfig;
int GeneralResponseOp::inference() { int GeneralResponseOp::inference() {
const GeneralBlob *input_blob = const GeneralBlob *input_blob = get_depend_argument<GeneralBlob>(pre_name());
get_depend_argument<GeneralBlob>(pre_name());
if (!input_blob) { if (!input_blob) {
LOG(ERROR) << "Failed mutable depended argument, op: " LOG(ERROR) << "Failed mutable depended argument, op: " << pre_name();
<< pre_name();
return -1; return -1;
} }
...@@ -75,7 +73,7 @@ int GeneralResponseOp::inference() { ...@@ -75,7 +73,7 @@ int GeneralResponseOp::inference() {
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts(); FetchInst *fetch_inst = res->add_insts();
for (auto & idx : fetch_index) { for (auto &idx : fetch_index) {
Tensor *tensor = fetch_inst->add_tensor_array(); Tensor *tensor = fetch_inst->add_tensor_array();
// currently only response float tensor or lod_tensor // currently only response float tensor or lod_tensor
tensor->set_elem_type(1); tensor->set_elem_type(1);
...@@ -85,8 +83,7 @@ int GeneralResponseOp::inference() { ...@@ -85,8 +83,7 @@ int GeneralResponseOp::inference() {
} else { } else {
VLOG(2) << "out[" << idx << "] is tensor"; VLOG(2) << "out[" << idx << "] is tensor";
for (int k = 1; k < in->at(idx).shape.size(); ++k) { for (int k = 1; k < in->at(idx).shape.size(); ++k) {
VLOG(2) << "shape[" << k - 1 << "]: " VLOG(2) << "shape[" << k - 1 << "]: " << in->at(idx).shape[k];
<< in->at(idx).shape[k];
tensor->add_shape(in->at(idx).shape[k]); tensor->add_shape(in->at(idx).shape[k]);
} }
} }
...@@ -94,7 +91,7 @@ int GeneralResponseOp::inference() { ...@@ -94,7 +91,7 @@ int GeneralResponseOp::inference() {
} }
int var_idx = 0; int var_idx = 0;
for (auto & idx : fetch_index) { for (auto &idx : fetch_index) {
float *data_ptr = static_cast<float *>(in->at(idx).data.data()); float *data_ptr = static_cast<float *>(in->at(idx).data.data());
int cap = 1; int cap = 1;
for (int j = 1; j < in->at(idx).shape.size(); ++j) { for (int j = 1; j < in->at(idx).shape.size(); ++j) {
...@@ -102,8 +99,8 @@ int GeneralResponseOp::inference() { ...@@ -102,8 +99,8 @@ int GeneralResponseOp::inference() {
} }
if (model_config->_is_lod_fetch[idx]) { if (model_config->_is_lod_fetch[idx]) {
for (int j = 0; j < batch_size; ++j) { for (int j = 0; j < batch_size; ++j) {
for (int k = in->at(idx).lod[0][j]; for (int k = in->at(idx).lod[0][j]; k < in->at(idx).lod[0][j + 1];
k < in->at(idx).lod[0][j + 1]; k++) { k++) {
res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data( res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data(
reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float)); reinterpret_cast<char *>(&(data_ptr[k])), sizeof(float));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册