提交 bbef17fa 编写于 作者: S ShiningZhang

fix pybind: merge the traversal of string_map

上级 24d76a63
...@@ -287,9 +287,6 @@ int PredictorClient::numpy_predict( ...@@ -287,9 +287,6 @@ int PredictorClient::numpy_predict(
LOG(ERROR) << "idx > tensor_vec.size()"; LOG(ERROR) << "idx > tensor_vec.size()";
return -1; return -1;
} }
if (_type[idx] == P_STRING) {
continue;
}
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) { for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
...@@ -298,36 +295,14 @@ int PredictorClient::numpy_predict( ...@@ -298,36 +295,14 @@ int PredictorClient::numpy_predict(
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) { for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]); tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
} }
tensor->set_elem_type(_type[idx]);
tensor->set_name(_feed_name[idx]); tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name); tensor->set_alias_name(name);
tensor->set_tensor_content(string_feed[vec_idx]);
vec_idx++;
}
vec_idx = 0;
for (auto &name : string_feed_name) {
int idx = _feed_name_to_idx[name];
if (idx >= tensor_vec.size()) {
LOG(ERROR) << "idx > tensor_vec.size()";
return -1;
}
if (_type[idx] != P_STRING) { if (_type[idx] != P_STRING) {
continue; tensor->set_elem_type(_type[idx]);
} tensor->set_tensor_content(string_feed[vec_idx]);
Tensor *tensor = tensor_vec[idx]; } else {
for (uint32_t j = 0; j < string_shape[vec_idx].size(); ++j) {
tensor->add_shape(string_shape[vec_idx][j]);
}
for (uint32_t j = 0; j < string_lod_slot_batch[vec_idx].size(); ++j) {
tensor->add_lod(string_lod_slot_batch[vec_idx][j]);
}
tensor->set_elem_type(P_STRING); tensor->set_elem_type(P_STRING);
tensor->set_name(_feed_name[idx]);
tensor->set_alias_name(name);
const int string_shape_size = string_shape[vec_idx].size(); const int string_shape_size = string_shape[vec_idx].size();
// string_shape[vec_idx] = [1];cause numpy has no datatype of string. // string_shape[vec_idx] = [1];cause numpy has no datatype of string.
// we pass string via vector<vector<string> >. // we pass string via vector<vector<string> >.
...@@ -342,6 +317,7 @@ int PredictorClient::numpy_predict( ...@@ -342,6 +317,7 @@ int PredictorClient::numpy_predict(
break; break;
} }
} }
}
vec_idx++; vec_idx++;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册