提交 3de8d893 编写于 作者: B barrierye

update general_model

上级 4ee4f83b
...@@ -52,7 +52,24 @@ class ModelRes { ...@@ -52,7 +52,24 @@ class ModelRes {
const std::string& name) { const std::string& name) {
return _float_map[name]; return _float_map[name];
} }
void set_engine_name(const std::string& engine_name) {
_engine_name = engine_name;
}
const std::string& engine_name() {
return engine_name;
}
ModelRes& operator = (ModelRes&& res) {
std::cout << "move ++++++++>";
if (this != &res) {
_int64_map = res._int64_map;
_float_map = res._float_map;
res._int64_map = nullptr;
res._float_map = nullptr;
}
return *this;
}
public: public:
std::string _engine_name;
std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map; std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map;
std::map<std::string, std::vector<std::vector<float>>> _float_map; std::map<std::string, std::vector<std::vector<float>>> _float_map;
}; };
...@@ -63,7 +80,10 @@ class PredictorRes { ...@@ -63,7 +80,10 @@ class PredictorRes {
~PredictorRes() {} ~PredictorRes() {}
public: public:
void clear() { _models.clear();} void clear() {
_models.clear();
_engine_names.clear();
}
const std::vector<std::vector<int64_t>>& get_int64_by_name( const std::vector<std::vector<int64_t>>& get_int64_by_name(
const int model_idx, const std::string& name) { const int model_idx, const std::string& name) {
return _models[model_idx].get_int64_by_name(name); return _models[model_idx].get_int64_by_name(name);
...@@ -72,16 +92,23 @@ class PredictorRes { ...@@ -72,16 +92,23 @@ class PredictorRes {
const int model_idx, const std::string& name) { const int model_idx, const std::string& name) {
return _models[model_idx].get_float_by_name(name); return _models[model_idx].get_float_by_name(name);
} }
void add_model_res(ModelRes&& res) {
_engine_names.push_back(res.engine_name());
_models.emplace_back(res);
}
void set_variant_tag(const std::string& variant_tag) { void set_variant_tag(const std::string& variant_tag) {
_variant_tag = variant_tag; _variant_tag = variant_tag;
} }
const std::string& variant_tag() { return _variant_tag; } const std::string& variant_tag() { return _variant_tag; }
int model_num() {return _models.size();} int model_num() {return _models.size();}
const std::vector<std::string>& get_engine_names() {
std::vector<ModelRes> _models; return _engine_names;
}
private: private:
std::vector<ModelRes> _models;
std::string _variant_tag; std::string _variant_tag;
std::vector<std::string> _engine_names;
}; };
class PredictorClient { class PredictorClient {
......
...@@ -112,6 +112,7 @@ void PredictorClient::set_predictor_conf(const std::string &conf_path, ...@@ -112,6 +112,7 @@ void PredictorClient::set_predictor_conf(const std::string &conf_path,
int PredictorClient::destroy_predictor() { int PredictorClient::destroy_predictor() {
_api.thrd_finalize(); _api.thrd_finalize();
_api.destroy(); _api.destroy();
return 0;
} }
int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) { int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
...@@ -120,6 +121,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) { ...@@ -120,6 +121,7 @@ int PredictorClient::create_predictor_by_desc(const std::string &sdk_desc) {
return -1; return -1;
} }
_api.thrd_initialize(); _api.thrd_initialize();
return 0;
} }
int PredictorClient::create_predictor() { int PredictorClient::create_predictor() {
...@@ -130,6 +132,7 @@ int PredictorClient::create_predictor() { ...@@ -130,6 +132,7 @@ int PredictorClient::create_predictor() {
return -1; return -1;
} }
_api.thrd_initialize(); _api.thrd_initialize();
return 0;
} }
int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
...@@ -166,11 +169,11 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -166,11 +169,11 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
for (auto &name : float_feed_name) { for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name]; int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) { for (uint32_t j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]); tensor->add_shape(_shape[idx][j]);
} }
tensor->set_elem_type(1); tensor->set_elem_type(1);
for (int j = 0; j < float_feed[vec_idx].size(); ++j) { for (uint32_t j = 0; j < float_feed[vec_idx].size(); ++j) {
tensor->add_float_data(float_feed[vec_idx][j]); tensor->add_float_data(float_feed[vec_idx][j]);
} }
vec_idx++; vec_idx++;
...@@ -182,11 +185,11 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -182,11 +185,11 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
for (auto &name : int_feed_name) { for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name]; int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) { for (uint32_t j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]); tensor->add_shape(_shape[idx][j]);
} }
tensor->set_elem_type(0); tensor->set_elem_type(0);
for (int j = 0; j < int_feed[vec_idx].size(); ++j) { for (uint32_t j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_int64_data(int_feed[vec_idx][j]); tensor->add_int64_data(int_feed[vec_idx][j]);
} }
vec_idx++; vec_idx++;
...@@ -321,11 +324,11 @@ int PredictorClient::batch_predict( ...@@ -321,11 +324,11 @@ int PredictorClient::batch_predict(
for (auto &name : float_feed_name) { for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name]; int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) { for (uint32_t j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]); tensor->add_shape(_shape[idx][j]);
} }
tensor->set_elem_type(1); tensor->set_elem_type(1);
for (int j = 0; j < float_feed[vec_idx].size(); ++j) { for (uint32_t j = 0; j < float_feed[vec_idx].size(); ++j) {
tensor->add_float_data(float_feed[vec_idx][j]); tensor->add_float_data(float_feed[vec_idx][j]);
} }
vec_idx++; vec_idx++;
...@@ -338,13 +341,13 @@ int PredictorClient::batch_predict( ...@@ -338,13 +341,13 @@ int PredictorClient::batch_predict(
for (auto &name : int_feed_name) { for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name]; int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
for (int j = 0; j < _shape[idx].size(); ++j) { for (uint32_t j = 0; j < _shape[idx].size(); ++j) {
tensor->add_shape(_shape[idx][j]); tensor->add_shape(_shape[idx][j]);
} }
tensor->set_elem_type(0); tensor->set_elem_type(0);
VLOG(3) << "feed var name " << name << " index " << vec_idx VLOG(3) << "feed var name " << name << " index " << vec_idx
<< "first data " << int_feed[vec_idx][0]; << "first data " << int_feed[vec_idx][0];
for (int j = 0; j < int_feed[vec_idx].size(); ++j) { for (uint32_t j = 0; j < int_feed[vec_idx].size(); ++j) {
tensor->add_int64_data(int_feed[vec_idx][j]); tensor->add_int64_data(int_feed[vec_idx][j]);
} }
vec_idx++; vec_idx++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册