提交 bbef17fa 编写于 作者: S ShiningZhang

fix pybind: merge the traversal of string_map

上级 24d76a63
...@@ -287,9 +287,6 @@ int PredictorClient::numpy_predict( ...@@ -287,9 +287,6 @@ int PredictorClient::numpy_predict(
LOG(ERROR) << "idx > tensor_vec.size()"; LOG(ERROR) << "idx > tensor_vec.size()";
return -1; return -1;
} }
if (_type[idx] == P_STRING) {
continue;
}
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) { for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
...@@ -298,48 +295,27 @@ int PredictorClient::numpy_predict( ...@@ -298,48 +295,27 @@ int PredictorClient::numpy_predict(
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) { for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]); tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
} }
tensor->set_elem_type(_type[idx]);
tensor->set_name(_feed_name[idx]); tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name); tensor->set_alias_name(name);
tensor->set_tensor_content(string_feed[vec_idx]);
vec_idx++;
}
vec_idx = 0;
for (auto &name : string_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
if (_type[idx] != P_STRING) { if (_type[idx] != P_STRING) {
continue; tensor->set_elem_type(_type[idx]);
} tensor->set_tensor_content(string_feed[vec_idx]);
Tensor *tensor = tensor_vec[idx]; } else {
tensor->set_elem_type(P_STRING);
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) { const int string_shape_size = string_shape[vec_idx].size();
tensor->add_shape(string_shape[vec_idx][j]); // string_shape[vec_idx] = [1];cause numpy has no datatype of string.
} // we pass string via vector<vector<string> >.
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) { if (string_shape_size != 1) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]); LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
} << string_shape_size;
tensor->set_elem_type(P_STRING); return -1;
tensor->set_name(_feed_name[idx]); }
tensor->set_alias_name(name); switch (string_shape_size) {
case 1: {
const int string_shape_size = string_shape[vec_idx].size(); tensor->add_data(string_feed[vec_idx]);
// string_shape[vec_idx] = [1];cause numpy has no datatype of string. break;
// we pass string via vector<vector<string> >. }
if (string_shape_size != 1) {
LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
<< string_shape_size;
return -1;
}
switch (string_shape_size) {
case 1: {
tensor->add_data(string_feed[vec_idx]);
break;
} }
} }
vec_idx++; vec_idx++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册