Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
aeb7980b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
aeb7980b
编写于
12月 21, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mgb): outputs of the same opr and same compnode share the same callbackcaller
GitOrigin-RevId: 59b8e3bcbe0dd76f80f85bc1f46733364df769d1
上级
89b6dbc7
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
147 addition
and
28 deletion
+147
-28
src/core/impl/graph/cg_impl.cpp
src/core/impl/graph/cg_impl.cpp
+75
-28
src/core/impl/graph/cg_impl.h
src/core/impl/graph/cg_impl.h
+22
-0
src/core/test/graph/misc.cpp
src/core/test/graph/misc.cpp
+50
-0
未找到文件。
src/core/impl/graph/cg_impl.cpp
浏览文件 @
aeb7980b
...
...
@@ -148,13 +148,15 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) {
/* ========================== CallbackCaller ========================== */
MGB_DEFINE_OPR_CLASS
(
ComputingGraphImpl
::
CallbackCaller
,
SingleCNOperatorNodeBase
)
// {
std
::
vector
<
ComputingGraph
::
Callback
>
m_cb
;
std
::
vector
<
std
::
vector
<
ComputingGraph
::
Callback
>
>
m_cb
;
void
scn_do_execute
()
override
{
auto
&&
dv
=
input
(
0
)
->
dev_tensor
();
for
(
auto
&&
i
:
m_cb
)
{
// const cast for backward API compatibility
i
(
const_cast
<
DeviceTensorND
&>
(
dv
));
for
(
size_t
i
=
0
;
i
<
input
().
size
();
++
i
)
{
auto
&&
in
=
input
(
i
)
->
dev_tensor
();
for
(
auto
&&
callback
:
m_cb
[
i
])
{
// const cast for backward API compatibility
callback
(
const_cast
<
DeviceTensorND
&>
(
in
));
}
}
}
...
...
@@ -168,14 +170,29 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
if
(
owner_graph
()
->
options
().
comp_node_seq_record_level
)
{
// the user callback usually copies from device to host, which
// involves tmp alloc if input is not contiguous
input
(
0
)
->
add_layout_constraint_contiguous
();
for
(
auto
&&
inp
:
input
())
{
inp
->
add_layout_constraint_contiguous
();
}
}
}
void
init_output_dtype
()
override
{
if
(
output
(
0
)
->
dtype
().
valid
())
{
return
;
}
mgb_assert
(
!
input
().
empty
());
DType
dtype
=
input
(
0
)
->
dtype
();
mgb_assert
(
dtype
.
valid
()
&&
dtype
!=
dtype
::
Byte
());
output
(
0
)
->
dtype
(
dtype
);
}
NodeProp
*
do_make_node_prop
()
const
override
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
for
(
auto
&&
inp
:
input
())
{
ret
->
add_dep_type_existing_var
(
inp
,
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
}
return
ret
;
}
...
...
@@ -185,25 +202,38 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
}
public
:
CallbackCaller
(
VarNode
*
inp
)
:
Super
{
inp
->
owner_graph
(),
{},
"callback"
,
{
inp
}}
{
add_input
({
inp
});
CallbackCaller
(
const
VarNodeArrayView
&
inp
)
:
Super
{
inp
[
0
]
->
owner_graph
(),
{},
"callback"
,
inp
}
{
mgb_assert
(
!
inp
.
empty
());
m_cb
.
resize
(
inp
.
size
());
for
(
auto
&&
i
:
inp
)
{
add_input
({
i
});
}
using
F
=
VarNode
::
Flag
;
add_output
(
None
)
->
add_flag
(
F
::
ALLOW_EMPTY_SHAPE
)
.
add_flag
(
F
::
VOLATILE_CONTENT
);
}
static
SymbolVar
make
(
SymbolVar
inp
)
{
return
inp
.
insert_single_output_opr
<
CallbackCaller
>
(
inp
.
node
());
static
SymbolVar
make
(
const
VarNodeArrayView
&
inp
)
{
mgb_assert
(
!
inp
.
empty
());
return
SymbolVar
{
inp
[
0
]}
.
node
()
->
owner_graph
()
->
insert_opr
(
std
::
make_unique
<
CallbackCaller
>
(
inp
))
->
output
(
0
);
}
void
add_callback
(
const
ComputingGraph
::
Callback
&
cb
)
{
mgb_assert
(
cb
);
m_cb
.
push_back
(
cb
);
void
add_callback
(
const
ComputingGraph
::
Callback
&
cb
,
size_t
i
=
0
)
{
mgb_assert
(
cb
&&
i
<
m_cb
.
size
()
);
m_cb
[
i
]
.
push_back
(
cb
);
}
void
clear_callback
()
{
m_cb
.
clear
();
}
void
clear_callback
()
{
for
(
size_t
i
=
0
;
i
<
m_cb
.
size
();
++
i
)
{
m_cb
[
i
].
clear
();
}
}
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ComputingGraphImpl
::
CallbackCaller
);
...
...
@@ -529,22 +559,39 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
cmpnt
.
seq_comp_node_opt
.
optimize_comp_nodes
(
dest_vars
);
auto
init_opr_seq
=
[
&
]()
{
ThinHashMap
<
VarNode
*
,
CallbackCaller
*>
var2cb_caller
;
ThinHashMap
<
VarNode
*
,
size_t
>
var2idx
;
std
::
unordered_map
<
CallbackCallerKey
,
CallbackCallerVal
,
CallbackCallerKey
::
Hash
>
opr2vars
;
for
(
size_t
i
=
0
;
i
<
out_spec
.
size
();
++
i
)
{
auto
&&
cb
=
out_spec
[
i
].
second
;
if
(
cb
)
{
auto
var
=
dest_vars
[
i
];
auto
&&
cb_caller
=
var2cb_caller
[
var
];
if
(
!
cb_caller
)
{
auto
dvar
=
CallbackCaller
::
make
(
var
);
cb_caller
=
&
dvar
.
node
()
->
owner_opr
()
->
cast_final_safe
<
CallbackCaller
>
();
++
extra_info
.
var2recvinfo
[
dvar
.
node
()].
nr_direct_comp_req
;
cb_caller
->
clear_callback
();
CallbackCallerKey
key
{
var
->
owner_opr
(),
var
->
comp_node
()};
auto
&&
vals
=
opr2vars
[
key
];
auto
&&
var2idx_iter
=
var2idx
.
find
(
var
);
if
(
var2idx_iter
==
var2idx
.
end
())
{
vals
.
vars
.
push_back
(
var
);
vals
.
indexs
.
push_back
({
i
});
var2idx
[
var
]
=
vals
.
vars
.
size
()
-
1
;
}
else
{
vals
.
indexs
[
var2idx_iter
->
second
].
push_back
(
i
);
}
}
}
for
(
auto
&
item
:
opr2vars
)
{
auto
&&
val
=
item
.
second
;
auto
dvar
=
CallbackCaller
::
make
(
val
.
vars
);
CallbackCaller
*
cb_caller
=
&
dvar
.
node
()
->
owner_opr
()
->
cast_final_safe
<
CallbackCaller
>
();
++
extra_info
.
var2recvinfo
[
dvar
.
node
()].
nr_direct_comp_req
;
cb_caller
->
clear_callback
();
for
(
size_t
i
=
0
;
i
<
val
.
vars
.
size
();
++
i
)
{
for
(
auto
&&
idx
:
val
.
indexs
[
i
])
{
cb_caller
->
add_callback
(
out_spec
[
idx
].
second
,
i
);
dest_vars
[
idx
]
=
cb_caller
->
output
(
0
);
}
cb_caller
->
add_callback
(
cb
);
dest_vars
[
i
]
=
cb_caller
->
output
(
0
);
}
}
opr_seq
=
topo_sorter
().
get_comp_seq
(
extra_info
,
dest_vars
);
...
...
src/core/impl/graph/cg_impl.h
浏览文件 @
aeb7980b
...
...
@@ -40,6 +40,28 @@ class ComputingGraphImpl final : public ComputingGraph {
const
OprNodeArray
*
opr_seq
=
nullptr
;
};
struct
CallbackCallerKey
{
OperatorNodeBase
*
opr
;
CompNode
comp_node
;
bool
operator
==
(
const
CallbackCallerKey
&
rhs
)
const
{
return
opr
==
rhs
.
opr
&&
comp_node
==
rhs
.
comp_node
;
}
struct
Hash
{
size_t
operator
()(
const
CallbackCallerKey
&
b
)
const
{
return
hash_pair_combine
(
mgb
::
hash
(
b
.
opr
),
mgb
::
hash
(
b
.
comp_node
));
}
};
};
struct
CallbackCallerVal
{
SmallVector
<
VarNode
*>
vars
;
//! indexs of vars in out_spec.
SmallVector
<
SmallVector
<
size_t
>>
indexs
;
};
/*!
* Components for implementing algorithms on a computing graph.
*
...
...
src/core/test/graph/misc.cpp
浏览文件 @
aeb7980b
...
...
@@ -17,6 +17,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/graph/helper.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/graph/event.h"
...
...
@@ -30,6 +31,7 @@
#include <atomic>
#include <chrono>
#include <array>
#include <memory>
using
namespace
mgb
;
...
...
@@ -2236,4 +2238,52 @@ TEST(TestGraph, FreeBias) {
}
}
TEST
(
TestGraph
,
CallbackCaller
)
{
using
namespace
opr
;
auto
cns
=
load_multiple_xpus
(
3
);
constexpr
size_t
C1
=
20
,
C2
=
30
,
C3
=
10
,
C4
=
40
;
constexpr
size_t
N
=
2
,
C
=
C1
+
C2
;
HostTensorGenerator
<>
gen
;
auto
host_opr0
=
gen
({
N
,
C
},
cns
[
0
]);
auto
graph
=
ComputingGraph
::
make
();
SymbolVar
opr0
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_opr0
,
{
"opr0"
});
auto
spl0
=
opr
::
Split
::
make
(
opr0
,
Split
::
Options
::
make_partition
(
opr0
,
1
,
{
C1
,
C2
}),
OperatorNodeConfig
(
"split0"
).
comp_node_arr
({
cns
[
1
],
cns
[
2
]}));
auto
spl1
=
opr
::
Split
::
make
(
opr0
,
Split
::
Options
::
make_partition
(
opr0
,
1
,
{
C3
,
C4
}),
OperatorNodeConfig
(
"split1"
));
HostTensorND
host_spl00
,
host_spl01
,
host_spl10
,
host_spl11
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
spl0
[
0
],
host_spl00
),
make_callback_copy
(
spl0
[
1
],
host_spl01
),
make_callback_copy
(
spl1
[
0
],
host_spl10
),
make_callback_copy
(
spl1
[
1
],
host_spl11
)});
func
->
execute
();
auto
o00
=
host_spl00
.
ptr
<
float
>
(),
o01
=
host_spl01
.
ptr
<
float
>
(),
o10
=
host_spl10
.
ptr
<
float
>
(),
o11
=
host_spl11
.
ptr
<
float
>
(),
c
=
host_opr0
->
ptr
<
float
>
();
for
(
size_t
i
=
0
,
it
=
host_opr0
->
layout
().
total_nr_elems
();
i
<
it
;
i
++
)
{
auto
ch
=
i
%
C
;
auto
n
=
i
/
C
;
if
(
ch
<
C1
)
{
MGB_ASSERT_FLOAT_EQ
(
o00
[
n
*
C1
+
ch
],
c
[
i
])
<<
ssprintf
(
"failed at %zd"
,
i
);
}
else
{
MGB_ASSERT_FLOAT_EQ
(
o01
[
n
*
C2
+
ch
-
C1
],
c
[
i
])
<<
ssprintf
(
"failed at %zd"
,
i
);
}
if
(
ch
<
C3
)
{
MGB_ASSERT_FLOAT_EQ
(
o10
[
n
*
C3
+
ch
],
c
[
i
])
<<
ssprintf
(
"failed at %zd"
,
i
);
}
else
{
MGB_ASSERT_FLOAT_EQ
(
o11
[
n
*
C4
+
ch
-
C3
],
c
[
i
])
<<
ssprintf
(
"failed at %zd"
,
i
);
}
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录