提交 1c2c236d 编写于 作者: M MRXLT

Merge remote-tracking branch 'origin/ce-script' into ce-script

...@@ -114,70 +114,48 @@ int GeneralResponseOp::inference() { ...@@ -114,70 +114,48 @@ int GeneralResponseOp::inference() {
for (int j = 0; j < in->at(idx).shape.size(); ++j) { for (int j = 0; j < in->at(idx).shape.size(); ++j) {
cap *= in->at(idx).shape[j]; cap *= in->at(idx).shape[j];
} }
if (in->at(idx).dtype == paddle::PaddleDType::INT64) {
FetchInst *fetch_p = output->mutable_insts(0);
auto dtype = in->at(idx).dtype;
if (dtype == paddle::PaddleDType::INT64) {
VLOG(2) << "Prepare int64 var [" << model_config->_fetch_name[idx] VLOG(2) << "Prepare int64 var [" << model_config->_fetch_name[idx]
<< "]."; << "].";
int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data()); int64_t *data_ptr = static_cast<int64_t *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) { // from
FetchInst *fetch_p = output->mutable_insts(0); // https://stackoverflow.com/questions/15499641/copy-a-stdvector-to-a-repeated-field-from-protobuf-with-memcpy
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) { // `Swap` method is faster than `{}` method.
fetch_p->mutable_tensor_array(var_idx)->add_lod( google::protobuf::RepeatedField<int64_t> tmp_data(data_ptr,
in->at(idx).lod[0][j]); data_ptr + cap);
} fetch_p->mutable_tensor_array(var_idx)->mutable_int64_data()->Swap(
for (int j = 0; j < cap; ++j) { &tmp_data);
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]); } else if (dtype == paddle::PaddleDType::FLOAT32) {
}
} else {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int64_data(data_ptr[j]);
}
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} else if (in->at(idx).dtype == paddle::PaddleDType::FLOAT32) {
VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx] VLOG(2) << "Prepare float var [" << model_config->_fetch_name[idx]
<< "]."; << "].";
float *data_ptr = static_cast<float *>(in->at(idx).data.data()); float *data_ptr = static_cast<float *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) { google::protobuf::RepeatedField<float> tmp_data(data_ptr,
FetchInst *fetch_p = output->mutable_insts(0); data_ptr + cap);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) { fetch_p->mutable_tensor_array(var_idx)->mutable_float_data()->Swap(
fetch_p->mutable_tensor_array(var_idx)->add_lod( &tmp_data);
in->at(idx).lod[0][j]); } else if (dtype == paddle::PaddleDType::INT32) {
}
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
}
} else {
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_float_data(data_ptr[j]);
}
}
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} else if (in->at(idx).dtype == paddle::PaddleDType::INT32) {
VLOG(2) << "Prepare int32 var [" << model_config->_fetch_name[idx] VLOG(2) << "Prepare int32 var [" << model_config->_fetch_name[idx]
<< "]."; << "].";
int32_t *data_ptr = static_cast<int32_t *>(in->at(idx).data.data()); int32_t *data_ptr = static_cast<int32_t *>(in->at(idx).data.data());
if (model_config->_is_lod_fetch[idx]) { google::protobuf::RepeatedField<int32_t> tmp_data(data_ptr,
FetchInst *fetch_p = output->mutable_insts(0); data_ptr + cap);
for (int j = 0; j < in->at(idx).lod[0].size(); ++j) { fetch_p->mutable_tensor_array(var_idx)->mutable_int_data()->Swap(
fetch_p->mutable_tensor_array(var_idx)->add_lod( &tmp_data);
in->at(idx).lod[0][j]); }
}
for (int j = 0; j < cap; ++j) { if (model_config->_is_lod_fetch[idx]) {
fetch_p->mutable_tensor_array(var_idx)->add_int_data(data_ptr[j]); for (int j = 0; j < in->at(idx).lod[0].size(); ++j) {
} fetch_p->mutable_tensor_array(var_idx)->add_lod(
} else { in->at(idx).lod[0][j]);
FetchInst *fetch_p = output->mutable_insts(0);
for (int j = 0; j < cap; ++j) {
fetch_p->mutable_tensor_array(var_idx)->add_int_data(data_ptr[j]);
}
} }
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} }
VLOG(2) << "fetch var [" << model_config->_fetch_name[idx] << "] ready";
var_idx++;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册