提交 b77f39b9 编写于 作者: S ShiningZhang

c++ client support uint8&int8

上级 83902386
......@@ -88,7 +88,7 @@ class PredictorData {
const std::string& name,
const std::vector<int>& shape,
const std::vector<int>& lod,
const int datatype = 3);
const int datatype = 20);
const std::map<std::string, std::vector<float>>& float_data_map() const {
return _float_data_map;
......@@ -140,6 +140,8 @@ class PredictorData {
int get_datatype(std::string name) const;
void set_datatype(std::string name, int type);
std::string print();
private:
......@@ -159,6 +161,7 @@ class PredictorData {
oss << "{";
oss << it->first << key_seg;
const std::vector<T2>& v = it->second;
oss << v.size() << key_seg;
for (size_t i = 0; i < v.size(); ++i) {
if (i != v.size() - 1) {
oss << v[i] << val_seg;
......@@ -184,7 +187,9 @@ class PredictorData {
typename std::map<T1, T2>::const_iterator itEnd = map.end();
for (; it != itEnd; it++) {
oss << "{";
oss << it->first << key_seg << it->second;
oss << it->first << key_seg
<< "size=" << it->second.size() << key_seg
<< "type=" << this->get_datatype(it->first);
oss << "}";
}
return oss.str();
......
......@@ -172,6 +172,10 @@ int PredictorData::get_datatype(std::string name) const {
return 0;
}
void PredictorData::set_datatype(std::string name, int type) {
_datatype_map[name] = type;
}
std::string PredictorData::print() {
std::string res;
res.append(map2string<std::string, float>(_float_data_map));
......@@ -325,20 +329,25 @@ int PredictorInputs::GenProto(const PredictorInputs& inputs,
tensor->set_name(feed_name[idx]);
tensor->set_alias_name(name);
const int string_shape_size = string_shape.size();
// string_shape[vec_idx] = [1];cause numpy has no datatype of string.
// we pass string via vector<vector<string> >.
if (string_shape_size != 1) {
LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
<< string_shape_size;
return -1;
}
switch (string_shape_size) {
case 1: {
tensor->add_data(string_data);
break;
if (datatype == P_STRING) {
const int string_shape_size = string_shape.size();
// string_shape[vec_idx] = [1];cause numpy has no datatype of string.
// we pass string via vector<vector<string> >.
if (string_shape_size != 1) {
LOG(ERROR) << "string_shape_size should be 1-D, but received is : "
<< string_shape_size;
return -1;
}
switch (string_shape_size) {
case 1: {
tensor->add_data(string_data);
break;
}
}
} else {
tensor->set_tensor_content(string_data);
}
}
return 0;
}
......@@ -371,6 +380,8 @@ int PredictorOutputs::ParseProto(const Response& res,
std::shared_ptr<PredictorOutputs::PredictorOutput> predictor_output =
std::make_shared<PredictorOutputs::PredictorOutput>();
predictor_output->engine_name = output.engine_name();
PredictorData& predictor_data = predictor_output->data;
std::map<std::string, std::vector<float>>& float_data_map = *predictor_output->data.mutable_float_data_map();
std::map<std::string, std::vector<int64_t>>& int64_data_map = *predictor_output->data.mutable_int64_data_map();
std::map<std::string, std::vector<int32_t>>& int32_data_map = *predictor_output->data.mutable_int_data_map();
......@@ -419,7 +430,13 @@ int PredictorOutputs::ParseProto(const Response& res,
int32_data_map[name] = std::vector<int32_t>(
output.tensor(idx).int_data().begin(),
output.tensor(idx).int_data().begin() + size);
} else if (fetch_name_to_type[name] == P_UINT8
|| fetch_name_to_type[name] == P_INT8) {
VLOG(2) << "fetch var [" << name << "]type="
<< fetch_name_to_type[name];
string_data_map[name] = output.tensor(idx).tensor_content();
}
predictor_data.set_datatype(name, output.tensor(idx).elem_type());
idx += 1;
}
outputs.add_data(predictor_output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册