Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a5a60679
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a5a60679
编写于
12月 22, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/interpreter): add more dispatch mode in apply_op
GitOrigin-RevId: 2663504470e6cf83a4ce5d84131f0cbd2f39716e
上级
45e20602
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
287 addition
and
76 deletion
+287
-76
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+128
-57
imperative/src/impl/interpreter_impl.h
imperative/src/impl/interpreter_impl.h
+11
-0
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+14
-0
imperative/src/impl/op_trait.cpp
imperative/src/impl/op_trait.cpp
+14
-1
imperative/src/impl/op_trait.h
imperative/src/impl/op_trait.h
+8
-0
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+33
-8
imperative/src/impl/ops/tensor_manip.cpp
imperative/src/impl/ops/tensor_manip.cpp
+53
-10
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+26
-0
未找到文件。
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
a5a60679
...
...
@@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() {
return
inst_
;
}
void
*
ChannelImpl
::
put
(
const
HostTensorND
&
value
,
bool
no_cache
)
{
Handle
ChannelImpl
::
put
(
const
HostTensorND
&
value
,
bool
no_cache
)
{
auto
info
=
alloc
();
info
->
desc
.
layout
=
value
.
layout
();
info
->
desc
.
comp_node
=
value
.
comp_node
();
...
...
@@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) {
return
info
;
}
void
*
ChannelImpl
::
put
(
const
DeviceTensorND
&
data
)
{
Handle
ChannelImpl
::
put
(
const
DeviceTensorND
&
data
)
{
auto
info
=
alloc
();
info
->
desc
.
layout
=
data
.
layout
();
info
->
desc
.
comp_node
=
data
.
comp_node
();
...
...
@@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) {
return
info
;
}
void
ChannelImpl
::
del
(
void
*
handle
)
{
void
ChannelImpl
::
del
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
erase
(
handle
),
"invalid handle: %p"
,
handle
);
m_buffer
.
enqueue
(
Del
{
reinterpret_cast
<
TensorInfo
*>
(
handle
)});
}
void
ChannelImpl
::
swap_in
(
void
*
handle
)
{
void
ChannelImpl
::
swap_in
(
Handle
handle
)
{
if
(
m_enable_evict
&
SWAP
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
...
...
@@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) {
}
}
void
ChannelImpl
::
swap_out
(
void
*
handle
)
{
void
ChannelImpl
::
swap_out
(
Handle
handle
)
{
if
(
m_enable_evict
&
SWAP
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
...
...
@@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) {
}
}
void
ChannelImpl
::
drop
(
void
*
handle
)
{
void
ChannelImpl
::
drop
(
Handle
handle
)
{
if
(
m_enable_evict
&
DROP
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
...
...
@@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) {
}
}
SmallVector
<
void
*>
ChannelImpl
::
apply_op
(
void
ChannelImpl
::
dispatch_default_cpu
(
std
::
shared_ptr
<
OpDef
>
op
,
const
SmallVector
<
void
*>&
inputs
)
{
for
(
auto
i
:
inputs
)
{
mgb_assert
(
m_valid_handle
.
find
(
i
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
i
);
}
SmallVector
<
TensorInfo
*>
input_infos
;
input_infos
.
reserve
(
inputs
.
size
());
SmallVector
<
LogicalTensorDesc
>
input_descs
;
input_descs
.
reserve
(
inputs
.
size
());
const
SmallVector
<
TensorInfo
*>&
input_infos
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
SmallVector
<
Handle
>*
outputs
)
{
auto
[
output_descs
,
validated
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
SmallVector
<
DeviceTensorND
>
input_tensornds
;
input_tensornds
.
reserve
(
input_descs
.
size
());
CompNode
output_cn
;
{
MGB_LOCK_GUARD
(
m_mutex
);
for
(
auto
i
:
inputs
)
{
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
i
);
mgb_assert
(
!
info
->
invalid
,
"Invalid tensor, unable to apply_op!"
);
input_infos
.
push_back
(
info
);
input_descs
.
push_back
(
info
->
desc
);
for
(
auto
&&
info
:
input_infos
)
{
mgb_assert
(
info
->
ptr
,
"invalid tensor ptr!"
);
if
(
!
output_cn
.
valid
())
{
output_cn
=
info
->
ptr
->
comp_node
();
}
else
{
mgb_assert
(
output_cn
==
info
->
ptr
->
comp_node
(),
"cannot decide output comp node"
);
}
mgb_assert
(
info
->
ptr
->
try_get_value
(),
"no valid host value"
);
input_tensornds
.
emplace_back
(
info
->
ptr
->
get_value
().
proxy_to_default_cpu
());
}
}
outputs
->
reserve
(
output_descs
.
size
());
SmallVector
<
DeviceTensorND
>
output_tensornds
;
output_tensornds
.
reserve
(
output_descs
.
size
());
for
(
auto
&&
desc
:
output_descs
)
{
// TODO: may conflict with condtake, which need alloc inside
mgb_assert
(
!
desc
.
layout
.
is_empty
());
// use HostTensorND alloc_host for cuda pinned memory
output_tensornds
.
emplace_back
(
HostTensorND
(
output_cn
,
desc
.
layout
).
proxy_to_default_cpu
());
}
OpDef
::
apply_on_device_tensornd
(
*
op
,
input_tensornds
,
&
output_tensornds
);
SmallVector
<
TensorInfo
*>
output_infos
;
output_infos
.
reserve
(
output_descs
.
size
());
for
(
auto
&&
tensornd
:
output_tensornds
)
{
// tensornd -> host_tensornd
HostTensorND
host_tensornd
=
HostTensorND
::
make_proxy
(
tensornd
)
.
proxy_to_comp_node
(
output_cn
);
// tensornd -> desc
LogicalTensorDesc
desc
=
{
tensornd
.
layout
(),
output_cn
,
tensornd
};
// tensornd -> tensor
auto
info
=
alloc
();
info
->
desc
=
desc
;
m_valid_handle
.
insert
(
info
);
output_infos
.
push_back
(
info
);
info
->
ptr
=
Tensor
::
make
(
host_tensornd
,
true
);
// host_only=true
info
->
value_fetched
=
true
;
outputs
->
push_back
(
info
);
}
if
(
m_enable_evict
&
DROP
)
{
for
(
auto
out
:
output_infos
)
{
out
->
path
.
op
=
op
;
for
(
auto
out_
:
output_infos
)
{
out
->
path
.
outputs
.
push_back
(
m_st
.
at
(
out_
));
}
for
(
auto
inp
:
input_infos
)
{
out
->
path
.
inputs
.
push_back
(
m_st
.
at
(
inp
));
inp
->
path
.
dep_outputs
.
push_back
(
m_st
.
at
(
out
));
}
}
}
}
void
ChannelImpl
::
dispatch_kernel
(
std
::
shared_ptr
<
OpDef
>
op
,
const
SmallVector
<
TensorInfo
*>&
input_infos
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
SmallVector
<
Handle
>*
outputs
)
{
auto
[
output_descs
,
validated
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
ApplyOp
cmd
{
std
::
move
(
op
)};
cmd
.
inputs
=
std
::
move
(
input_infos
);
cmd
.
outputs
.
reserve
(
output_descs
.
size
());
SmallVector
<
void
*>
outputs
;
// FIXME: remove this check when op check is correct
bool
validated_bkp
=
true
;
for
(
size_t
i
=
0
;
i
<
output_descs
.
size
();
i
++
)
{
auto
&&
desc
=
output_descs
[
i
];
if
(
desc
.
layout
.
ndim
==
0
)
{
validated_bkp
=
false
;
}
outputs
->
reserve
(
output_descs
.
size
());
for
(
auto
&&
desc
:
output_descs
)
{
auto
info
=
alloc
();
info
->
desc
=
desc
;
m_valid_handle
.
insert
(
info
);
cmd
.
outputs
.
push_back
(
info
);
outputs
.
push_back
(
info
);
outputs
->
push_back
(
info
);
}
if
(
m_enable_evict
&
DROP
)
{
for
(
auto
out
:
cmd
.
outputs
)
{
...
...
@@ -130,20 +176,55 @@ SmallVector<void*> ChannelImpl::apply_op(
}
}
m_buffer
.
enqueue
(
std
::
move
(
cmd
));
if
(
!
(
validated
&&
validated_bkp
)
&&
m_async_level
==
1
)
{
if
(
!
validated
&&
m_async_level
==
1
)
{
sync
();
}
else
if
(
m_async_level
==
0
)
{
sync
();
// check device error
for
(
auto
&&
oup
:
outputs
)
{
for
(
auto
&&
oup
:
*
outputs
)
{
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
oup
);
info
->
ptr
->
comp_node
().
sync
();
}
}
}
SmallVector
<
Handle
>
ChannelImpl
::
apply_op
(
std
::
shared_ptr
<
OpDef
>
op
,
const
SmallVector
<
Handle
>&
inputs
)
{
for
(
auto
i
:
inputs
)
{
mgb_assert
(
m_valid_handle
.
find
(
i
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
i
);
}
SmallVector
<
TensorInfo
*>
input_infos
;
input_infos
.
reserve
(
inputs
.
size
());
SmallVector
<
LogicalTensorDesc
>
input_descs
;
input_descs
.
reserve
(
inputs
.
size
());
{
MGB_LOCK_GUARD
(
m_mutex
);
for
(
auto
i
:
inputs
)
{
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
i
);
mgb_assert
(
!
info
->
invalid
,
"Invalid tensor, unable to apply_op!"
);
input_infos
.
push_back
(
info
);
input_descs
.
push_back
(
info
->
desc
);
}
}
SmallVector
<
Handle
>
outputs
;
switch
(
OpDef
::
decide_dispatch_mode
(
*
op
,
input_descs
))
{
case
DEFAULT_CPU
:
{
dispatch_default_cpu
(
op
,
input_infos
,
input_descs
,
&
outputs
);
break
;
}
case
KERNEL
:
{
dispatch_kernel
(
op
,
input_infos
,
input_descs
,
&
outputs
);
break
;
}
}
mgb_assert
(
outputs
.
size
()
>
0
,
"Invalid dispatch mode!"
);
return
outputs
;
}
HostTensorND
ChannelImpl
::
get_value
(
void
*
handle
)
{
HostTensorND
ChannelImpl
::
get_value
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
...
...
@@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) {
return
info
->
ptr
->
get_value
();
}
TensorShape
ChannelImpl
::
get_shape
(
void
*
handle
)
{
TensorShape
ChannelImpl
::
get_shape
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
...
...
@@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) {
return
ret
;
}
DType
ChannelImpl
::
get_dtype
(
void
*
handle
)
{
DType
ChannelImpl
::
get_dtype
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
...
...
@@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) {
return
ret
;
}
CompNode
ChannelImpl
::
get_device
(
void
*
handle
)
{
CompNode
ChannelImpl
::
get_device
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
...
...
@@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) {
return
ret
;
}
DeviceTensorND
ChannelImpl
::
get_dev_tensor
(
void
*
handle
)
{
DeviceTensorND
ChannelImpl
::
get_dev_tensor
(
Handle
handle
)
{
mgb_assert
(
m_valid_handle
.
find
(
handle
)
!=
m_valid_handle
.
end
(),
"invalid handle: %p"
,
handle
);
auto
info
=
reinterpret_cast
<
TensorInfo
*>
(
handle
);
...
...
@@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() {
}
void
ChannelImpl
::
produce_tensor
(
TensorInfo
*
dest
,
TensorPtr
ptr
,
bool
notice
=
true
)
{
if
(
notice
)
{
MGB_LOCK_GUARD
(
m_mutex
);
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
// if (dest->desc.layout.ndim) {
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout()));
// }
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
ptr
=
std
::
move
(
ptr
);
if
(
m_waitee
==
dest
)
{
m_cv
.
notify_all
();
}
}
else
{
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
ptr
=
std
::
move
(
ptr
);
auto
lock
=
notice
?
std
::
unique_lock
<
std
::
mutex
>
(
m_mutex
)
:
std
::
unique_lock
<
std
::
mutex
>
();
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
ptr
=
std
::
move
(
ptr
);
if
(
notice
&&
m_waitee
==
dest
)
{
m_cv
.
notify_all
();
}
}
...
...
@@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) {
dest
->
evict_type
=
SWAP
;
dest
->
value_fetched
=
false
;
// TODO: swap in parallel
dest
->
h_value
.
copy_from
(
dest
->
ptr
->
dev_tensor
()).
sync
();
dest
->
h_value
=
dest
->
ptr
->
get_value
();
dest
->
ptr
.
reset
();
}
...
...
imperative/src/impl/interpreter_impl.h
浏览文件 @
a5a60679
...
...
@@ -198,6 +198,17 @@ private:
void
do_drop
(
TensorInfo
*
dest
);
void
regenerate
(
TensorInfo
*
dest
,
bool
must_drop
);
void
dispatch_default_cpu
(
std
::
shared_ptr
<
OpDef
>
op
,
const
SmallVector
<
TensorInfo
*>&
input_infos
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
SmallVector
<
Handle
>*
outputs
);
void
dispatch_kernel
(
std
::
shared_ptr
<
OpDef
>
op
,
const
SmallVector
<
TensorInfo
*>&
input_infos
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
SmallVector
<
Handle
>*
outputs
);
std
::
mutex
m_mutex
;
std
::
condition_variable
m_cv
;
MemPool
<
TensorInfo
>
m_pool
;
...
...
imperative/src/impl/op_def.cpp
浏览文件 @
a5a60679
...
...
@@ -30,12 +30,26 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node(
return
trait
->
make_from_op_node
(
node
);
}
DispatchMode
OpDef
::
decide_dispatch_mode
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
return
def
.
trait
()
->
decide_dispatch_mode
(
def
,
inputs
);
}
SmallVector
<
TensorPtr
>
OpDef
::
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
)
{
return
def
.
trait
()
->
apply_on_physical_tensor
(
def
,
std
::
move
(
inputs
));
}
void
OpDef
::
apply_on_device_tensornd
(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
{
def
.
trait
()
->
apply_on_device_tensornd
(
def
,
inputs
,
outputs
);
return
;
}
VarNodeArray
OpDef
::
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
imperative/src/impl/op_trait.cpp
浏览文件 @
a5a60679
...
...
@@ -9,12 +9,16 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <exception>
#include <sstream>
#include <stdexcept>
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/tensor.h"
#include "./op_trait.h"
#include "megbrain/imperative/proxy_graph_detail.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}
DispatchMode
fallback_decide_dispatch_mode
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
return
KERNEL
;
}
OpTraitRegistry
&
OpTraitRegistry
::
fallback
()
{
if
(
trait
->
apply_on_var_node
)
{
// fallback to proxy graph impl
...
...
@@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() {
proxy_graph_detail
::
make_backward_graph
;
}
}
if
(
!
trait
->
decide_dispatch_mode
)
{
trait
->
decide_dispatch_mode
=
fallback_decide_dispatch_mode
;
}
return
*
this
;
}
...
...
imperative/src/impl/op_trait.h
浏览文件 @
a5a60679
...
...
@@ -60,8 +60,12 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
using
OpDefMaker
=
detail
::
OpMeth
<
decltype
(
OpDef
::
make_from_op_node
)
>
;
using
DecideDispatchMode
=
detail
::
OpMeth
<
decltype
(
OpDef
::
decide_dispatch_mode
)
>
;
using
ApplyOnPhysicalTensor
=
detail
::
OpMeth
<
decltype
(
OpDef
::
apply_on_physical_tensor
)
>
;
using
ApplyOnDeviceTensorND
=
detail
::
OpMeth
<
decltype
(
OpDef
::
apply_on_device_tensornd
)
>
;
using
ApplyOnVarNode
=
detail
::
OpMeth
<
decltype
(
OpDef
::
apply_on_var_node
)
>
;
using
InferOutputAttrsFallible
=
detail
::
OpMeth
<
...
...
@@ -74,7 +78,9 @@ using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
struct
OpTrait
{
const
char
*
name
;
OpDefMaker
make_from_op_node
;
DecideDispatchMode
decide_dispatch_mode
;
ApplyOnPhysicalTensor
apply_on_physical_tensor
;
ApplyOnDeviceTensorND
apply_on_device_tensornd
;
ApplyOnVarNode
apply_on_var_node
;
InferOutputAttrsFallible
infer_output_attrs_fallible
;
GradMaker
make_backward_graph
;
...
...
@@ -88,7 +94,9 @@ struct OpTrait {
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \
...
...
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
a5a60679
...
...
@@ -68,23 +68,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{{{
TensorLayout
(
out_shape
,
out_dt
,
inputs
[
0
].
layout
.
format
),
out_cn
}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
DispatchMode
decide_dispatch_mode
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
bool
host_computable
=
true
;
constexpr
int
size_threshhold
=
TensorShape
::
MAX_NDIM
;
for
(
auto
&&
inp
:
inputs
)
{
if
(
inp
.
value
.
empty
()
||
inp
.
value
.
layout
().
ndim
==
0
||
inp
.
value
.
layout
().
total_nr_elems
()
>
size_threshhold
)
{
host_computable
=
false
;
break
;
}
}
return
host_computable
?
DEFAULT_CPU
:
KERNEL
;
}
void
apply_on_device_tensornd
(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Elemwise
>
();
auto
trait
=
megdnn
::
Elemwise
::
ModeTrait
::
from_mode
(
op_def
.
mode
);
mgb_assert
(
inputs
.
size
()
==
trait
.
arity
,
"%s expects %u inputs; got %zu actually"
,
trait
.
name
,
trait
.
arity
,
inputs
.
size
());
auto
&&
dnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
megdnn
::
Elemwise
>
(
inputs
[
0
].
comp_node
());
opr
::
Elemwise
::
perform
(
op_def
.
mode
,
(
*
outputs
)[
0
],
inputs
,
dnn_opr
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
DeviceTensorND
out
;
SmallVector
<
DeviceTensorND
>
dt_inputs
(
inputs
.
size
());
SmallVector
<
DeviceTensorND
>
inp_tensornds
(
inputs
.
size
());
for
(
unsigned
i
=
0
;
i
<
inputs
.
size
();
++
i
){
dt_input
s
[
i
]
=
inputs
[
i
]
->
dev_tensor
();
inp_tensornd
s
[
i
]
=
inputs
[
i
]
->
dev_tensor
();
}
auto
&&
dnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
megdnn
::
Elemwise
>
(
inputs
[
0
]
->
comp_node
())
;
opr
::
Elemwise
::
perform
(
op_def
.
mode
,
out
,
dt_inputs
,
dnn_opr
);
return
{
Tensor
::
make
(
ou
t
)};
SmallVector
<
DeviceTensorND
>
oup_tensornds
=
{{
inp_tensornds
[
0
].
comp_node
(),
inp_tensornds
[
0
].
dtype
()}}
;
apply_on_device_tensornd
(
def
,
inp_tensornds
,
&
oup_tensornds
);
return
{
Tensor
::
make
(
ou
p_tensornds
[
0
]
)};
}
MGB_DEFINE_OPR_CLASS
(
ForceInplaceElemwise
,
cg
::
SingleCNOperatorNodeBaseT
<
opr
::
mixin
::
MegDNNOprHolder
>
)
//{
...
...
@@ -214,8 +237,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_
OP_TRAIT_REG
(
Elemwise
,
Elemwise
,
opr
::
Elemwise
)
.
make_from_op_node
(
make_from_op_node
)
.
decide_dispatch_mode
(
decide_dispatch_mode
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_device_tensornd
(
apply_on_device_tensornd
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
...
...
imperative/src/impl/ops/tensor_manip.cpp
浏览文件 @
a5a60679
...
...
@@ -15,8 +15,8 @@
#include "../op_trait.h"
namespace
mgb
::
imperative
{
namespace
{
namespace
get_var_shape
{
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
@@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node(
return
opr
::
GetVarShape
::
make
(
inputs
,
op_def
.
param
()).
node
()
->
owner_opr
();
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
DispatchMode
decide_dispatch_mode
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
bool
host_computable
=
true
;
for
(
auto
&&
inp
:
inputs
)
{
// FIXME(czh): remove value chech after proxy graph's
// apply_on_device_tensornd is supported and output Tensor
// is made before add_task.
// then if layout is valid, ptr->layout must be ready
if
(
inp
.
value
.
empty
()
||
inp
.
value
.
layout
().
ndim
==
0
)
{
host_computable
=
false
;
break
;
}
}
return
host_computable
?
DEFAULT_CPU
:
KERNEL
;
}
void
apply_on_device_tensornd
(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
GetVarShape
>
();
mgb_assert
(
inputs
.
size
()
==
1
,
"GetVarShape take 1 input, got %lu"
,
inputs
.
size
());
auto
&&
inp
=
inputs
[
0
];
auto
&&
shp
=
inp
->
layout
();
auto
&&
shp
=
inp
.
layout
();
mgb_assert
(
shp
.
ndim
!=
0
,
"input shape invalid"
);
mgb_assert
((
*
outputs
)[
0
].
comp_node
()
==
CompNode
::
default_cpu
(),
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."
);
HostTensorND
hv
;
if
(
op_def
.
axis
==
opr
::
GetVarShape
::
Param
::
INVALID_AXIS
){
hv
=
HostTensorND
(
inp
->
comp_node
(),
{
shp
.
ndim
},
dtype
::
Int32
());
if
(
op_def
.
axis
==
opr
::
GetVarShape
::
Param
::
INVALID_AXIS
)
{
hv
=
HostTensorND
(
CompNode
::
default_cpu
(),
{
shp
.
ndim
},
dtype
::
Int32
());
auto
*
ptr
=
hv
.
ptr
<
dt_int32
>
();
for
(
size_t
i
=
0
;
i
<
shp
.
ndim
;
++
i
)
{
ptr
[
i
]
=
shp
.
shape
[
i
];
...
...
@@ -45,11 +66,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
axis
+=
shp
.
ndim
;
}
mgb_assert
(
axis
>=
0
&&
axis
<
(
int32_t
)
shp
.
ndim
);
hv
=
HostTensorND
(
inp
->
comp_node
(),
{
1
},
dtype
::
Int32
());
hv
=
HostTensorND
(
CompNode
::
default_cpu
(),
{
1
},
dtype
::
Int32
());
auto
*
ptr
=
hv
.
ptr
<
dt_int32
>
();
ptr
[
0
]
=
shp
.
shape
[
axis
];
}
return
{
Tensor
::
make
(
std
::
move
(
hv
))};
(
*
outputs
)[
0
]
=
DeviceTensorND
::
make_proxy
(
hv
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
SmallVector
<
DeviceTensorND
>
input_tensornds
;
input_tensornds
.
reserve
(
inputs
.
size
());
for
(
auto
&&
inp
:
inputs
)
{
input_tensornds
.
push_back
(
inp
->
dev_tensor
());
}
SmallVector
<
DeviceTensorND
>
output_tensornds
=
{{
CompNode
::
default_cpu
(),
dtype
::
Int32
()}};
apply_on_device_tensornd
(
def
,
input_tensornds
,
&
output_tensornds
);
// restore to input comp_node
HostTensorND
host_tensornd
=
HostTensorND
::
make_proxy
(
output_tensornds
[
0
])
.
proxy_to_comp_node
(
inputs
[
0
]
->
comp_node
());
return
{
Tensor
::
make
(
std
::
move
(
host_tensornd
))};
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
...
...
@@ -62,7 +101,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{{{
TensorLayout
(
dtype
::
Int32
()),
desc
.
comp_node
}},
false
};
}
DeviceTensorND
value
;
if
(
op_def
.
axis
==
opr
::
GetVarShape
::
Param
::
INVALID_AXIS
){
if
(
op_def
.
axis
==
opr
::
GetVarShape
::
Param
::
INVALID_AXIS
)
{
value
=
DeviceTensorND
(
CompNode
::
default_cpu
(),
{
desc
.
layout
.
ndim
},
dtype
::
Int32
());
auto
*
ptr
=
value
.
ptr
<
dt_int32
>
();
for
(
size_t
i
=
0
;
i
<
desc
.
layout
.
ndim
;
++
i
)
{
...
...
@@ -88,11 +127,15 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
OP_TRAIT_REG
(
GetVarShape
,
GetVarShape
,
opr
::
GetVarShape
)
.
make_from_op_node
(
make_from_op_node
)
.
decide_dispatch_mode
(
decide_dispatch_mode
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_device_tensornd
(
apply_on_device_tensornd
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// get_var_shape
namespace
param_pack
{
TensorShapeArray
get_shapes
(
const
std
::
vector
<
std
::
vector
<
size_t
>>&
shapes
)
{
TensorShapeArray
ret
;
for
(
auto
&&
i
:
shapes
)
{
...
...
@@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
OP_TRAIT_REG
(
ParamPackConcat
,
ParamPackConcat
,
mgb
::
opr
::
ParamPackConcat
)
.
apply_on_var_node
(
param_pack_concat_apply_on_var_node
)
.
fallback
();
}
//
namespace
}
//
param_pack
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
a5a60679
...
...
@@ -20,6 +20,11 @@ namespace imperative {
class
OpDef
;
struct
OpTrait
;
enum
DispatchMode
{
DEFAULT_CPU
=
0
,
KERNEL
=
1
};
struct
BackwardGraphResult
{
std
::
shared_ptr
<
OpDef
>
backward
;
std
::
vector
<
bool
>
save_for_backward
;
...
...
@@ -36,10 +41,31 @@ public:
static
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node
);
/*!
* \brief Decide which dispatch method to be used according to the inputs'
* host value and size.
*
* \param def Specific :c:expr:`OpDef` to be executed.
* \param inputs Input tensor descriptions.
* \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`.
*/
static
DispatchMode
decide_dispatch_mode
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
);
static
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
SmallVector
<
TensorPtr
>
inputs
);
/*!
* \brief Call the corresponding dnn op to calculate results. Output
* tensors' device memory should be allocated outside.
*/
static
void
apply_on_device_tensornd
(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
);
static
cg
::
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录