From b77f39b97651762cfc3b2cf537924fe93059b203 Mon Sep 17 00:00:00 2001 From: ShiningZhang Date: Fri, 3 Sep 2021 18:01:01 +0800 Subject: [PATCH] c++ client support uint8&int8 --- core/general-client/include/client.h | 9 ++++-- core/general-client/src/client.cpp | 41 ++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/core/general-client/include/client.h b/core/general-client/include/client.h index 689732c5..11c6a2b7 100644 --- a/core/general-client/include/client.h +++ b/core/general-client/include/client.h @@ -88,7 +88,7 @@ class PredictorData { const std::string& name, const std::vector& shape, const std::vector& lod, - const int datatype = 3); + const int datatype = 20); const std::map>& float_data_map() const { return _float_data_map; @@ -140,6 +140,8 @@ class PredictorData { int get_datatype(std::string name) const; + void set_datatype(std::string name, int type); + std::string print(); private: @@ -159,6 +161,7 @@ class PredictorData { oss << "{"; oss << it->first << key_seg; const std::vector& v = it->second; + oss << v.size() << key_seg; for (size_t i = 0; i < v.size(); ++i) { if (i != v.size() - 1) { oss << v[i] << val_seg; @@ -184,7 +187,9 @@ class PredictorData { typename std::map::const_iterator itEnd = map.end(); for (; it != itEnd; it++) { oss << "{"; - oss << it->first << key_seg << it->second; + oss << it->first << key_seg + << "size=" << it->second.size() << key_seg + << "type=" << this->get_datatype(it->first); oss << "}"; } return oss.str(); diff --git a/core/general-client/src/client.cpp b/core/general-client/src/client.cpp index 0b9f067f..4d3b99f2 100644 --- a/core/general-client/src/client.cpp +++ b/core/general-client/src/client.cpp @@ -172,6 +172,10 @@ int PredictorData::get_datatype(std::string name) const { return 0; } +void PredictorData::set_datatype(std::string name, int type) { + _datatype_map[name] = type; +} + std::string PredictorData::print() { std::string res; res.append(map2string(_float_data_map)); @@ -325,20 +329,25 @@ int PredictorInputs::GenProto(const PredictorInputs& inputs, tensor->set_name(feed_name[idx]); tensor->set_alias_name(name); - const int string_shape_size = string_shape.size(); - // string_shape[vec_idx] = [1];cause numpy has no datatype of string. - // we pass string via vector >. - if (string_shape_size != 1) { - LOG(ERROR) << "string_shape_size should be 1-D, but received is : " - << string_shape_size; - return -1; - } - switch (string_shape_size) { - case 1: { - tensor->add_data(string_data); - break; + if (datatype == P_STRING) { + const int string_shape_size = string_shape.size(); + // string_shape[vec_idx] = [1];cause numpy has no datatype of string. + // we pass string via vector >. + if (string_shape_size != 1) { + LOG(ERROR) << "string_shape_size should be 1-D, but received is : " + << string_shape_size; + return -1; + } + switch (string_shape_size) { + case 1: { + tensor->add_data(string_data); + break; + } } + } else { + tensor->set_tensor_content(string_data); } + } return 0; } @@ -371,6 +380,8 @@ int PredictorOutputs::ParseProto(const Response& res, std::shared_ptr predictor_output = std::make_shared(); predictor_output->engine_name = output.engine_name(); + + PredictorData& predictor_data = predictor_output->data; std::map>& float_data_map = *predictor_output->data.mutable_float_data_map(); std::map>& int64_data_map = *predictor_output->data.mutable_int64_data_map(); std::map>& int32_data_map = *predictor_output->data.mutable_int_data_map(); @@ -419,7 +430,13 @@ int PredictorOutputs::ParseProto(const Response& res, int32_data_map[name] = std::vector( output.tensor(idx).int_data().begin(), output.tensor(idx).int_data().begin() + size); + } else if (fetch_name_to_type[name] == P_UINT8 + || fetch_name_to_type[name] == P_INT8) { + VLOG(2) << "fetch var [" << name << "]type=" + << fetch_name_to_type[name]; + string_data_map[name] = output.tensor(idx).tensor_content(); } + predictor_data.set_datatype(name, output.tensor(idx).elem_type()); idx += 1; } outputs.add_data(predictor_output); -- GitLab