提交 b77f39b9 编写于 作者: S ShiningZhang

c++ client support uint8&int8

上级 83902386
...@@ -88,7 +88,7 @@ class PredictorData { ...@@ -88,7 +88,7 @@ class PredictorData {
const std::string& name, const std::string& name,
const std::vector<int>& shape, const std::vector<int>& shape,
const std::vector<int>& lod, const std::vector<int>& lod,
const int datatype = 3); const int datatype = 20);
const std::map<std::string, std::vector<float>>& float_data_map() const { const std::map<std::string, std::vector<float>>& float_data_map() const {
return _float_data_map; return _float_data_map;
...@@ -140,6 +140,8 @@ class PredictorData { ...@@ -140,6 +140,8 @@ class PredictorData {
int get_datatype(std::string name) const; int get_datatype(std::string name) const;
void set_datatype(std::string name, int type);
std::string print(); std::string print();
private: private:
...@@ -159,6 +161,7 @@ class PredictorData { ...@@ -159,6 +161,7 @@ class PredictorData {
oss << "{"; oss << "{";
oss << it->first << key_seg; oss << it->first << key_seg;
const std::vector<T2>& v = it->second; const std::vector<T2>& v = it->second;
oss << v.size() << key_seg;
for (size_t i = 0; i < v.size(); ++i) { for (size_t i = 0; i < v.size(); ++i) {
if (i != v.size() - 1) { if (i != v.size() - 1) {
oss << v[i] << val_seg; oss << v[i] << val_seg;
...@@ -184,7 +187,9 @@ class PredictorData { ...@@ -184,7 +187,9 @@ class PredictorData {
typename std::map<T1, T2>::const_iterator itEnd = map.end(); typename std::map<T1, T2>::const_iterator itEnd = map.end();
for (; it != itEnd; it++) { for (; it != itEnd; it++) {
oss << "{"; oss << "{";
oss << it->first << key_seg << it->second; oss << it->first << key_seg
<< "size=" << it->second.size() << key_seg
<< "type=" << this->get_datatype(it->first);
oss << "}"; oss << "}";
} }
return oss.str(); return oss.str();
......
...@@ -172,6 +172,10 @@ int PredictorData::get_datatype(std::string name) const { ...@@ -172,6 +172,10 @@ int PredictorData::get_datatype(std::string name) const {
return 0; return 0;
} }
void PredictorData::set_datatype(std::string name, int type) {
_datatype_map[name] = type;
}
std::string PredictorData::print() { std::string PredictorData::print() {
std::string res; std::string res;
res.append(map2string<std::string, float>(_float_data_map)); res.append(map2string<std::string, float>(_float_data_map));
...@@ -325,20 +329,25 @@ int PredictorInputs::GenProto(const PredictorInputs& inputs, ...@@ -325,20 +329,25 @@ int PredictorInputs::GenProto(const PredictorInputs& inputs,
tensor->set_name(feed_name[idx]); tensor->set_name(feed_name[idx]);
tensor->set_alias_name(name); tensor->set_alias_name(name);
const int string_shape_size = string_shape.size(); if (datatype == P_STRING) {
// string_shape[vec_idx] = [1];cause numpy has no datatype of string. const int string_shape_size = string_shape.size();
// we pass string via vector<vector<string> >. // string_shape[vec_idx] = [1];cause numpy has no datatype of string.
if (string_shape_size != 1) { // we pass string via vector<vector<string> >.
LOG(ERROR) << "string_shape_size should be 1-D, but received is : " if (string_shape_size != 1) {
<< string_shape_size; LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
return -1; << string_shape_size;
} return -1;
switch (string_shape_size) { }
case 1: { switch (string_shape_size) {
tensor->add_data(string_data); case 1: {
break; tensor->add_data(string_data);
break;
}
} }
} else {
tensor->set_tensor_content(string_data);
} }
} }
return 0; return 0;
} }
...@@ -371,6 +380,8 @@ int PredictorOutputs::ParseProto(const Response& res, ...@@ -371,6 +380,8 @@ int PredictorOutputs::ParseProto(const Response& res,
std::shared_ptr<PredictorOutputs::PredictorOutput> predictor_output = std::shared_ptr<PredictorOutputs::PredictorOutput> predictor_output =
std::make_shared<PredictorOutputs::PredictorOutput>(); std::make_shared<PredictorOutputs::PredictorOutput>();
predictor_output->engine_name = output.engine_name(); predictor_output->engine_name = output.engine_name();
PredictorData& predictor_data = predictor_output->data;
std::map<std::string, std::vector<float>>& float_data_map = *predictor_output->data.mutable_float_data_map(); std::map<std::string, std::vector<float>>& float_data_map = *predictor_output->data.mutable_float_data_map();
std::map<std::string, std::vector<int64_t>>& int64_data_map = *predictor_output->data.mutable_int64_data_map(); std::map<std::string, std::vector<int64_t>>& int64_data_map = *predictor_output->data.mutable_int64_data_map();
std::map<std::string, std::vector<int32_t>>& int32_data_map = *predictor_output->data.mutable_int_data_map(); std::map<std::string, std::vector<int32_t>>& int32_data_map = *predictor_output->data.mutable_int_data_map();
...@@ -419,7 +430,13 @@ int PredictorOutputs::ParseProto(const Response& res, ...@@ -419,7 +430,13 @@ int PredictorOutputs::ParseProto(const Response& res,
int32_data_map[name] = std::vector<int32_t>( int32_data_map[name] = std::vector<int32_t>(
output.tensor(idx).int_data().begin(), output.tensor(idx).int_data().begin(),
output.tensor(idx).int_data().begin() + size); output.tensor(idx).int_data().begin() + size);
} else if (fetch_name_to_type[name] == P_UINT8
|| fetch_name_to_type[name] == P_INT8) {
VLOG(2) << "fetch var [" << name << "]type="
<< fetch_name_to_type[name];
string_data_map[name] = output.tensor(idx).tensor_content();
} }
predictor_data.set_datatype(name, output.tensor(idx).elem_type());
idx += 1; idx += 1;
} }
outputs.add_data(predictor_output); outputs.add_data(predictor_output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册