Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
dashen
graphengine
提交
81d79f33
G
graphengine
项目概览
dashen
/
graphengine
与 Fork 源项目一致
Fork自
MindSpore / graphengine
通知
0
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
G
graphengine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
81d79f33
编写于
8月 04, 2020
作者:
Z
zhangzhenghai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify ge_runtime
上级
9275e7b0
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
98 addition
and
640 deletion
+98
-640
inc/framework/ge_runtime/model_runner.h
inc/framework/ge_runtime/model_runner.h
+1
-11
inc/framework/ge_runtime/task_info.h
inc/framework/ge_runtime/task_info.h
+55
-80
src/ge/ge_runtime/model_runner.cc
src/ge/ge_runtime/model_runner.cc
+0
-50
src/ge/ge_runtime/output.cc
src/ge/ge_runtime/output.cc
+1
-1
src/ge/ge_runtime/runtime_model.cc
src/ge/ge_runtime/runtime_model.cc
+28
-63
src/ge/ge_runtime/runtime_model.h
src/ge/ge_runtime/runtime_model.h
+2
-9
src/ge/ge_runtime/task/aicpu_task.cc
src/ge/ge_runtime/task/aicpu_task.cc
+5
-9
src/ge/ge_runtime/task/aicpu_task.h
src/ge/ge_runtime/task/aicpu_task.h
+0
-6
src/ge/ge_runtime/task/hccl_task.cc
src/ge/ge_runtime/task/hccl_task.cc
+1
-0
src/ge/ge_runtime/task/label_goto_task.cc
src/ge/ge_runtime/task/label_goto_task.cc
+0
-70
src/ge/ge_runtime/task/label_goto_task.h
src/ge/ge_runtime/task/label_goto_task.h
+0
-41
src/ge/ge_runtime/task/label_set_task.cc
src/ge/ge_runtime/task/label_set_task.cc
+0
-70
src/ge/ge_runtime/task/label_set_task.h
src/ge/ge_runtime/task/label_set_task.h
+0
-41
src/ge/ge_runtime/task/label_switch_task.cc
src/ge/ge_runtime/task/label_switch_task.cc
+0
-131
src/ge/ge_runtime/task/label_switch_task.h
src/ge/ge_runtime/task/label_switch_task.h
+0
-44
src/ge/ge_runtime/task/stream_switch_task.cc
src/ge/ge_runtime/task/stream_switch_task.cc
+1
-1
src/ge/ge_runtime/task/task.h
src/ge/ge_runtime/task/task.h
+0
-6
src/ge/ge_runtime/task/tbe_task.cc
src/ge/ge_runtime/task/tbe_task.cc
+4
-3
src/ge/ge_runtime/task/tbe_task.h
src/ge/ge_runtime/task/tbe_task.h
+0
-4
未找到文件。
inc/framework/ge_runtime/model_runner.h
浏览文件 @
81d79f33
...
...
@@ -28,7 +28,7 @@
namespace
ge
{
namespace
model_runner
{
class
RuntimeModel
;
using
RuntimeInfo
=
std
::
tuple
<
uint32_t
,
uint32_t
,
void
*>
;
class
ModelRunner
{
public:
static
ModelRunner
&
Instance
();
...
...
@@ -36,18 +36,8 @@ class ModelRunner {
bool
LoadDavinciModel
(
uint32_t
device_id
,
uint64_t
session_id
,
uint32_t
model_id
,
std
::
shared_ptr
<
DavinciModel
>
davinci_model
,
std
::
shared_ptr
<
ModelListener
>
listener
);
bool
DistributeTask
(
uint32_t
model_id
);
bool
LoadModelComplete
(
uint32_t
model_id
);
const
std
::
vector
<
uint32_t
>
&
GetTaskIdList
(
uint32_t
model_id
)
const
;
const
std
::
vector
<
uint32_t
>
&
GetStreamIdList
(
uint32_t
model_id
)
const
;
const
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
RuntimeInfo
>>
&
GetRuntimeInfoMap
(
uint32_t
model_id
)
const
;
void
*
GetModelHandle
(
uint32_t
model_id
)
const
;
bool
UnloadModel
(
uint32_t
model_id
);
bool
RunModel
(
uint32_t
model_id
,
const
InputData
&
input_data
,
OutputData
*
output_data
);
...
...
inc/framework/ge_runtime/task_info.h
浏览文件 @
81d79f33
...
...
@@ -21,7 +21,6 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "cce/taskdown_api.h"
...
...
@@ -53,27 +52,21 @@ class TaskInfo {
virtual
~
TaskInfo
()
{}
uint32_t
stream_id
()
const
{
return
stream_id_
;
}
TaskInfoType
type
()
const
{
return
type_
;
}
std
::
string
op_name
()
const
{
return
op_name_
;
}
bool
dump_flag
()
const
{
return
dump_flag_
;
}
protected:
TaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
TaskInfoType
type
,
bool
dump_flag
)
:
op_name_
(
op_name
),
stream_id_
(
stream_id
),
type_
(
type
),
dump_flag_
(
dump_flag
)
{}
TaskInfo
(
uint32_t
stream_id
,
TaskInfoType
type
)
:
stream_id_
(
stream_id
),
type_
(
type
)
{}
private:
std
::
string
op_name_
;
uint32_t
stream_id_
;
TaskInfoType
type_
;
bool
dump_flag_
;
};
class
CceTaskInfo
:
public
TaskInfo
{
public:
CceTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
const
cce
::
ccOpContext
&
ctx
,
const
std
::
string
&
stub_func
,
uint32_t
block_dim
,
const
std
::
vector
<
uint8_t
>
&
args
,
uint32_t
args_size
,
const
std
::
vector
<
uint8_t
>
&
sm_desc
,
const
std
::
vector
<
uint8_t
>
&
flow_table
,
const
std
::
vector
<
uint8_t
>
&
args_offset
,
bool
is_flowtable
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
CCE
,
false
),
CceTaskInfo
(
uint32_t
stream_id
,
const
cce
::
ccOpContext
&
ctx
,
const
std
::
string
&
stub_func
,
uint32_t
block_dim
,
const
std
::
vector
<
uint8_t
>
&
args
,
uint32_t
args_size
,
const
std
::
vector
<
uint8_t
>
&
sm_desc
,
const
std
::
vector
<
uint8_t
>
&
flow_table
,
const
std
::
vector
<
uint8_t
>
&
args_offset
,
bool
is_flowtable
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
CCE
),
ctx_
(
ctx
),
stub_func_
(
stub_func
),
block_dim_
(
block_dim
),
...
...
@@ -109,11 +102,11 @@ class CceTaskInfo : public TaskInfo {
class
TbeTaskInfo
:
public
TaskInfo
{
public:
TbeTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
const
std
::
string
&
stub_func
,
uint32_t
block_dim
,
const
std
::
vector
<
uint8_t
>
&
args
,
uint32_t
args_size
,
const
std
::
vector
<
uint8_t
>
&
sm_desc
,
void
*
binary
,
uint32_t
binary_size
,
const
std
::
vector
<
uint8_t
>
&
meta_data
,
const
std
::
vector
<
void
*>
&
input_data_addrs
,
const
std
::
vector
<
void
*>
&
output_data_addrs
,
const
std
::
vector
<
void
*>
&
workspace_addrs
,
bool
dump_flag
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
TBE
,
dump_flag
),
TbeTaskInfo
(
uint32_t
stream_id
,
const
std
::
string
&
stub_func
,
uint32_t
block_dim
,
const
std
::
vector
<
uint8_t
>
&
args
,
uint32_t
args_size
,
const
std
::
vector
<
uint8_t
>
&
sm_desc
,
void
*
binary
,
uint32_t
binary_size
,
const
std
::
vector
<
uint8_t
>
&
meta_data
,
const
std
::
vector
<
void
*>
&
input_data_addrs
,
const
std
::
vector
<
void
*>
&
output_data_addrs
,
const
std
::
vector
<
void
*>
&
workspace_addrs
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
TBE
),
stub_func_
(
stub_func
),
block_dim_
(
block_dim
),
args_
(
args
),
...
...
@@ -160,10 +153,9 @@ class TbeTaskInfo : public TaskInfo {
class
AicpuTaskInfo
:
public
TaskInfo
{
public:
AicpuTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
const
string
&
so_name
,
const
std
::
string
&
kernel_name
,
const
std
::
string
&
node_def
,
const
std
::
vector
<
void
*>
&
input_data_addrs
,
const
std
::
vector
<
void
*>
&
output_data_addrs
,
bool
dump_flag
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
AICPU
,
dump_flag
),
AicpuTaskInfo
(
uint32_t
stream_id
,
const
string
&
so_name
,
const
std
::
string
&
kernel_name
,
const
std
::
string
&
node_def
,
const
std
::
vector
<
void
*>
&
input_data_addrs
,
const
std
::
vector
<
void
*>
&
output_data_addrs
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
AICPU
),
so_name_
(
so_name
),
kernel_name_
(
kernel_name
),
node_def_
(
node_def
),
...
...
@@ -185,45 +177,37 @@ class AicpuTaskInfo : public TaskInfo {
std
::
vector
<
void
*>
output_data_addrs_
;
};
class
Label
Set
TaskInfo
:
public
TaskInfo
{
class
LabelTaskInfo
:
public
TaskInfo
{
public:
LabelSetTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint32_t
label_id
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
LABEL_SET
,
false
),
label_id_
(
label_id
)
{}
~
LabelSetTaskInfo
()
override
{}
uint32_t
label_id
()
const
{
return
label_id_
;
}
private:
protected:
LabelTaskInfo
(
uint32_t
stream_id
,
TaskInfoType
type
,
uint32_t
label_id
)
:
TaskInfo
(
stream_id
,
type
),
label_id_
(
label_id
)
{}
virtual
~
LabelTaskInfo
()
override
{}
uint32_t
label_id_
;
};
class
Label
GotoTaskInfo
:
public
TaskInfo
{
class
Label
SetTaskInfo
:
public
Label
TaskInfo
{
public:
LabelGotoTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint32_t
label_id
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
LABEL_GOTO
,
false
),
label_id_
(
label_id
)
{}
~
LabelGotoTaskInfo
()
override
{}
uint32_t
label_id
()
const
{
return
label_id_
;
}
private:
uint32_t
label_id_
;
LabelSetTaskInfo
(
uint32_t
stream_id
,
uint32_t
label_id
)
:
LabelTaskInfo
(
stream_id
,
TaskInfoType
::
LABEL_SET
,
label_id
)
{}
~
LabelSetTaskInfo
()
override
{}
};
class
LabelSwitchTaskInfo
:
public
TaskInfo
{
class
LabelSwitchTaskInfo
:
public
Label
TaskInfo
{
public:
LabelSwitchTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint32_t
label_size
,
const
std
::
vector
<
uint32_t
>
&
label_list
,
void
*
cond
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
LABEL_SWITCH
,
false
),
label_size_
(
label_size
),
label_list_
(
label_list
),
cond_
(
cond
)
{}
LabelSwitchTaskInfo
(
uint32_t
stream_id
,
uint32_t
label_id
)
:
LabelTaskInfo
(
stream_id
,
TaskInfoType
::
LABEL_SWITCH
,
label_id
)
{}
~
LabelSwitchTaskInfo
()
override
{}
uint32_t
label_size
()
{
return
label_size_
;
};
const
std
::
vector
<
uint32_t
>
&
label_list
()
{
return
label_list_
;
};
void
*
cond
()
{
return
cond_
;
};
};
private:
uint32_t
label_size_
;
std
::
vector
<
uint32_t
>
label_list_
;
void
*
cond_
;
class
LabelGotoTaskInfo
:
public
LabelTaskInfo
{
public:
LabelGotoTaskInfo
(
uint32_t
stream_id
,
uint32_t
label_id
)
:
LabelTaskInfo
(
stream_id
,
TaskInfoType
::
LABEL_GOTO
,
label_id
)
{}
~
LabelGotoTaskInfo
()
override
{}
};
class
EventTaskInfo
:
public
TaskInfo
{
...
...
@@ -231,8 +215,8 @@ class EventTaskInfo : public TaskInfo {
uint32_t
event_id
()
const
{
return
event_id_
;
}
protected:
EventTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
TaskInfoType
type
,
uint32_t
event_id
)
:
TaskInfo
(
op_name
,
stream_id
,
type
,
fals
e
),
event_id_
(
event_id
)
{}
EventTaskInfo
(
uint32_t
stream_id
,
TaskInfoType
type
,
uint32_t
event_id
)
:
TaskInfo
(
stream_id
,
typ
e
),
event_id_
(
event_id
)
{}
virtual
~
EventTaskInfo
()
override
{}
uint32_t
event_id_
;
...
...
@@ -240,41 +224,39 @@ class EventTaskInfo : public TaskInfo {
class
EventRecordTaskInfo
:
public
EventTaskInfo
{
public:
EventRecordTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint32_t
event_id
)
:
EventTaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
EVENT_RECORD
,
event_id
)
{}
EventRecordTaskInfo
(
uint32_t
stream_id
,
uint32_t
event_id
)
:
EventTaskInfo
(
stream_id
,
TaskInfoType
::
EVENT_RECORD
,
event_id
)
{}
~
EventRecordTaskInfo
()
override
{}
};
class
EventWaitTaskInfo
:
public
EventTaskInfo
{
public:
EventWaitTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint32_t
event_id
)
:
EventTaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
EVENT_WAIT
,
event_id
)
{}
EventWaitTaskInfo
(
uint32_t
stream_id
,
uint32_t
event_id
)
:
EventTaskInfo
(
stream_id
,
TaskInfoType
::
EVENT_WAIT
,
event_id
)
{}
~
EventWaitTaskInfo
()
override
{}
};
class
FusionStartTaskInfo
:
public
TaskInfo
{
public:
explicit
FusionStartTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
FUSION_START
,
false
)
{}
explicit
FusionStartTaskInfo
(
uint32_t
stream_id
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
FUSION_START
)
{}
~
FusionStartTaskInfo
()
override
{}
};
class
FusionEndTaskInfo
:
public
TaskInfo
{
public:
explicit
FusionEndTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
FUSION_END
,
false
)
{}
explicit
FusionEndTaskInfo
(
uint32_t
stream_id
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
FUSION_END
)
{}
~
FusionEndTaskInfo
()
override
{}
};
class
HcclTaskInfo
:
public
TaskInfo
{
public:
HcclTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
const
std
::
string
hccl_type
,
void
*
in
put_data_addr
,
void
*
output_data_addr
,
void
*
workspace_addr
,
int64_t
workspace_size
,
int64_t
hccl_stream_num
,
HcclTaskInfo
(
uint32_t
stream_id
,
const
std
::
string
hccl_type
,
void
*
input_data_addr
,
void
*
out
put_data_addr
,
void
*
workspace_addr
,
int64_t
workspace_size
,
int64_t
hccl_stream_num
,
const
std
::
vector
<
uint8_t
>
&
private_def
,
void
*
ops_kernel_store
,
int32_t
count
,
int64_t
root_id
,
int64_t
op_type
,
int64_t
data_type
,
const
std
::
string
&
group
,
std
::
function
<
bool
(
void
*
,
void
*
)
>
hcom_bind_model
,
std
::
function
<
bool
(
void
*
)
>
hcom_unbind_model
,
std
::
function
<
bool
(
std
::
shared_ptr
<
HcclTaskInfo
>
,
void
*
)
>
hcom_distribute_task
,
bool
dump_flag
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
HCCL
,
dump_flag
),
int64_t
op_type
,
int64_t
data_type
,
std
::
function
<
bool
(
void
*
,
void
*
)
>
hcom_bind_model
,
std
::
function
<
bool
(
void
*
)
>
hcom_unbind_model
,
std
::
function
<
bool
(
std
::
shared_ptr
<
HcclTaskInfo
>
,
void
*
)
>
hcom_distribute_task
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
HCCL
),
hccl_type_
(
hccl_type
),
input_data_addr_
(
input_data_addr
),
output_data_addr_
(
output_data_addr
),
...
...
@@ -287,7 +269,6 @@ class HcclTaskInfo : public TaskInfo {
root_id_
(
root_id
),
op_type_
(
op_type
),
data_type_
(
data_type
),
group_
(
group
),
hcom_bind_model_
(
hcom_bind_model
),
hcom_unbind_model_
(
hcom_unbind_model
),
hcom_distribute_task_
(
hcom_distribute_task
)
{}
...
...
@@ -305,7 +286,6 @@ class HcclTaskInfo : public TaskInfo {
int64_t
root_id
()
const
{
return
root_id_
;
}
int64_t
op_type
()
const
{
return
op_type_
;
}
int64_t
data_type
()
const
{
return
data_type_
;
}
const
std
::
string
&
group
()
const
{
return
group_
;
}
std
::
function
<
bool
(
void
*
,
void
*
)
>
hcom_bind_model
()
const
{
return
hcom_bind_model_
;
}
std
::
function
<
bool
(
void
*
)
>
hcom_unbind_model
()
const
{
return
hcom_unbind_model_
;
}
std
::
function
<
bool
(
std
::
shared_ptr
<
HcclTaskInfo
>
,
void
*
)
>
hcom_distribute_task
()
const
{
...
...
@@ -325,7 +305,6 @@ class HcclTaskInfo : public TaskInfo {
int64_t
root_id_
;
int64_t
op_type_
;
int64_t
data_type_
;
std
::
string
group_
;
std
::
function
<
bool
(
void
*
,
void
*
)
>
hcom_bind_model_
;
std
::
function
<
bool
(
void
*
)
>
hcom_unbind_model_
;
std
::
function
<
bool
(
std
::
shared_ptr
<
HcclTaskInfo
>
,
void
*
)
>
hcom_distribute_task_
;
...
...
@@ -333,11 +312,8 @@ class HcclTaskInfo : public TaskInfo {
class
ProfilerTraceTaskInfo
:
public
TaskInfo
{
public:
ProfilerTraceTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint64_t
log_id
,
bool
notify
,
uint32_t
flat
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
PROFILER_TRACE
,
false
),
log_id_
(
log_id
),
notify_
(
notify
),
flat_
(
flat
)
{}
ProfilerTraceTaskInfo
(
uint32_t
stream_id
,
uint64_t
log_id
,
bool
notify
,
uint32_t
flat
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
PROFILER_TRACE
),
log_id_
(
log_id
),
notify_
(
notify
),
flat_
(
flat
)
{}
~
ProfilerTraceTaskInfo
()
override
{}
uint64_t
log_id
()
const
{
return
log_id_
;
}
...
...
@@ -352,9 +328,8 @@ class ProfilerTraceTaskInfo : public TaskInfo {
class
MemcpyAsyncTaskInfo
:
public
TaskInfo
{
public:
MemcpyAsyncTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
void
*
dst
,
uint64_t
dst_max
,
void
*
src
,
uint64_t
count
,
uint32_t
kind
,
bool
dump_flag
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
MEMCPY_ASYNC
,
dump_flag
),
MemcpyAsyncTaskInfo
(
uint32_t
stream_id
,
void
*
dst
,
uint64_t
dst_max
,
void
*
src
,
uint64_t
count
,
uint32_t
kind
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
MEMCPY_ASYNC
),
dst_
(
dst
),
dst_max_
(
dst_max
),
src_
(
src
),
...
...
@@ -378,9 +353,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo {
class
StreamSwitchTaskInfo
:
public
TaskInfo
{
public:
StreamSwitchTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
int64_t
true_stream_id
,
void
*
input_addr
,
void
*
value_addr
,
int64_t
cond
,
int64_t
data_type
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
STREAM_SWITCH
,
false
),
StreamSwitchTaskInfo
(
uint32_t
stream_id
,
int64_t
true_stream_id
,
void
*
input_addr
,
void
*
value_addr
,
int64_t
cond
,
int64_t
data_type
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
STREAM_SWITCH
),
true_stream_id_
(
true_stream_id
),
input_addr_
(
input_addr
),
value_addr_
(
value_addr
),
...
...
@@ -404,8 +379,8 @@ class StreamSwitchTaskInfo : public TaskInfo {
class
StreamActiveTaskInfo
:
public
TaskInfo
{
public:
StreamActiveTaskInfo
(
const
std
::
string
&
op_name
,
uint32_t
stream_id
,
uint32_t
active_stream_id
)
:
TaskInfo
(
op_name
,
stream_id
,
TaskInfoType
::
STREAM_ACTIVE
,
false
),
active_stream_id_
(
active_stream_id
)
{}
StreamActiveTaskInfo
(
uint32_t
stream_id
,
uint32_t
active_stream_id
)
:
TaskInfo
(
stream_id
,
TaskInfoType
::
STREAM_ACTIVE
),
active_stream_id_
(
active_stream_id
)
{}
~
StreamActiveTaskInfo
()
override
{}
uint32_t
active_stream_id
()
const
{
return
active_stream_id_
;
}
...
...
src/ge/ge_runtime/model_runner.cc
浏览文件 @
81d79f33
...
...
@@ -49,24 +49,6 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint
return
true
;
}
bool
ModelRunner
::
DistributeTask
(
uint32_t
model_id
)
{
auto
model_iter
=
runtime_models_
.
find
(
model_id
);
if
(
model_iter
==
runtime_models_
.
end
())
{
GELOGE
(
PARAM_INVALID
,
"Model id %u not found."
,
model_id
);
return
false
;
}
return
model_iter
->
second
->
DistributeTask
();
}
bool
ModelRunner
::
LoadModelComplete
(
uint32_t
model_id
)
{
auto
model_iter
=
runtime_models_
.
find
(
model_id
);
if
(
model_iter
==
runtime_models_
.
end
())
{
GELOGE
(
PARAM_INVALID
,
"Model id %u not found."
,
model_id
);
return
false
;
}
return
model_iter
->
second
->
LoadComplete
();
}
const
std
::
vector
<
uint32_t
>
&
ModelRunner
::
GetTaskIdList
(
uint32_t
model_id
)
const
{
auto
model_iter
=
runtime_models_
.
find
(
model_id
);
if
(
model_iter
==
runtime_models_
.
end
())
{
...
...
@@ -78,38 +60,6 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const
return
model_iter
->
second
->
GetTaskIdList
();
}
const
std
::
vector
<
uint32_t
>
&
ModelRunner
::
GetStreamIdList
(
uint32_t
model_id
)
const
{
auto
model_iter
=
runtime_models_
.
find
(
model_id
);
if
(
model_iter
==
runtime_models_
.
end
())
{
GELOGE
(
PARAM_INVALID
,
"Model id %u not found."
,
model_id
);
static
const
std
::
vector
<
uint32_t
>
empty_ret
;
return
empty_ret
;
}
return
model_iter
->
second
->
GetStreamIdList
();
}
const
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
RuntimeInfo
>>
&
ModelRunner
::
GetRuntimeInfoMap
(
uint32_t
model_id
)
const
{
auto
model_iter
=
runtime_models_
.
find
(
model_id
);
if
(
model_iter
==
runtime_models_
.
end
())
{
GELOGW
(
"Model id %u not found."
,
model_id
);
static
const
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
RuntimeInfo
>>
empty_ret
;
return
empty_ret
;
}
return
model_iter
->
second
->
GetRuntimeInfoMap
();
}
void
*
ModelRunner
::
GetModelHandle
(
uint32_t
model_id
)
const
{
auto
model_iter
=
runtime_models_
.
find
(
model_id
);
if
(
model_iter
==
runtime_models_
.
end
())
{
GELOGW
(
"Model id %u not found."
,
model_id
);
return
nullptr
;
}
return
model_iter
->
second
->
GetModelHandle
();
}
bool
ModelRunner
::
UnloadModel
(
uint32_t
model_id
)
{
auto
iter
=
runtime_models_
.
find
(
model_id
);
if
(
iter
!=
runtime_models_
.
end
())
{
...
...
src/ge/ge_runtime/output.cc
浏览文件 @
81d79f33
...
...
@@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde
DataBuffer
data_buf
=
rslt
->
blobs
[
data_begin
+
data_count
];
bool
ret
=
SetDataBuf
(
data_buf
,
data_begin
,
data_count
,
i
,
support_mem_share
);
if
(
!
ret
)
{
GELOGE
(
FAILED
,
"Copy data to host
error
. index: %lu, addr: %p"
,
i
,
v_input_data_addr_
[
i
]);
GELOGE
(
FAILED
,
"Copy data to host
failed
. index: %lu, addr: %p"
,
i
,
v_input_data_addr_
[
i
]);
return
ret
;
}
data_index
=
data_begin
+
data_count
;
...
...
src/ge/ge_runtime/runtime_model.cc
浏览文件 @
81d79f33
...
...
@@ -28,6 +28,7 @@
namespace
ge
{
namespace
model_runner
{
RuntimeModel
::~
RuntimeModel
()
{
GELOGI
(
"RuntimeModel destructor start"
);
...
...
@@ -115,34 +116,23 @@ bool RuntimeModel::InitEvent(uint32_t event_num) {
return
true
;
}
bool
RuntimeModel
::
InitLabel
(
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
)
{
GELOGI
(
"batch number:%u."
,
davinci_model
->
GetBatchNum
());
label_list_
.
resize
(
davinci_model
->
GetBatchNum
());
for
(
auto
&
task_info
:
davinci_model
->
GetTaskInfoList
())
{
if
(
task_info
==
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"task_info is null."
);
continue
;
}
if
(
task_info
->
type
()
!=
TaskInfoType
::
LABEL_SET
)
{
continue
;
}
auto
label_set_task_info
=
std
::
static_pointer_cast
<
LabelSetTaskInfo
>
(
task_info
);
if
(
label_set_task_info
->
stream_id
()
>=
stream_list_
.
size
())
{
GELOGE
(
PARAM_INVALID
,
"Invalid stream id."
);
bool
RuntimeModel
::
InitLabel
(
uint32_t
batch_num
)
{
GELOGI
(
"batch number:%u."
,
batch_num
);
for
(
uint32_t
i
=
0
;
(
batch_num
!=
0
&&
i
<=
batch_num
);
++
i
)
{
rtLabel_t
rt_lLabel
=
nullptr
;
rtError_t
rt_ret
=
rtLabelCreate
(
&
rt_lLabel
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api rtLabelCreate failed, i; %u; ret: 0x%X"
,
i
,
rt_ret
);
return
false
;
}
rtLabel_t
rt_label
=
nullptr
;
rtError_t
rt_ret
=
rtLabelCreateEx
(
&
rt_label
,
stream_list_
[
label_set_task_info
->
stream_id
()]);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api rtLabelCreate failed, ret: 0x%X"
,
rt_ret
);
if
(
rt_lLabel
==
nullptr
)
{
GELOGE
(
RT_FAILED
,
"rtLabel is nullptr!"
);
return
false
;
}
label_list_
[
label_set_task_info
->
label_id
()]
=
rt_label
;
}
label_list_
.
emplace_back
(
rt_lLabel
);
}
return
true
;
}
...
...
@@ -174,7 +164,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) {
return
false
;
}
if
(
!
InitLabel
(
davinci_model
))
{
if
(
!
InitLabel
(
davinci_model
->
GetBatchNum
()
))
{
return
false
;
}
...
...
@@ -219,41 +209,20 @@ bool RuntimeModel::LoadTask() {
return
false
;
}
task_id_list_
.
push_back
(
task_id
);
stream_id_list_
.
push_back
(
stream_id
);
if
(
task
->
Args
()
!=
nullptr
)
{
std
::
shared_ptr
<
RuntimeInfo
>
runtime_tuple
=
nullptr
;
GE_MAKE_SHARED
(
runtime_tuple
=
std
::
make_shared
<
RuntimeInfo
>
(
task_id
,
stream_id
,
task
->
Args
()),
return
false
);
auto
emplace_ret
=
runtime_info_map_
.
emplace
(
task
->
task_name
(),
runtime_tuple
);
if
(
!
emplace_ret
.
second
)
{
GELOGW
(
"Task name exist:%s"
,
task
->
task_name
().
c_str
());
}
}
}
if
(
task_list_
.
empty
())
{
GELOGE
(
FAILED
,
"Task list is empty"
);
return
false
;
}
GELOGI
(
"Distribute task succ."
);
GELOGI
(
"LoadTask succ."
);
return
true
;
}
bool
RuntimeModel
::
LoadComplete
()
{
uint32_t
task_id
=
0
;
uint32_t
stream_id
=
0
;
auto
rt_ret
=
rtModelGetTaskId
(
rt_model_handle_
,
&
task_id
,
&
stream_id
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rtModelGetTaskId failed, ret:0x%X"
,
rt_ret
);
return
RT_FAILED
;
}
task_id_list_
.
push_back
(
task_id
);
stream_id_list_
.
push_back
(
stream_id
);
rt_ret
=
rtModelLoadComplete
(
rt_model_handle_
);
auto
rt_ret
=
rtModelLoadComplete
(
rt_model_handle_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api rtModelLoadComplete failed, ret: 0x%X."
,
rt_ret
);
return
false
;
}
GELOGI
(
"LoadTask succ."
);
return
true
;
}
...
...
@@ -283,16 +252,14 @@ bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr
}
GenerateTask
(
device_id
,
session_id
,
davinci_model
);
return
status
;
}
bool
RuntimeModel
::
DistributeTask
()
{
bool
status
=
LoadTask
();
status
=
LoadTask
();
if
(
!
status
)
{
GELOGE
(
FAILED
,
"DistributeTask failed"
);
return
false
;
return
status
;
}
return
true
;
return
status
;
}
bool
RuntimeModel
::
Run
()
{
...
...
@@ -303,14 +270,10 @@ bool RuntimeModel::Run() {
return
false
;
}
GELOGI
(
"Run rtModelExecute success
, ret = 0x%X"
,
ret
);
GELOGI
(
"Run rtModelExecute success
"
);
ret
=
rtStreamSynchronize
(
rt_model_stream_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
if
(
ret
==
RT_ERROR_END_OF_SEQUENCE
)
{
GELOGI
(
"Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X"
,
ret
);
return
true
;
}
GELOGE
(
RT_FAILED
,
"Model stream sync failed, ret = 0x%X"
,
ret
);
return
false
;
}
...
...
@@ -470,7 +433,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
}
if
(
constant
->
output_tensors
[
0
].
size
<
constant
->
weight_data
.
size
())
{
GELOGE
(
PARAM_INVALID
,
"Output size:%u less than weight data size:%zu"
,
constant
->
output_tensors
[
0
].
size
,
GELOGE
(
PARAM_INVALID
,
"Output size:%u
is
less than weight data size:%zu"
,
constant
->
output_tensors
[
0
].
size
,
constant
->
weight_data
.
size
());
return
false
;
}
...
...
@@ -485,8 +448,11 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero
/// and that of unknown shape is zero too.
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not.
int64_t
elem_num
=
(
constant
->
weight_tensors
[
0
].
GetShapeSize
()
==
0
)
?
1
:
constant
->
weight_tensors
[
0
].
GetShapeSize
();
int64_t
elem_num
=
constant
->
weight_tensors
[
0
].
GetShapeSize
();
if
(
elem_num
==
0
&&
constant
->
weight_tensors
[
0
].
size
==
0
)
{
elem_num
=
1
;
}
if
(
constant
->
weight_data
.
size
()
<
sizeof
(
uint64_t
))
{
GELOGE
(
FAILED
,
"weight_data size is smaller than sizeof(uint64_t)"
);
return
false
;
...
...
@@ -529,6 +495,5 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp
const
std
::
vector
<
uint32_t
>
&
RuntimeModel
::
GetTaskIdList
()
const
{
return
task_id_list_
;
}
const
std
::
vector
<
uint32_t
>
&
RuntimeModel
::
GetStreamIdList
()
const
{
return
stream_id_list_
;
}
}
// namespace model_runner
}
// namespace ge
src/ge/ge_runtime/runtime_model.h
浏览文件 @
81d79f33
...
...
@@ -27,7 +27,7 @@
namespace
ge
{
namespace
model_runner
{
using
RuntimeInfo
=
std
::
tuple
<
uint32_t
,
uint32_t
,
void
*>
;
class
Task
;
class
RuntimeModel
{
public:
...
...
@@ -35,12 +35,7 @@ class RuntimeModel {
~
RuntimeModel
();
bool
Load
(
uint32_t
device_id
,
uint64_t
session_id
,
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
);
bool
DistributeTask
();
bool
LoadComplete
();
const
std
::
vector
<
uint32_t
>
&
GetTaskIdList
()
const
;
const
std
::
vector
<
uint32_t
>
&
GetStreamIdList
()
const
;
const
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
RuntimeInfo
>>
&
GetRuntimeInfoMap
()
const
{
return
runtime_info_map_
;
}
rtModel_t
GetModelHandle
()
const
{
return
rt_model_handle_
;
}
bool
Run
();
bool
CopyInputData
(
const
InputData
&
input_data
);
bool
GetInputOutputDescInfo
(
bool
zero_copy
,
std
::
vector
<
InputOutputDescInfo
>
*
input_desc
,
...
...
@@ -53,7 +48,7 @@ class RuntimeModel {
bool
LoadTask
();
bool
InitStream
(
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
);
bool
InitEvent
(
uint32_t
event_num
);
bool
InitLabel
(
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
);
bool
InitLabel
(
uint32_t
batch_num
);
bool
InitDataInfo
(
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
);
bool
InitOutputInfo
(
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
);
bool
InitConstantInfo
(
std
::
shared_ptr
<
DavinciModel
>
&
davinci_model
);
...
...
@@ -82,8 +77,6 @@ class RuntimeModel {
std
::
vector
<
std
::
shared_ptr
<
OpInfo
>>
constant_info_list_
{};
std
::
vector
<
uint32_t
>
task_id_list_
{};
std
::
vector
<
uint32_t
>
stream_id_list_
{};
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
RuntimeInfo
>>
runtime_info_map_
;
};
}
// namespace model_runner
...
...
src/ge/ge_runtime/task/aicpu_task.cc
浏览文件 @
81d79f33
...
...
@@ -85,15 +85,11 @@ bool AicpuTask::Distribute() {
return
false
;
}
input_output_addr_
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint8_t
*>
(
args_
)
+
io_addr_offset
);
auto
dump_flag
=
task_info_
->
dump_flag
()
?
RT_KERNEL_DUMPFLAG
:
RT_KERNEL_DEFAULT
;
GELOGI
(
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d."
,
args_size
,
io_addrs_num
,
task_info_
->
so_name
().
data
(),
task_info_
->
kernel_name
().
data
(),
dump_flag
);
rt_ret
=
rtCpuKernelLaunchWithFlag
(
reinterpret_cast
<
const
void
*>
(
task_info_
->
so_name
().
data
()),
reinterpret_cast
<
const
void
*>
(
task_info_
->
kernel_name
().
data
()),
1
,
args_
,
args_size
,
nullptr
,
stream_
,
dump_flag
);
GELOGI
(
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s."
,
args_size
,
io_addrs_num
,
task_info_
->
so_name
().
data
(),
task_info_
->
kernel_name
().
data
());
rt_ret
=
rtCpuKernelLaunch
(
reinterpret_cast
<
const
void
*>
(
task_info_
->
so_name
().
data
()),
reinterpret_cast
<
const
void
*>
(
task_info_
->
kernel_name
().
data
()),
1
,
args_
,
args_size
,
nullptr
,
stream_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
return
false
;
...
...
src/ge/ge_runtime/task/aicpu_task.h
浏览文件 @
81d79f33
...
...
@@ -18,7 +18,6 @@
#define GE_GE_RUNTIME_TASK_AICPU_TASK_H_
#include <memory>
#include <string>
#include "ge_runtime/task/task.h"
namespace
ge
{
...
...
@@ -31,17 +30,12 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {
bool
Distribute
()
override
;
void
*
Args
()
override
{
return
input_output_addr_
;
}
std
::
string
task_name
()
const
override
{
return
task_info_
->
op_name
();
}
private:
static
void
ReleaseRtMem
(
void
**
ptr
)
noexcept
;
std
::
shared_ptr
<
AicpuTaskInfo
>
task_info_
;
void
*
stream_
;
void
*
args_
;
void
*
input_output_addr_
;
};
}
// namespace model_runner
}
// namespace ge
...
...
src/ge/ge_runtime/task/hccl_task.cc
浏览文件 @
81d79f33
...
...
@@ -115,6 +115,7 @@ bool HcclTask::Distribute() {
rt_ret
=
rtModelBindStream
(
rt_model_handle_
,
stream
,
RT_HEAD_STREAM
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
(
void
)
rtStreamDestroy
(
stream
);
return
false
;
}
...
...
src/ge/ge_runtime/task/label_goto_task.cc
已删除
100644 → 0
浏览文件 @
9275e7b0
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"
namespace
ge
{
namespace
model_runner
{
LabelGotoTask
::
LabelGotoTask
(
const
ModelContext
&
model_context
,
const
std
::
shared_ptr
<
LabelGotoTaskInfo
>
&
task_info
)
:
TaskRepeater
<
LabelGotoTaskInfo
>
(
model_context
,
task_info
),
task_info_
(
task_info
),
stream_
(
nullptr
),
label_
(
nullptr
)
{
if
(
task_info_
==
nullptr
)
{
GELOGW
(
"task_info_ is null!"
);
return
;
}
auto
stream_list
=
model_context
.
stream_list
();
auto
label_list
=
model_context
.
label_list
();
uint32_t
stream_id
=
task_info
->
stream_id
();
uint32_t
label_id
=
task_info
->
label_id
();
GELOGI
(
"Stream list size:%zu, stream id:%u."
,
stream_list
.
size
(),
stream_id
);
GELOGI
(
"Label list size:%zu, label id:%u."
,
label_list
.
size
(),
label_id
);
if
(
stream_id
>=
stream_list
.
size
()
||
label_id
>=
label_list
.
size
())
{
GELOGW
(
"Stream/Label id invalid."
);
return
;
}
stream_
=
stream_list
[
stream_id
];
label_
=
label_list
[
label_id
];
}
LabelGotoTask
::~
LabelGotoTask
()
{}
bool
LabelGotoTask
::
Distribute
()
{
GELOGI
(
"LabelGotoTask Distribute start."
);
if
(
stream_
==
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"stream is null!"
);
return
false
;
}
if
(
label_
==
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"label is null!"
);
return
false
;
}
rtError_t
rt_ret
=
rtLabelGotoEx
(
label_
,
stream_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
return
false
;
}
GELOGI
(
"DistributeTask end."
);
return
true
;
}
REGISTER_TASK
(
TaskInfoType
::
LABEL_GOTO
,
LabelGotoTask
,
LabelGotoTaskInfo
);
}
// namespace model_runner
}
// namespace ge
src/ge/ge_runtime/task/label_goto_task.h
已删除
100644 → 0
浏览文件 @
9275e7b0
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace
ge
{
namespace
model_runner
{
class
LabelGotoTask
:
public
TaskRepeater
<
LabelGotoTaskInfo
>
{
public:
LabelGotoTask
(
const
ModelContext
&
model_context
,
const
std
::
shared_ptr
<
LabelGotoTaskInfo
>
&
task_info
);
~
LabelGotoTask
()
override
;
bool
Distribute
()
override
;
private:
std
::
shared_ptr
<
LabelGotoTaskInfo
>
task_info_
;
void
*
stream_
;
void
*
label_
;
};
}
// namespace model_runner
}
// namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
src/ge/ge_runtime/task/label_set_task.cc
已删除
100644 → 0
浏览文件 @
9275e7b0
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_set_task.h"
#include "ge_runtime/task/task_factory.h"
namespace
ge
{
namespace
model_runner
{
LabelSetTask
::
LabelSetTask
(
const
ModelContext
&
model_context
,
const
std
::
shared_ptr
<
LabelSetTaskInfo
>
&
task_info
)
:
TaskRepeater
<
LabelSetTaskInfo
>
(
model_context
,
task_info
),
task_info_
(
task_info
),
stream_
(
nullptr
),
label_
(
nullptr
)
{
if
(
task_info_
==
nullptr
)
{
GELOGW
(
"task_info_ is null!"
);
return
;
}
auto
stream_list
=
model_context
.
stream_list
();
auto
label_list
=
model_context
.
label_list
();
uint32_t
stream_id
=
task_info
->
stream_id
();
uint32_t
label_id
=
task_info
->
label_id
();
GELOGI
(
"Stream list size:%zu, stream id:%u."
,
stream_list
.
size
(),
stream_id
);
GELOGI
(
"Label list size:%zu, label id:%u."
,
label_list
.
size
(),
label_id
);
if
(
stream_id
>=
stream_list
.
size
()
||
label_id
>=
label_list
.
size
())
{
GELOGW
(
"Stream/Label id invalid."
);
return
;
}
stream_
=
stream_list
[
stream_id
];
label_
=
label_list
[
label_id
];
}
LabelSetTask
::~
LabelSetTask
()
{}
bool
LabelSetTask
::
Distribute
()
{
GELOGI
(
"LabelSetTask Distribute start."
);
if
(
stream_
==
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"stream is null!"
);
return
false
;
}
if
(
label_
==
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"label is null!"
);
return
false
;
}
rtError_t
rt_ret
=
rtLabelSet
(
label_
,
stream_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
return
false
;
}
GELOGI
(
"DistributeTask end."
);
return
true
;
}
REGISTER_TASK
(
TaskInfoType
::
LABEL_SET
,
LabelSetTask
,
LabelSetTaskInfo
);
}
// namespace model_runner
}
// namespace ge
src/ge/ge_runtime/task/label_set_task.h
已删除
100644 → 0
浏览文件 @
9275e7b0
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace
ge
{
namespace
model_runner
{
class
LabelSetTask
:
public
TaskRepeater
<
LabelSetTaskInfo
>
{
public:
LabelSetTask
(
const
ModelContext
&
model_context
,
const
std
::
shared_ptr
<
LabelSetTaskInfo
>
&
task_info
);
~
LabelSetTask
()
override
;
bool
Distribute
()
override
;
private:
std
::
shared_ptr
<
LabelSetTaskInfo
>
task_info_
;
void
*
stream_
;
void
*
label_
;
};
}
// namespace model_runner
}
// namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
src/ge/ge_runtime/task/label_switch_task.cc
已删除
100644 → 0
浏览文件 @
9275e7b0
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_switch_task.h"
#include "ge_runtime/task/task_factory.h"
namespace
ge
{
namespace
model_runner
{
LabelSwitchTask
::
LabelSwitchTask
(
const
ModelContext
&
model_context
,
const
std
::
shared_ptr
<
LabelSwitchTaskInfo
>
&
task_info
)
:
TaskRepeater
<
LabelSwitchTaskInfo
>
(
model_context
,
task_info
),
task_info_
(
task_info
),
stream_
(
nullptr
),
all_label_resource_
(),
label_info_
(
nullptr
)
{
if
(
task_info_
==
nullptr
)
{
GELOGW
(
"task_info_ is null!"
);
return
;
}
all_label_resource_
=
model_context
.
label_list
();
auto
stream_list
=
model_context
.
stream_list
();
uint32_t
stream_id
=
task_info
->
stream_id
();
GELOGI
(
"Stream list size:%zu, stream id:%u."
,
stream_list
.
size
(),
stream_id
);
if
(
stream_id
>=
stream_list
.
size
())
{
GELOGW
(
"Stream id invalid."
);
return
;
}
stream_
=
stream_list
[
stream_id
];
}
LabelSwitchTask
::~
LabelSwitchTask
()
{
if
(
label_info_
!=
nullptr
)
{
rtError_t
rt_ret
=
rtFree
(
label_info_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"rtFree fwkOpBuf failed! ret: 0x%X."
,
rt_ret
);
}
label_info_
=
nullptr
;
}
}
bool
LabelSwitchTask
::
Distribute
()
{
GELOGI
(
"LabelSwitchTask Distribute start."
);
if
(
!
CheckParamValid
())
{
return
false
;
}
const
std
::
vector
<
uint32_t
>
&
label_index_list
=
task_info_
->
label_list
();
std
::
vector
<
void
*>
label_list
(
task_info_
->
label_size
(),
nullptr
);
for
(
size_t
i
=
0
;
i
<
task_info_
->
label_size
();
++
i
)
{
uint32_t
label_index
=
label_index_list
[
i
];
if
(
label_index
>=
all_label_resource_
.
size
())
{
GELOGE
(
PARAM_INVALID
,
"label %zu index is %u, but there are %zu labels in total."
,
i
,
label_index
,
all_label_resource_
.
size
());
return
false
;
}
label_list
[
i
]
=
all_label_resource_
[
label_index
];
GELOGI
(
"Case %zu: label id %zu."
,
i
,
label_index
);
}
uint32_t
label_info_size
=
sizeof
(
rtLabelDevInfo
)
*
task_info_
->
label_size
();
rtError_t
rt_ret
=
rtMalloc
(
&
label_info_
,
label_info_size
,
RT_MEMORY_HBM
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
return
false
;
}
rt_ret
=
rtLabelListCpy
(
label_list
.
data
(),
label_list
.
size
(),
label_info_
,
label_info_size
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
return
false
;
}
rt_ret
=
rtLabelSwitchByIndex
(
task_info_
->
cond
(),
label_list
.
size
(),
label_info_
,
stream_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api failed, ret: 0x%X"
,
rt_ret
);
return
false
;
}
GELOGI
(
"DistributeTask end."
);
return
true
;
}
bool
LabelSwitchTask
::
CheckParamValid
()
{
if
(
stream_
==
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"stream is null!"
);
return
false
;
}
if
(
task_info_
->
label_list
().
empty
())
{
GELOGE
(
PARAM_INVALID
,
"label_list is empty."
);
return
false
;
}
if
(
task_info_
->
label_size
()
!=
task_info_
->
label_list
().
size
())
{
GELOGE
(
PARAM_INVALID
,
"label_list size %zu but label_size is %u."
,
task_info_
->
label_list
().
size
(),
task_info_
->
label_size
());
return
false
;
}
if
(
task_info_
->
label_size
()
>=
UINT32_MAX
/
sizeof
(
rtLabelDevInfo
))
{
GELOGE
(
PARAM_INVALID
,
"label_size %u will overflow."
,
task_info_
->
label_size
());
return
false
;
}
if
(
label_info_
!=
nullptr
)
{
GELOGE
(
PARAM_INVALID
,
"label_info_ has dirty data."
);
return
false
;
}
return
true
;
}
REGISTER_TASK
(
TaskInfoType
::
LABEL_SWITCH
,
LabelSwitchTask
,
LabelSwitchTaskInfo
);
}
// namespace model_runner
}
// namespace ge
src/ge/ge_runtime/task/label_switch_task.h
已删除
100644 → 0
浏览文件 @
9275e7b0
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace
ge
{
namespace
model_runner
{
class
LabelSwitchTask
:
public
TaskRepeater
<
LabelSwitchTaskInfo
>
{
public:
LabelSwitchTask
(
const
ModelContext
&
model_context
,
const
std
::
shared_ptr
<
LabelSwitchTaskInfo
>
&
task_info
);
~
LabelSwitchTask
()
override
;
bool
Distribute
()
override
;
private:
bool
CheckParamValid
();
std
::
shared_ptr
<
LabelSwitchTaskInfo
>
task_info_
;
void
*
stream_
;
std
::
vector
<
void
*>
all_label_resource_
;
void
*
label_info_
;
};
}
// namespace model_runner
}
// namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
src/ge/ge_runtime/task/stream_switch_task.cc
浏览文件 @
81d79f33
...
...
@@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() {
}
if
(
static_cast
<
uint64_t
>
(
task_info_
->
true_stream_id
())
>=
stream_list_
.
size
())
{
GELOGE
(
PARAM_INVALID
,
"true_stream_id %ld must less than stream_list_ size %zu!"
,
task_info_
->
true_stream_id
(),
GELOGE
(
PARAM_INVALID
,
"true_stream_id %ld must
be
less than stream_list_ size %zu!"
,
task_info_
->
true_stream_id
(),
stream_list_
.
size
());
return
false
;
}
...
...
src/ge/ge_runtime/task/task.h
浏览文件 @
81d79f33
...
...
@@ -18,9 +18,7 @@
#define GE_GE_RUNTIME_TASK_TASK_H_
#include <memory>
#include <utility>
#include <vector>
#include <string>
#include "runtime/rt_model.h"
#include "ge_runtime/model_context.h"
#include "ge_runtime/task_info.h"
...
...
@@ -34,10 +32,6 @@ class Task {
virtual
~
Task
()
{}
virtual
bool
Distribute
()
=
0
;
virtual
void
*
Args
()
{
return
nullptr
;
}
virtual
std
::
string
task_name
()
const
{
return
""
;
}
};
template
<
class
T
>
...
...
src/ge/ge_runtime/task/tbe_task.cc
浏览文件 @
81d79f33
...
...
@@ -95,14 +95,15 @@ bool TbeTask::Distribute() {
return
false
;
}
GELOGI
(
"InitTbeTask end."
);
GELOGI
(
"DistributeTbeTask start."
);
auto
dump_flag
=
task_info_
->
dump_flag
()
?
RT_KERNEL_DUMPFLAG
:
RT_KERNEL_DEFAULT
;
rt_ret
=
rtKernelLaunchWithFlag
(
stub_func_
,
task_info_
->
block_dim
(),
args_
,
args_size
,
nullptr
,
stream_
,
dump_flag
);
rt_ret
=
rtKernelLaunch
(
stub_func_
,
task_info_
->
block_dim
(),
args_
,
args_size
,
nullptr
,
stream_
);
if
(
rt_ret
!=
RT_ERROR_NONE
)
{
GELOGE
(
RT_FAILED
,
"Call rt api rtKernelLaunch failed, ret: 0x%X"
,
rt_ret
);
return
false
;
}
GELOGI
(
"[DataDump] task name:%s, dump_flag:%d"
,
task_info_
->
op_name
().
c_str
(),
dump_flag
);
GELOGI
(
"DistributeTbeTask end."
);
return
true
;
}
...
...
src/ge/ge_runtime/task/tbe_task.h
浏览文件 @
81d79f33
...
...
@@ -30,10 +30,6 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {
bool
Distribute
()
override
;
void
*
Args
()
override
{
return
args_
;
}
std
::
string
task_name
()
const
override
{
return
task_info_
->
op_name
();
}
private:
std
::
shared_ptr
<
TbeTaskInfo
>
task_info_
;
void
*
stream_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录