Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
81d8c73a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
81d8c73a
编写于
2月 07, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(dispatch/trace): serval tricks to speed up trace
GitOrigin-RevId: 2bdd70cde2d19f43055218804abd65f4e5a54b89
上级
4fa61620
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
26 addition
and
21 deletion
+26
-21
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+23
-18
imperative/src/include/megbrain/imperative/transformations/trace.h
...e/src/include/megbrain/imperative/transformations/trace.h
+3
-3
未找到文件。
imperative/src/impl/transformations/trace.cpp
浏览文件 @
81d8c73a
...
@@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump(
...
@@ -47,7 +47,7 @@ VarNodeArray TraceResult::dump(
auto
&
node
=
nodes
[
input
];
auto
&
node
=
nodes
[
input
];
// TODO: cambricon CompNode
// TODO: cambricon CompNode
auto
host
=
std
::
make_shared
<
HostTensorND
>
(
auto
host
=
std
::
make_shared
<
HostTensorND
>
(
CompNode
::
load
(
"xpux"
),
shape
,
var
.
dtype
);
CompNode
::
load
(
"xpux"
),
shape
,
*
var
.
dtype
);
OperatorNodeConfig
config
;
OperatorNodeConfig
config
;
// if prefer_input_names, prefer names from dump args
// if prefer_input_names, prefer names from dump args
// else prefer names got from trace procedure
// else prefer names got from trace procedure
...
@@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation(
...
@@ -211,7 +211,6 @@ ValueRefList TracingTransformation::apply_transformation(
auto
&
var_info
=
m_vars
[
tracing_value
->
id
()];
auto
&
var_info
=
m_vars
[
tracing_value
->
id
()];
switch
(
get_attr
->
attr
())
{
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
Shape
:
case
GetAttr
::
Shape
:
// TODO: reduce h2d when data or value is available
var_info
.
shape_required
=
true
;
var_info
.
shape_required
=
true
;
break
;
break
;
case
GetAttr
::
Data
:
case
GetAttr
::
Data
:
...
@@ -301,8 +300,8 @@ void CompiledTransformation::compile() {
...
@@ -301,8 +300,8 @@ void CompiledTransformation::compile() {
auto
box
=
make_box
<
DeviceTensorND
>
();
auto
box
=
make_box
<
DeviceTensorND
>
();
// TODO: attach ref count, release early
// TODO: attach ref count, release early
auto
outputs
=
opr
::
InputCallback
::
make
(
auto
outputs
=
opr
::
InputCallback
::
make
(
*
m_graph
,
[
box
]
{
return
box
->
take_value
();
},
var_info
->
device
,
*
m_graph
,
[
box
]
{
return
box
->
take_value
();
},
*
var_info
->
device
,
var_info
->
dtype
,
var_info
->
shape
,
io_links
,
m_input_shape_static
);
*
var_info
->
dtype
,
var_info
->
shape
,
io_links
,
m_input_shape_static
);
// attach input_callback to io_links
// attach input_callback to io_links
accessor
.
node
=
outputs
[
0
].
node
();
accessor
.
node
=
outputs
[
0
].
node
();
io_links
=
{
outputs
[
1
]};
io_links
=
{
outputs
[
1
]};
...
@@ -312,6 +311,11 @@ void CompiledTransformation::compile() {
...
@@ -312,6 +311,11 @@ void CompiledTransformation::compile() {
auto
make_output
=
[
&
](
TraceResult
::
VarInfo
*
var_info
,
SymbolVar
node
)
{
auto
make_output
=
[
&
](
TraceResult
::
VarInfo
*
var_info
,
SymbolVar
node
)
{
VarAccessor
accessor
;
VarAccessor
accessor
;
accessor
.
node
=
node
.
node
();
accessor
.
node
=
node
.
node
();
if
(
var_info
->
data_required
)
{
// reduce d2h when data is available
// FIXME: compile should not change var_info in-place
var_info
->
shape_required
=
false
;
}
if
(
var_info
->
shape_required
)
{
if
(
var_info
->
shape_required
)
{
// TODO: use static infer manager for some vars?
// TODO: use static infer manager for some vars?
auto
box
=
make_box
<
TensorShape
>
();
auto
box
=
make_box
<
TensorShape
>
();
...
@@ -334,6 +338,12 @@ void CompiledTransformation::compile() {
...
@@ -334,6 +338,12 @@ void CompiledTransformation::compile() {
accessor
.
data_getter
=
[
box
]()
->
DeviceTensorND
{
accessor
.
data_getter
=
[
box
]()
->
DeviceTensorND
{
return
box
->
get_value
();
return
box
->
get_value
();
};
};
if
(
!
accessor
.
shape_getter
)
{
// also implement shape_getter
accessor
.
shape_getter
=
[
box
]()
->
TensorShape
{
return
box
->
get_value
().
shape
();
};
}
}
}
if
(
var_info
->
value_required
)
{
if
(
var_info
->
value_required
)
{
struct
ValueWithEvent
{
struct
ValueWithEvent
{
...
@@ -341,7 +351,7 @@ void CompiledTransformation::compile() {
...
@@ -341,7 +351,7 @@ void CompiledTransformation::compile() {
CompNode
::
Event
*
event
=
nullptr
;
CompNode
::
Event
*
event
=
nullptr
;
};
};
auto
box
=
make_box
<
ValueWithEvent
>
();
auto
box
=
make_box
<
ValueWithEvent
>
();
auto
event
=
EventPool
::
without_timer
().
alloc_shared
(
var_info
->
device
);
auto
event
=
EventPool
::
without_timer
().
alloc_shared
(
*
var_info
->
device
);
auto
callback
=
[
box
,
event
](
DeviceTensorND
data
)
{
auto
callback
=
[
box
,
event
](
DeviceTensorND
data
)
{
HostTensorND
host_val
;
HostTensorND
host_val
;
host_val
.
copy_from
(
data
);
host_val
.
copy_from
(
data
);
...
@@ -355,7 +365,7 @@ void CompiledTransformation::compile() {
...
@@ -355,7 +365,7 @@ void CompiledTransformation::compile() {
};
};
SymbolVarArray
inputs
=
io_links
;
SymbolVarArray
inputs
=
io_links
;
inputs
.
insert
(
inputs
.
begin
(),
node
);
inputs
.
insert
(
inputs
.
begin
(),
node
);
auto
output
=
opr
::
OutputCallback
::
make
({
callback
,
fals
e
,
true
},
inputs
);
auto
output
=
opr
::
OutputCallback
::
make
({
callback
,
tru
e
,
true
},
inputs
);
io_links
=
{
output
};
io_links
=
{
output
};
accessor
.
value_getter
=
[
box
]()
->
HostTensorND
{
accessor
.
value_getter
=
[
box
]()
->
HostTensorND
{
auto
&&
[
value
,
event
]
=
box
->
get_value
();
auto
&&
[
value
,
event
]
=
box
->
get_value
();
...
@@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
...
@@ -486,11 +496,12 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
DType
dtype
=
*
value
.
dtype
();
DType
dtype
=
*
value
.
dtype
();
CompNode
device
=
*
value
.
device
();
CompNode
device
=
*
value
.
device
();
trace_assert
(
trace_assert
(
var
.
dtype
==
dtype
,
"dtype mismatch: %s vs %s"
,
*
var
.
dtype
==
dtype
,
"dtype mismatch: %s vs %s"
,
var
.
dtype
.
name
(),
dtype
.
name
());
var
.
dtype
->
name
(),
dtype
.
name
());
trace_assert
(
trace_assert
(
var
.
device
==
device
,
"comp_node mismatch: %s vs %s"
,
*
var
.
device
==
device
,
"comp_node mismatch: %s vs %s"
,
var
.
device
.
to_string
().
c_str
(),
device
.
to_string
().
c_str
());
var
.
device
->
to_string
().
c_str
(),
device
.
to_string
().
c_str
());
}
}
var_accessor
.
data_setter
(
value
.
dev_tensor
()
->
as_nd
());
var_accessor
.
data_setter
(
value
.
dev_tensor
()
->
as_nd
());
break
;
break
;
...
@@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
...
@@ -535,17 +546,11 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
}
}
DTypeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
dtype
()
const
{
DTypeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
dtype
()
const
{
if
(
!
m_dtype
)
{
return
m_var
->
dtype
;
m_dtype
=
DTypeValue
::
make
(
m_var
->
dtype
);
}
return
m_dtype
;
}
}
CompNodeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
comp_node
()
const
{
CompNodeValue
::
ref_t
CompiledTransformation
::
TracedInfo
::
comp_node
()
const
{
if
(
!
m_comp_node
)
{
return
m_var
->
device
;
m_comp_node
=
CompNodeValue
::
make
(
m_var
->
device
);
}
return
m_comp_node
;
}
}
auto
CompiledTransformation
::
TracedInfo
::
accessor
()
const
->
const
VarAccessor
&
{
auto
CompiledTransformation
::
TracedInfo
::
accessor
()
const
->
const
VarAccessor
&
{
return
*
m_accessor
;
return
*
m_accessor
;
...
...
imperative/src/include/megbrain/imperative/transformations/trace.h
浏览文件 @
81d8c73a
...
@@ -44,8 +44,8 @@ struct TraceResult {
...
@@ -44,8 +44,8 @@ struct TraceResult {
};
};
size_t
id
;
size_t
id
;
DType
dtype
;
DType
Value
::
ref_t
dtype
;
CompNode
device
;
CompNode
Value
::
ref_t
device
;
// if exists, assert equal when meet
// if exists, assert equal when meet
ValueRef
bound_data
;
ValueRef
bound_data
;
...
@@ -162,7 +162,7 @@ public:
...
@@ -162,7 +162,7 @@ public:
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
TypedValueRef
<
TracingValue
>
record_var
(
ValueRef
value
,
bool
capture
,
VarKind
kind
)
{
size_t
id
=
m_vars
.
size
();
size_t
id
=
m_vars
.
size
();
auto
wrapped_value
=
TracingValue
::
make
(
value
,
id
);
auto
wrapped_value
=
TracingValue
::
make
(
value
,
id
);
m_vars
.
push_back
({
id
,
*
value
.
dtype
(),
*
value
.
device
()});
m_vars
.
push_back
({
id
,
value
.
dtype
(),
value
.
device
()});
auto
&
var
=
m_vars
.
back
();
auto
&
var
=
m_vars
.
back
();
if
(
capture
)
{
if
(
capture
)
{
var
.
bound_data
=
value
;
var
.
bound_data
=
value
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录