From 81d79f333a71a57375bf83fadbd09f7b31a9b04b Mon Sep 17 00:00:00 2001 From: zhangzhenghai Date: Tue, 4 Aug 2020 15:46:43 +0800 Subject: [PATCH] modify ge_runtime --- inc/framework/ge_runtime/model_runner.h | 12 +- inc/framework/ge_runtime/task_info.h | 135 ++++++++----------- src/ge/ge_runtime/model_runner.cc | 50 ------- src/ge/ge_runtime/output.cc | 2 +- src/ge/ge_runtime/runtime_model.cc | 91 ++++--------- src/ge/ge_runtime/runtime_model.h | 11 +- src/ge/ge_runtime/task/aicpu_task.cc | 14 +- src/ge/ge_runtime/task/aicpu_task.h | 6 - src/ge/ge_runtime/task/hccl_task.cc | 1 + src/ge/ge_runtime/task/label_goto_task.cc | 70 ---------- src/ge/ge_runtime/task/label_goto_task.h | 41 ------ src/ge/ge_runtime/task/label_set_task.cc | 70 ---------- src/ge/ge_runtime/task/label_set_task.h | 41 ------ src/ge/ge_runtime/task/label_switch_task.cc | 131 ------------------ src/ge/ge_runtime/task/label_switch_task.h | 44 ------ src/ge/ge_runtime/task/stream_switch_task.cc | 2 +- src/ge/ge_runtime/task/task.h | 6 - src/ge/ge_runtime/task/tbe_task.cc | 7 +- src/ge/ge_runtime/task/tbe_task.h | 4 - 19 files changed, 98 insertions(+), 640 deletions(-) delete mode 100644 src/ge/ge_runtime/task/label_goto_task.cc delete mode 100644 src/ge/ge_runtime/task/label_goto_task.h delete mode 100644 src/ge/ge_runtime/task/label_set_task.cc delete mode 100644 src/ge/ge_runtime/task/label_set_task.h delete mode 100644 src/ge/ge_runtime/task/label_switch_task.cc delete mode 100644 src/ge/ge_runtime/task/label_switch_task.h diff --git a/inc/framework/ge_runtime/model_runner.h b/inc/framework/ge_runtime/model_runner.h index e495dfd..6e7abcb 100644 --- a/inc/framework/ge_runtime/model_runner.h +++ b/inc/framework/ge_runtime/model_runner.h @@ -28,7 +28,7 @@ namespace ge { namespace model_runner { class RuntimeModel; -using RuntimeInfo = std::tuple; + class ModelRunner { public: static ModelRunner &Instance(); @@ -36,18 +36,8 @@ class ModelRunner { bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, std::shared_ptr davinci_model, std::shared_ptr listener); - bool DistributeTask(uint32_t model_id); - - bool LoadModelComplete(uint32_t model_id); - const std::vector &GetTaskIdList(uint32_t model_id) const; - const std::vector &GetStreamIdList(uint32_t model_id) const; - - const std::map> &GetRuntimeInfoMap(uint32_t model_id) const; - - void *GetModelHandle(uint32_t model_id) const; - bool UnloadModel(uint32_t model_id); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index 68d7187..a48ed68 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -21,7 +21,6 @@ #include #include #include -#include #include #include "cce/taskdown_api.h" @@ -53,27 +52,21 @@ class TaskInfo { virtual ~TaskInfo() {} uint32_t stream_id() const { return stream_id_; } TaskInfoType type() const { return type_; } - std::string op_name() const { return op_name_; } - bool dump_flag() const { return dump_flag_; } protected: - TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) - : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} + TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} private: - std::string op_name_; uint32_t stream_id_; TaskInfoType type_; - bool dump_flag_; }; class CceTaskInfo : public TaskInfo { public: - CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, - uint32_t block_dim, const std::vector &args, uint32_t args_size, - const std::vector &sm_desc, const std::vector &flow_table, - const std::vector &args_offset, bool is_flowtable) - : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), + CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, + const std::vector &args, uint32_t args_size, const std::vector &sm_desc, + const std::vector &flow_table, const std::vector &args_offset, bool is_flowtable) + : TaskInfo(stream_id, TaskInfoType::CCE), ctx_(ctx), stub_func_(stub_func), block_dim_(block_dim), @@ -109,11 +102,11 @@ class CceTaskInfo : public TaskInfo { class TbeTaskInfo : public TaskInfo { public: - TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, - const std::vector &args, uint32_t args_size, const std::vector &sm_desc, void *binary, - uint32_t binary_size, const std::vector &meta_data, const std::vector &input_data_addrs, - const std::vector &output_data_addrs, const std::vector &workspace_addrs, bool dump_flag) - : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), + TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector &args, + uint32_t args_size, const std::vector &sm_desc, void *binary, uint32_t binary_size, + const std::vector &meta_data, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, const std::vector &workspace_addrs) + : TaskInfo(stream_id, TaskInfoType::TBE), stub_func_(stub_func), block_dim_(block_dim), args_(args), @@ -160,10 +153,9 @@ class TbeTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo { public: - AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, - const std::string &node_def, const std::vector &input_data_addrs, - const std::vector &output_data_addrs, bool dump_flag) - : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), + AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, + const std::vector &input_data_addrs, const std::vector &output_data_addrs) + : TaskInfo(stream_id, TaskInfoType::AICPU), so_name_(so_name), kernel_name_(kernel_name), node_def_(node_def), @@ -185,45 +177,37 @@ class AicpuTaskInfo : public TaskInfo { std::vector output_data_addrs_; }; -class LabelSetTaskInfo : public TaskInfo { +class LabelTaskInfo : public TaskInfo { public: - LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) - : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} - ~LabelSetTaskInfo() override {} uint32_t label_id() const { return label_id_; } - private: + protected: + LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) + : TaskInfo(stream_id, type), label_id_(label_id) {} + virtual ~LabelTaskInfo() override {} + uint32_t label_id_; }; -class LabelGotoTaskInfo : public TaskInfo { +class LabelSetTaskInfo : public LabelTaskInfo { public: - LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) - : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} - ~LabelGotoTaskInfo() override {} - uint32_t label_id() const { return label_id_; } - - private: - uint32_t label_id_; + LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) + : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} + ~LabelSetTaskInfo() override {} }; -class LabelSwitchTaskInfo : public TaskInfo { +class LabelSwitchTaskInfo : public LabelTaskInfo { public: - LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, - const std::vector &label_list, void *cond) - : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), - label_size_(label_size), - label_list_(label_list), - cond_(cond) {} + LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) + : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} ~LabelSwitchTaskInfo() override {} - uint32_t label_size() { return label_size_; }; - const std::vector &label_list() { return label_list_; }; - void *cond() { return cond_; }; +}; - private: - uint32_t label_size_; - std::vector label_list_; - void *cond_; +class LabelGotoTaskInfo : public LabelTaskInfo { + public: + LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) + : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} + ~LabelGotoTaskInfo() override {} }; class EventTaskInfo : public TaskInfo { @@ -231,8 +215,8 @@ class EventTaskInfo : public TaskInfo { uint32_t event_id() const { return event_id_; } protected: - EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) - : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} + EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) + : TaskInfo(stream_id, type), event_id_(event_id) {} virtual ~EventTaskInfo() override {} uint32_t event_id_; @@ -240,41 +224,39 @@ class EventTaskInfo : public TaskInfo { class EventRecordTaskInfo : public EventTaskInfo { public: - EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) - : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} + EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} ~EventRecordTaskInfo() override {} }; class EventWaitTaskInfo : public EventTaskInfo { public: - EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) - : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} + EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} ~EventWaitTaskInfo() override {} }; class FusionStartTaskInfo : public TaskInfo { public: - explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) - : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} + explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} ~FusionStartTaskInfo() override {} }; class FusionEndTaskInfo : public TaskInfo { public: - explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) - : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} + explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} ~FusionEndTaskInfo() override {} }; class HcclTaskInfo : public TaskInfo { public: - HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, - void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, + void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, - int64_t op_type, int64_t data_type, const std::string &group, - std::function hcom_bind_model, std::function hcom_unbind_model, - std::function, void *)> hcom_distribute_task, bool dump_flag) - : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), + int64_t op_type, int64_t data_type, std::function hcom_bind_model, + std::function hcom_unbind_model, + std::function, void *)> hcom_distribute_task) + : TaskInfo(stream_id, TaskInfoType::HCCL), hccl_type_(hccl_type), input_data_addr_(input_data_addr), output_data_addr_(output_data_addr), @@ -287,7 +269,6 @@ class HcclTaskInfo : public TaskInfo { root_id_(root_id), op_type_(op_type), data_type_(data_type), - group_(group), hcom_bind_model_(hcom_bind_model), hcom_unbind_model_(hcom_unbind_model), hcom_distribute_task_(hcom_distribute_task) {} @@ -305,7 +286,6 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id() const { return root_id_; } int64_t op_type() const { return op_type_; } int64_t data_type() const { return data_type_; } - const std::string &group() const { return group_; } std::function hcom_bind_model() const { return hcom_bind_model_; } std::function hcom_unbind_model() const { return hcom_unbind_model_; } std::function, void *)> hcom_distribute_task() const { @@ -325,7 +305,6 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id_; int64_t op_type_; int64_t data_type_; - std::string group_; std::function hcom_bind_model_; std::function hcom_unbind_model_; std::function, void *)> hcom_distribute_task_; @@ -333,11 +312,8 @@ class HcclTaskInfo : public TaskInfo { class ProfilerTraceTaskInfo : public TaskInfo { public: - ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) - : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), - log_id_(log_id), - notify_(notify), - flat_(flat) {} + ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) + : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} ~ProfilerTraceTaskInfo() override {} uint64_t log_id() const { return log_id_; } @@ -352,9 +328,8 @@ class ProfilerTraceTaskInfo : public TaskInfo { class MemcpyAsyncTaskInfo : public TaskInfo { public: - MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, - uint64_t count, uint32_t kind, bool dump_flag) - : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), + MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) + : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), dst_(dst), dst_max_(dst_max), src_(src), @@ -378,9 +353,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo { class StreamSwitchTaskInfo : public TaskInfo { public: - StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, - void *value_addr, int64_t cond, int64_t data_type) - : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), + StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, + int64_t data_type) + : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), true_stream_id_(true_stream_id), input_addr_(input_addr), value_addr_(value_addr), @@ -404,8 +379,8 @@ class StreamSwitchTaskInfo : public TaskInfo { class StreamActiveTaskInfo : public TaskInfo { public: - StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) - : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} + StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) + : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} ~StreamActiveTaskInfo() override {} uint32_t active_stream_id() const { return active_stream_id_; } diff --git a/src/ge/ge_runtime/model_runner.cc b/src/ge/ge_runtime/model_runner.cc index 9961ab4..59952e3 100644 --- a/src/ge/ge_runtime/model_runner.cc +++ b/src/ge/ge_runtime/model_runner.cc @@ -49,24 +49,6 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint return true; } -bool ModelRunner::DistributeTask(uint32_t model_id) { - auto model_iter = runtime_models_.find(model_id); - if (model_iter == runtime_models_.end()) { - GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); - return false; - } - return model_iter->second->DistributeTask(); -} - -bool ModelRunner::LoadModelComplete(uint32_t model_id) { - auto model_iter = runtime_models_.find(model_id); - if (model_iter == runtime_models_.end()) { - GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); - return false; - } - return model_iter->second->LoadComplete(); -} - const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const { auto model_iter = runtime_models_.find(model_id); if (model_iter == runtime_models_.end()) { @@ -78,38 +60,6 @@ const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const return model_iter->second->GetTaskIdList(); } -const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) const { - auto model_iter = runtime_models_.find(model_id); - if (model_iter == runtime_models_.end()) { - GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); - static const std::vector empty_ret; - return empty_ret; - } - - return model_iter->second->GetStreamIdList(); -} - -const std::map> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { - auto model_iter = runtime_models_.find(model_id); - if (model_iter == runtime_models_.end()) { - GELOGW("Model id %u not found.", model_id); - static const std::map> empty_ret; - return empty_ret; - } - - return model_iter->second->GetRuntimeInfoMap(); -} - -void *ModelRunner::GetModelHandle(uint32_t model_id) const { - auto model_iter = runtime_models_.find(model_id); - if (model_iter == runtime_models_.end()) { - GELOGW("Model id %u not found.", model_id); - return nullptr; - } - - return model_iter->second->GetModelHandle(); -} - bool ModelRunner::UnloadModel(uint32_t model_id) { auto iter = runtime_models_.find(model_id); if (iter != runtime_models_.end()) { diff --git a/src/ge/ge_runtime/output.cc b/src/ge/ge_runtime/output.cc index 5153f68..90c33bb 100644 --- a/src/ge/ge_runtime/output.cc +++ b/src/ge/ge_runtime/output.cc @@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde DataBuffer data_buf = rslt->blobs[data_begin + data_count]; bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); if (!ret) { - GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); + GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); return ret; } data_index = data_begin + data_count; diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index f040505..c89ced9 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -28,6 +28,7 @@ namespace ge { namespace model_runner { + RuntimeModel::~RuntimeModel() { GELOGI("RuntimeModel destructor start"); @@ -115,34 +116,23 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { return true; } -bool RuntimeModel::InitLabel(std::shared_ptr &davinci_model) { - GELOGI("batch number:%u.", davinci_model->GetBatchNum()); - label_list_.resize(davinci_model->GetBatchNum()); - for (auto &task_info : davinci_model->GetTaskInfoList()) { - if (task_info == nullptr) { - GELOGE(PARAM_INVALID, "task_info is null."); - continue; - } - - if (task_info->type() != TaskInfoType::LABEL_SET) { - continue; - } - auto label_set_task_info = std::static_pointer_cast(task_info); - - if (label_set_task_info->stream_id() >= stream_list_.size()) { - GELOGE(PARAM_INVALID, "Invalid stream id."); +bool RuntimeModel::InitLabel(uint32_t batch_num) { + GELOGI("batch number:%u.", batch_num); + for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { + rtLabel_t rt_lLabel = nullptr; + rtError_t rt_ret = rtLabelCreate(&rt_lLabel); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); return false; } - rtLabel_t rt_label = nullptr; - rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); + if (rt_lLabel == nullptr) { + GELOGE(RT_FAILED, "rtLabel is nullptr!"); return false; } - label_list_[label_set_task_info->label_id()] = rt_label; - } + label_list_.emplace_back(rt_lLabel); + } return true; } @@ -174,7 +164,7 @@ bool RuntimeModel::InitResource(std::shared_ptr &davinci_model) { return false; } - if (!InitLabel(davinci_model)) { + if (!InitLabel(davinci_model->GetBatchNum())) { return false; } @@ -219,41 +209,20 @@ bool RuntimeModel::LoadTask() { return false; } task_id_list_.push_back(task_id); - stream_id_list_.push_back(stream_id); - if (task->Args() != nullptr) { - std::shared_ptr runtime_tuple = nullptr; - GE_MAKE_SHARED(runtime_tuple = std::make_shared(task_id, stream_id, task->Args()), return false); - auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); - if (!emplace_ret.second) { - GELOGW("Task name exist:%s", task->task_name().c_str()); - } - } } if (task_list_.empty()) { GELOGE(FAILED, "Task list is empty"); return false; } + GELOGI("Distribute task succ."); - GELOGI("LoadTask succ."); - return true; -} - -bool RuntimeModel::LoadComplete() { - uint32_t task_id = 0; - uint32_t stream_id = 0; - auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); - return RT_FAILED; - } - task_id_list_.push_back(task_id); - stream_id_list_.push_back(stream_id); - - rt_ret = rtModelLoadComplete(rt_model_handle_); + auto rt_ret = rtModelLoadComplete(rt_model_handle_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); return false; } + + GELOGI("LoadTask succ."); return true; } @@ -283,16 +252,14 @@ bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr } GenerateTask(device_id, session_id, davinci_model); - return status; -} -bool RuntimeModel::DistributeTask() { - bool status = LoadTask(); + status = LoadTask(); if (!status) { GELOGE(FAILED, "DistributeTask failed"); - return false; + return status; } - return true; + + return status; } bool RuntimeModel::Run() { @@ -303,14 +270,10 @@ bool RuntimeModel::Run() { return false; } - GELOGI("Run rtModelExecute success, ret = 0x%X", ret); + GELOGI("Run rtModelExecute success"); ret = rtStreamSynchronize(rt_model_stream_); if (ret != RT_ERROR_NONE) { - if (ret == RT_ERROR_END_OF_SEQUENCE) { - GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); - return true; - } GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); return false; } @@ -470,7 +433,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model } if (constant->output_tensors[0].size < constant->weight_data.size()) { - GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size, + GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, constant->weight_data.size()); return false; } @@ -485,8 +448,11 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// and that of unknown shape is zero too. /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. - int64_t elem_num = - (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); + int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); + if (elem_num == 0 && constant->weight_tensors[0].size == 0) { + elem_num = 1; + } + if (constant->weight_data.size() < sizeof(uint64_t)) { GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); return false; @@ -529,6 +495,5 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp const std::vector &RuntimeModel::GetTaskIdList() const { return task_id_list_; } -const std::vector &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/runtime_model.h b/src/ge/ge_runtime/runtime_model.h index d0c466d..e8ff405 100644 --- a/src/ge/ge_runtime/runtime_model.h +++ b/src/ge/ge_runtime/runtime_model.h @@ -27,7 +27,7 @@ namespace ge { namespace model_runner { -using RuntimeInfo = std::tuple; + class Task; class RuntimeModel { public: @@ -35,12 +35,7 @@ class RuntimeModel { ~RuntimeModel(); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); - bool DistributeTask(); - bool LoadComplete(); const std::vector &GetTaskIdList() const; - const std::vector &GetStreamIdList() const; - const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } - rtModel_t GetModelHandle() const { return rt_model_handle_; } bool Run(); bool CopyInputData(const InputData &input_data); bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc, @@ -53,7 +48,7 @@ class RuntimeModel { bool LoadTask(); bool InitStream(std::shared_ptr &davinci_model); bool InitEvent(uint32_t event_num); - bool InitLabel(std::shared_ptr &davinci_model); + bool InitLabel(uint32_t batch_num); bool InitDataInfo(std::shared_ptr &davinci_model); bool InitOutputInfo(std::shared_ptr &davinci_model); bool InitConstantInfo(std::shared_ptr &davinci_model); @@ -82,8 +77,6 @@ class RuntimeModel { std::vector> constant_info_list_{}; std::vector task_id_list_{}; - std::vector stream_id_list_{}; - std::map> runtime_info_map_; }; } // namespace model_runner diff --git a/src/ge/ge_runtime/task/aicpu_task.cc b/src/ge/ge_runtime/task/aicpu_task.cc index 9b126ec..4cb7186 100644 --- a/src/ge/ge_runtime/task/aicpu_task.cc +++ b/src/ge/ge_runtime/task/aicpu_task.cc @@ -85,15 +85,11 @@ bool AicpuTask::Distribute() { return false; } - input_output_addr_ = reinterpret_cast(reinterpret_cast(args_) + io_addr_offset); - - auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; - GELOGI( - "Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", - args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); - rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(task_info_->so_name().data()), - reinterpret_cast(task_info_->kernel_name().data()), 1, args_, - args_size, nullptr, stream_, dump_flag); + GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size, + io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data()); + rt_ret = rtCpuKernelLaunch(reinterpret_cast(task_info_->so_name().data()), + reinterpret_cast(task_info_->kernel_name().data()), 1, args_, args_size, + nullptr, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; diff --git a/src/ge/ge_runtime/task/aicpu_task.h b/src/ge/ge_runtime/task/aicpu_task.h index cc21af8..f5cdc61 100644 --- a/src/ge/ge_runtime/task/aicpu_task.h +++ b/src/ge/ge_runtime/task/aicpu_task.h @@ -18,7 +18,6 @@ #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ #include -#include #include "ge_runtime/task/task.h" namespace ge { @@ -31,17 +30,12 @@ class AicpuTask : public TaskRepeater { bool Distribute() override; - void *Args() override { return input_output_addr_; } - - std::string task_name() const override { return task_info_->op_name(); } - private: static void ReleaseRtMem(void **ptr) noexcept; std::shared_ptr task_info_; void *stream_; void *args_; - void *input_output_addr_; }; } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/task/hccl_task.cc b/src/ge/ge_runtime/task/hccl_task.cc index 3d5f850..54ae3bf 100644 --- a/src/ge/ge_runtime/task/hccl_task.cc +++ b/src/ge/ge_runtime/task/hccl_task.cc @@ -115,6 +115,7 @@ bool HcclTask::Distribute() { rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + (void)rtStreamDestroy(stream); return false; } diff --git a/src/ge/ge_runtime/task/label_goto_task.cc b/src/ge/ge_runtime/task/label_goto_task.cc deleted file mode 100644 index d357acc..0000000 --- a/src/ge/ge_runtime/task/label_goto_task.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ge_runtime/task/label_goto_task.h" -#include "ge_runtime/task/task_factory.h" - -namespace ge { -namespace model_runner { -LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) - : TaskRepeater(model_context, task_info), - task_info_(task_info), - stream_(nullptr), - label_(nullptr) { - if (task_info_ == nullptr) { - GELOGW("task_info_ is null!"); - return; - } - auto stream_list = model_context.stream_list(); - auto label_list = model_context.label_list(); - uint32_t stream_id = task_info->stream_id(); - uint32_t label_id = task_info->label_id(); - GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); - GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); - if (stream_id >= stream_list.size() || label_id >= label_list.size()) { - GELOGW("Stream/Label id invalid."); - return; - } - stream_ = stream_list[stream_id]; - label_ = label_list[label_id]; -} - -LabelGotoTask::~LabelGotoTask() {} - -bool LabelGotoTask::Distribute() { - GELOGI("LabelGotoTask Distribute start."); - if (stream_ == nullptr) { - GELOGE(PARAM_INVALID, "stream is null!"); - return false; - } - if (label_ == nullptr) { - GELOGE(PARAM_INVALID, "label is null!"); - return false; - } - rtError_t rt_ret = rtLabelGotoEx(label_, stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - GELOGI("DistributeTask end."); - return true; -} - -REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); - -} // namespace model_runner -} // namespace ge diff --git a/src/ge/ge_runtime/task/label_goto_task.h b/src/ge/ge_runtime/task/label_goto_task.h deleted file mode 100644 index 4fd6d1b..0000000 --- a/src/ge/ge_runtime/task/label_goto_task.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ -#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ - -#include -#include "ge_runtime/task/task.h" - -namespace ge { -namespace model_runner { -class LabelGotoTask : public TaskRepeater { - public: - LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info); - - ~LabelGotoTask() override; - - bool Distribute() override; - - private: - std::shared_ptr task_info_; - void *stream_; - void *label_; -}; -} // namespace model_runner -} // namespace ge - -#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ diff --git a/src/ge/ge_runtime/task/label_set_task.cc b/src/ge/ge_runtime/task/label_set_task.cc deleted file mode 100644 index 3ab5802..0000000 --- a/src/ge/ge_runtime/task/label_set_task.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ge_runtime/task/label_set_task.h" -#include "ge_runtime/task/task_factory.h" - -namespace ge { -namespace model_runner { -LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info) - : TaskRepeater(model_context, task_info), - task_info_(task_info), - stream_(nullptr), - label_(nullptr) { - if (task_info_ == nullptr) { - GELOGW("task_info_ is null!"); - return; - } - auto stream_list = model_context.stream_list(); - auto label_list = model_context.label_list(); - uint32_t stream_id = task_info->stream_id(); - uint32_t label_id = task_info->label_id(); - GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); - GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); - if (stream_id >= stream_list.size() || label_id >= label_list.size()) { - GELOGW("Stream/Label id invalid."); - return; - } - stream_ = stream_list[stream_id]; - label_ = label_list[label_id]; -} - -LabelSetTask::~LabelSetTask() {} - -bool LabelSetTask::Distribute() { - GELOGI("LabelSetTask Distribute start."); - if (stream_ == nullptr) { - GELOGE(PARAM_INVALID, "stream is null!"); - return false; - } - if (label_ == nullptr) { - GELOGE(PARAM_INVALID, "label is null!"); - return false; - } - rtError_t rt_ret = rtLabelSet(label_, stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - GELOGI("DistributeTask end."); - return true; -} - -REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); - -} // namespace model_runner -} // namespace ge diff --git a/src/ge/ge_runtime/task/label_set_task.h b/src/ge/ge_runtime/task/label_set_task.h deleted file mode 100644 index 70bf158..0000000 --- a/src/ge/ge_runtime/task/label_set_task.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ -#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ - -#include -#include "ge_runtime/task/task.h" - -namespace ge { -namespace model_runner { -class LabelSetTask : public TaskRepeater { - public: - LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info); - - ~LabelSetTask() override; - - bool Distribute() override; - - private: - std::shared_ptr task_info_; - void *stream_; - void *label_; -}; -} // namespace model_runner -} // namespace ge - -#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ diff --git a/src/ge/ge_runtime/task/label_switch_task.cc b/src/ge/ge_runtime/task/label_switch_task.cc deleted file mode 100644 index a3c2d41..0000000 --- a/src/ge/ge_runtime/task/label_switch_task.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ge_runtime/task/label_switch_task.h" -#include "ge_runtime/task/task_factory.h" - -namespace ge { -namespace model_runner { -LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, - const std::shared_ptr &task_info) - : TaskRepeater(model_context, task_info), - task_info_(task_info), - stream_(nullptr), - all_label_resource_(), - label_info_(nullptr) { - if (task_info_ == nullptr) { - GELOGW("task_info_ is null!"); - return; - } - - all_label_resource_ = model_context.label_list(); - auto stream_list = model_context.stream_list(); - uint32_t stream_id = task_info->stream_id(); - GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); - if (stream_id >= stream_list.size()) { - GELOGW("Stream id invalid."); - return; - } - stream_ = stream_list[stream_id]; -} - -LabelSwitchTask::~LabelSwitchTask() { - if (label_info_ != nullptr) { - rtError_t rt_ret = rtFree(label_info_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); - } - label_info_ = nullptr; - } -} - -bool LabelSwitchTask::Distribute() { - GELOGI("LabelSwitchTask Distribute start."); - if (!CheckParamValid()) { - return false; - } - - const std::vector &label_index_list = task_info_->label_list(); - std::vector label_list(task_info_->label_size(), nullptr); - - for (size_t i = 0; i < task_info_->label_size(); ++i) { - uint32_t label_index = label_index_list[i]; - if (label_index >= all_label_resource_.size()) { - GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, - all_label_resource_.size()); - return false; - } - label_list[i] = all_label_resource_[label_index]; - GELOGI("Case %zu: label id %zu.", i, label_index); - } - - uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); - rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } - - GELOGI("DistributeTask end."); - return true; -} - -bool LabelSwitchTask::CheckParamValid() { - if (stream_ == nullptr) { - GELOGE(PARAM_INVALID, "stream is null!"); - return false; - } - - if (task_info_->label_list().empty()) { - GELOGE(PARAM_INVALID, "label_list is empty."); - return false; - } - - if (task_info_->label_size() != task_info_->label_list().size()) { - GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), - task_info_->label_size()); - return false; - } - - if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { - GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); - return false; - } - - if (label_info_ != nullptr) { - GELOGE(PARAM_INVALID, "label_info_ has dirty data."); - return false; - } - - return true; -} - -REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); - -} // namespace model_runner -} // namespace ge diff --git a/src/ge/ge_runtime/task/label_switch_task.h b/src/ge/ge_runtime/task/label_switch_task.h deleted file mode 100644 index 463faa3..0000000 --- a/src/ge/ge_runtime/task/label_switch_task.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ -#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ - -#include -#include "ge_runtime/task/task.h" - -namespace ge { -namespace model_runner { -class LabelSwitchTask : public TaskRepeater { - public: - LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr &task_info); - - ~LabelSwitchTask() override; - - bool Distribute() override; - - private: - bool CheckParamValid(); - - std::shared_ptr task_info_; - void *stream_; - std::vector all_label_resource_; - void *label_info_; -}; -} // namespace model_runner -} // namespace ge - -#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ diff --git a/src/ge/ge_runtime/task/stream_switch_task.cc b/src/ge/ge_runtime/task/stream_switch_task.cc index 2adcb4b..9114113 100644 --- a/src/ge/ge_runtime/task/stream_switch_task.cc +++ b/src/ge/ge_runtime/task/stream_switch_task.cc @@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() { } if (static_cast(task_info_->true_stream_id()) >= stream_list_.size()) { - GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(), + GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), stream_list_.size()); return false; } diff --git a/src/ge/ge_runtime/task/task.h b/src/ge/ge_runtime/task/task.h index 6c4df24..7c748a7 100644 --- a/src/ge/ge_runtime/task/task.h +++ b/src/ge/ge_runtime/task/task.h @@ -18,9 +18,7 @@ #define GE_GE_RUNTIME_TASK_TASK_H_ #include -#include #include -#include #include "runtime/rt_model.h" #include "ge_runtime/model_context.h" #include "ge_runtime/task_info.h" @@ -34,10 +32,6 @@ class Task { virtual ~Task() {} virtual bool Distribute() = 0; - - virtual void *Args() { return nullptr; } - - virtual std::string task_name() const { return ""; } }; template diff --git a/src/ge/ge_runtime/task/tbe_task.cc b/src/ge/ge_runtime/task/tbe_task.cc index e7025ae..8a3c36a 100644 --- a/src/ge/ge_runtime/task/tbe_task.cc +++ b/src/ge/ge_runtime/task/tbe_task.cc @@ -95,14 +95,15 @@ bool TbeTask::Distribute() { return false; } + GELOGI("InitTbeTask end."); GELOGI("DistributeTbeTask start."); - auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; - rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); + rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); return false; } - GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); + + GELOGI("DistributeTbeTask end."); return true; } diff --git a/src/ge/ge_runtime/task/tbe_task.h b/src/ge/ge_runtime/task/tbe_task.h index a8ce626..994ba5e 100644 --- a/src/ge/ge_runtime/task/tbe_task.h +++ b/src/ge/ge_runtime/task/tbe_task.h @@ -30,10 +30,6 @@ class TbeTask : public TaskRepeater { bool Distribute() override; - void *Args() override { return args_; } - - std::string task_name() const override { return task_info_->op_name(); } - private: std::shared_ptr task_info_; void *stream_; -- GitLab