提交 81d79f33 编写于 作者: Z zhangzhenghai

modify ge_runtime

上级 9275e7b0
......@@ -28,7 +28,7 @@
namespace ge {
namespace model_runner {
class RuntimeModel;
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class ModelRunner {
public:
static ModelRunner &Instance();
......@@ -36,18 +36,8 @@ class ModelRunner {
bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener);
bool DistributeTask(uint32_t model_id);
bool LoadModelComplete(uint32_t model_id);
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;
void *GetModelHandle(uint32_t model_id) const;
bool UnloadModel(uint32_t model_id);
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);
......
......@@ -21,7 +21,6 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "cce/taskdown_api.h"
......@@ -53,27 +52,21 @@ class TaskInfo {
virtual ~TaskInfo() {}
uint32_t stream_id() const { return stream_id_; }
TaskInfoType type() const { return type_; }
std::string op_name() const { return op_name_; }
bool dump_flag() const { return dump_flag_; }
protected:
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {}
private:
std::string op_name_;
uint32_t stream_id_;
TaskInfoType type_;
bool dump_flag_;
};
class CceTaskInfo : public TaskInfo {
public:
CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc,
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(stream_id, TaskInfoType::CCE),
ctx_(ctx),
stub_func_(stub_func),
block_dim_(block_dim),
......@@ -109,11 +102,11 @@ class CceTaskInfo : public TaskInfo {
class TbeTaskInfo : public TaskInfo {
public:
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args,
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size,
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs)
: TaskInfo(stream_id, TaskInfoType::TBE),
stub_func_(stub_func),
block_dim_(block_dim),
args_(args),
......@@ -160,10 +153,9 @@ class TbeTaskInfo : public TaskInfo {
class AicpuTaskInfo : public TaskInfo {
public:
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def,
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs)
: TaskInfo(stream_id, TaskInfoType::AICPU),
so_name_(so_name),
kernel_name_(kernel_name),
node_def_(node_def),
......@@ -185,45 +177,37 @@ class AicpuTaskInfo : public TaskInfo {
std::vector<void *> output_data_addrs_;
};
class LabelSetTaskInfo : public TaskInfo {
class LabelTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; }
private:
protected:
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id)
: TaskInfo(stream_id, type), label_id_(label_id) {}
virtual ~LabelTaskInfo() override {}
uint32_t label_id_;
};
class LabelGotoTaskInfo : public TaskInfo {
class LabelSetTaskInfo : public LabelTaskInfo {
public:
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; }
private:
uint32_t label_id_;
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {}
~LabelSetTaskInfo() override {}
};
class LabelSwitchTaskInfo : public TaskInfo {
class LabelSwitchTaskInfo : public LabelTaskInfo {
public:
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
label_size_(label_size),
label_list_(label_list),
cond_(cond) {}
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {}
~LabelSwitchTaskInfo() override {}
uint32_t label_size() { return label_size_; };
const std::vector<uint32_t> &label_list() { return label_list_; };
void *cond() { return cond_; };
};
private:
uint32_t label_size_;
std::vector<uint32_t> label_list_;
void *cond_;
class LabelGotoTaskInfo : public LabelTaskInfo {
public:
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {}
~LabelGotoTaskInfo() override {}
};
class EventTaskInfo : public TaskInfo {
......@@ -231,8 +215,8 @@ class EventTaskInfo : public TaskInfo {
uint32_t event_id() const { return event_id_; }
protected:
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(stream_id, type), event_id_(event_id) {}
virtual ~EventTaskInfo() override {}
uint32_t event_id_;
......@@ -240,41 +224,39 @@ class EventTaskInfo : public TaskInfo {
class EventRecordTaskInfo : public EventTaskInfo {
public:
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo() override {}
};
class EventWaitTaskInfo : public EventTaskInfo {
public:
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo() override {}
};
class FusionStartTaskInfo : public TaskInfo {
public:
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {}
~FusionStartTaskInfo() override {}
};
class FusionEndTaskInfo : public TaskInfo {
public:
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {}
~FusionEndTaskInfo() override {}
};
class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr,
void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, const std::string &group,
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model,
std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task)
: TaskInfo(stream_id, TaskInfoType::HCCL),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
......@@ -287,7 +269,6 @@ class HcclTaskInfo : public TaskInfo {
root_id_(root_id),
op_type_(op_type),
data_type_(data_type),
group_(group),
hcom_bind_model_(hcom_bind_model),
hcom_unbind_model_(hcom_unbind_model),
hcom_distribute_task_(hcom_distribute_task) {}
......@@ -305,7 +286,6 @@ class HcclTaskInfo : public TaskInfo {
int64_t root_id() const { return root_id_; }
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
......@@ -325,7 +305,6 @@ class HcclTaskInfo : public TaskInfo {
int64_t root_id_;
int64_t op_type_;
int64_t data_type_;
std::string group_;
std::function<bool(void *, void *)> hcom_bind_model_;
std::function<bool(void *)> hcom_unbind_model_;
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
......@@ -333,11 +312,8 @@ class HcclTaskInfo : public TaskInfo {
class ProfilerTraceTaskInfo : public TaskInfo {
public:
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
log_id_(log_id),
notify_(notify),
flat_(flat) {}
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {}
~ProfilerTraceTaskInfo() override {}
uint64_t log_id() const { return log_id_; }
......@@ -352,9 +328,8 @@ class ProfilerTraceTaskInfo : public TaskInfo {
class MemcpyAsyncTaskInfo : public TaskInfo {
public:
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
uint64_t count, uint32_t kind, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind)
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC),
dst_(dst),
dst_max_(dst_max),
src_(src),
......@@ -378,9 +353,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo {
class StreamSwitchTaskInfo : public TaskInfo {
public:
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
void *value_addr, int64_t cond, int64_t data_type)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond,
int64_t data_type)
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH),
true_stream_id_(true_stream_id),
input_addr_(input_addr),
value_addr_(value_addr),
......@@ -404,8 +379,8 @@ class StreamSwitchTaskInfo : public TaskInfo {
class StreamActiveTaskInfo : public TaskInfo {
public:
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo() override {}
uint32_t active_stream_id() const { return active_stream_id_; }
......
......@@ -49,24 +49,6 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint
return true;
}
bool ModelRunner::DistributeTask(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}
return model_iter->second->DistributeTask();
}
bool ModelRunner::LoadModelComplete(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}
return model_iter->second->LoadComplete();
}
const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
......@@ -78,38 +60,6 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return model_iter->second->GetTaskIdList();
}
const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
static const std::vector<uint32_t> empty_ret;
return empty_ret;
}
return model_iter->second->GetStreamIdList();
}
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGW("Model id %u not found.", model_id);
static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret;
return empty_ret;
}
return model_iter->second->GetRuntimeInfoMap();
}
void *ModelRunner::GetModelHandle(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGW("Model id %u not found.", model_id);
return nullptr;
}
return model_iter->second->GetModelHandle();
}
bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) {
......
......@@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde
DataBuffer data_buf = rslt->blobs[data_begin + data_count];
bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share);
if (!ret) {
GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]);
GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]);
return ret;
}
data_index = data_begin + data_count;
......
......@@ -28,6 +28,7 @@
namespace ge {
namespace model_runner {
RuntimeModel::~RuntimeModel() {
GELOGI("RuntimeModel destructor start");
......@@ -115,34 +116,23 @@ bool RuntimeModel::InitEvent(uint32_t event_num) {
return true;
}
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("batch number:%u.", davinci_model->GetBatchNum());
label_list_.resize(davinci_model->GetBatchNum());
for (auto &task_info : davinci_model->GetTaskInfoList()) {
if (task_info == nullptr) {
GELOGE(PARAM_INVALID, "task_info is null.");
continue;
}
if (task_info->type() != TaskInfoType::LABEL_SET) {
continue;
}
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);
if (label_set_task_info->stream_id() >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "Invalid stream id.");
bool RuntimeModel::InitLabel(uint32_t batch_num) {
GELOGI("batch number:%u.", batch_num);
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) {
rtLabel_t rt_lLabel = nullptr;
rtError_t rt_ret = rtLabelCreate(&rt_lLabel);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret);
return false;
}
rtLabel_t rt_label = nullptr;
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret);
if (rt_lLabel == nullptr) {
GELOGE(RT_FAILED, "rtLabel is nullptr!");
return false;
}
label_list_[label_set_task_info->label_id()] = rt_label;
}
label_list_.emplace_back(rt_lLabel);
}
return true;
}
......@@ -174,7 +164,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) {
return false;
}
if (!InitLabel(davinci_model)) {
if (!InitLabel(davinci_model->GetBatchNum())) {
return false;
}
......@@ -219,41 +209,20 @@ bool RuntimeModel::LoadTask() {
return false;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
if (task->Args() != nullptr) {
std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr;
GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false);
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple);
if (!emplace_ret.second) {
GELOGW("Task name exist:%s", task->task_name().c_str());
}
}
}
if (task_list_.empty()) {
GELOGE(FAILED, "Task list is empty");
return false;
}
GELOGI("Distribute task succ.");
GELOGI("LoadTask succ.");
return true;
}
bool RuntimeModel::LoadComplete() {
uint32_t task_id = 0;
uint32_t stream_id = 0;
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);
rt_ret = rtModelLoadComplete(rt_model_handle_);
auto rt_ret = rtModelLoadComplete(rt_model_handle_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret);
return false;
}
GELOGI("LoadTask succ.");
return true;
}
......@@ -283,16 +252,14 @@ bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr
}
GenerateTask(device_id, session_id, davinci_model);
return status;
}
bool RuntimeModel::DistributeTask() {
bool status = LoadTask();
status = LoadTask();
if (!status) {
GELOGE(FAILED, "DistributeTask failed");
return false;
return status;
}
return true;
return status;
}
bool RuntimeModel::Run() {
......@@ -303,14 +270,10 @@ bool RuntimeModel::Run() {
return false;
}
GELOGI("Run rtModelExecute success, ret = 0x%X", ret);
GELOGI("Run rtModelExecute success");
ret = rtStreamSynchronize(rt_model_stream_);
if (ret != RT_ERROR_NONE) {
if (ret == RT_ERROR_END_OF_SEQUENCE) {
GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret);
return true;
}
GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret);
return false;
}
......@@ -470,7 +433,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
}
if (constant->output_tensors[0].size < constant->weight_data.size()) {
GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size,
GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size,
constant->weight_data.size());
return false;
}
......@@ -485,8 +448,11 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero
/// and that of unknown shape is zero too.
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not.
int64_t elem_num =
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize();
int64_t elem_num = constant->weight_tensors[0].GetShapeSize();
if (elem_num == 0 && constant->weight_tensors[0].size == 0) {
elem_num = 1;
}
if (constant->weight_data.size() < sizeof(uint64_t)) {
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)");
return false;
......@@ -529,6 +495,5 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp
const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; }
const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; }
} // namespace model_runner
} // namespace ge
......@@ -27,7 +27,7 @@
namespace ge {
namespace model_runner {
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class Task;
class RuntimeModel {
public:
......@@ -35,12 +35,7 @@ class RuntimeModel {
~RuntimeModel();
bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
bool DistributeTask();
bool LoadComplete();
const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
rtModel_t GetModelHandle() const { return rt_model_handle_; }
bool Run();
bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
......@@ -53,7 +48,7 @@ class RuntimeModel {
bool LoadTask();
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model);
bool InitEvent(uint32_t event_num);
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model);
bool InitLabel(uint32_t batch_num);
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model);
......@@ -82,8 +77,6 @@ class RuntimeModel {
std::vector<std::shared_ptr<OpInfo>> constant_info_list_{};
std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{};
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
};
} // namespace model_runner
......
......@@ -85,15 +85,11 @@ bool AicpuTask::Distribute() {
return false;
}
input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset);
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
GELOGI(
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.",
args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag);
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_,
args_size, nullptr, stream_, dump_flag);
GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size,
io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data());
rt_ret = rtCpuKernelLaunch(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, args_size,
nullptr, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
......
......@@ -18,7 +18,6 @@
#define GE_GE_RUNTIME_TASK_AICPU_TASK_H_
#include <memory>
#include <string>
#include "ge_runtime/task/task.h"
namespace ge {
......@@ -31,17 +30,12 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
bool Distribute() override;
void *Args() override { return input_output_addr_; }
std::string task_name() const override { return task_info_->op_name(); }
private:
static void ReleaseRtMem(void **ptr) noexcept;
std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_;
void *args_;
void *input_output_addr_;
};
} // namespace model_runner
} // namespace ge
......
......@@ -115,6 +115,7 @@ bool HcclTask::Distribute() {
rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
(void)rtStreamDestroy(stream);
return false;
}
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"
namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}
LabelGotoTask::~LabelGotoTask() {}
bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end.");
return true;
}
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);
} // namespace model_runner
} // namespace ge
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace ge {
namespace model_runner {
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
public:
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);
~LabelGotoTask() override;
bool Distribute() override;
private:
std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_set_task.h"
#include "ge_runtime/task/task_factory.h"
namespace ge {
namespace model_runner {
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}
LabelSetTask::~LabelSetTask() {}
bool LabelSetTask::Distribute() {
GELOGI("LabelSetTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelSet(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end.");
return true;
}
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);
} // namespace model_runner
} // namespace ge
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace ge {
namespace model_runner {
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
public:
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);
~LabelSetTask() override;
bool Distribute() override;
private:
std::shared_ptr<LabelSetTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_switch_task.h"
#include "ge_runtime/task/task_factory.h"
namespace ge {
namespace model_runner {
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
all_label_resource_ = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid.");
return;
}
stream_ = stream_list[stream_id];
}
LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
}
}
bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) {
return false;
}
const std::vector<uint32_t> &label_index_list = task_info_->label_list();
std::vector<void *> label_list(task_info_->label_size(), nullptr);
for (size_t i = 0; i < task_info_->label_size(); ++i) {
uint32_t label_index = label_index_list[i];
if (label_index >= all_label_resource_.size()) {
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index,
all_label_resource_.size());
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, label_index);
}
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end.");
return true;
}
bool LabelSwitchTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (task_info_->label_list().empty()) {
GELOGE(PARAM_INVALID, "label_list is empty.");
return false;
}
if (task_info_->label_size() != task_info_->label_list().size()) {
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(),
task_info_->label_size());
return false;
}
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size());
return false;
}
if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}
return true;
}
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);
} // namespace model_runner
} // namespace ge
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace ge {
namespace model_runner {
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
public:
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);
~LabelSwitchTask() override;
bool Distribute() override;
private:
bool CheckParamValid();
std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
......@@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() {
}
if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(),
GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(),
stream_list_.size());
return false;
}
......
......@@ -18,9 +18,7 @@
#define GE_GE_RUNTIME_TASK_TASK_H_
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "runtime/rt_model.h"
#include "ge_runtime/model_context.h"
#include "ge_runtime/task_info.h"
......@@ -34,10 +32,6 @@ class Task {
virtual ~Task() {}
virtual bool Distribute() = 0;
virtual void *Args() { return nullptr; }
virtual std::string task_name() const { return ""; }
};
template <class T>
......
......@@ -95,14 +95,15 @@ bool TbeTask::Distribute() {
return false;
}
GELOGI("InitTbeTask end.");
GELOGI("DistributeTbeTask start.");
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag);
GELOGI("DistributeTbeTask end.");
return true;
}
......
......@@ -30,10 +30,6 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {
bool Distribute() override;
void *Args() override { return args_; }
std::string task_name() const override { return task_info_->op_name(); }
private:
std::shared_ptr<TbeTaskInfo> task_info_;
void *stream_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册