提交 c54db434 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!37 Add GetStreamIdList in ge_runtime

Merge pull request !37 from caifubi/master
......@@ -38,6 +38,8 @@ class ModelRunner {
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
bool UnloadModel(uint32_t model_id);
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);
......
......@@ -60,6 +60,17 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList();
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}
return model_iter->second->GetStreamIdList();
}
bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {
......
......@@ -220,6 +220,7 @@ bool RuntimeModel::LoadTask() {
return false;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
}
if (task_list_.empty()) {
GELOGE(FAILED, "Task list is empty");
......@@ -507,5 +508,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner
} // namespace ge
......@@ -36,6 +36,7 @@ class RuntimeModel {
bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
bool Run();
bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
......@@ -77,6 +78,7 @@ class RuntimeModel {
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};
std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
};
} // namespace model_runner
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册