Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b4581788
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b4581788
编写于
9月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr): add mutable tensor opr
GitOrigin-RevId: 7f8a3d7b661c18fb407047a78b965630b52e61d9
上级
47fe7663
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
86 addition
and
0 deletion
+86
-0
imperative/src/impl/opr_utility.cpp
imperative/src/impl/opr_utility.cpp
+63
-0
imperative/src/include/megbrain/imperative/opr_utility.h
imperative/src/include/megbrain/imperative/opr_utility.h
+23
-0
未找到文件。
imperative/src/impl/opr_utility.cpp
浏览文件 @
b4581788
...
...
@@ -271,6 +271,69 @@ void NopCallback::do_execute(ExecEnv& env) {
env
.
dispatch_on_comp_node
(
cn
,
runner
);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
MutableTensor
);
MutableTensor
::
MutableTensor
(
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
DeviceTensorND
>
dev_tensor
,
std
::
shared_ptr
<
HostTensorND
>
host_tensor
,
const
OperatorNodeConfig
&
config
)
:
Super
(
&
graph
,
config
,
{},
{})
{
m_dev_tensor
=
dev_tensor
;
m_host_tensor
=
host_tensor
;
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
)
.
add_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
)
.
dtype
(
m_dev_tensor
->
dtype
());
add_equivalence_component
<
ScalarHash
<
const
void
*>>
(
this
);
}
SymbolVar
MutableTensor
::
make
(
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
DeviceTensorND
>
dev_tensor
,
std
::
shared_ptr
<
HostTensorND
>
host_tensor
,
const
OperatorNodeConfig
&
config
)
{
return
graph
.
insert_opr
(
std
::
make_unique
<
MutableTensor
>
(
graph
,
dev_tensor
,
host_tensor
,
config
))
->
output
(
0
);
}
void
MutableTensor
::
init_output_comp_node
()
{
if
(
config
().
has_comp_node_set
())
{
mgb_assert
(
config
().
get_single_comp_node
()
==
m_dev_tensor
->
comp_node
(),
"comp_node mismatch"
);
}
comp_node
(
m_dev_tensor
->
comp_node
());
}
cg
::
OperatorNodeBase
::
NodeProp
*
MutableTensor
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_flag
(
NodeProp
::
Flag
::
IMPURE_OUTPUT_MEM_PLAN
);
return
ret
;
}
void
MutableTensor
::
scn_do_execute
()
{
output
(
0
)
->
reset_dev_tensor_from_tensor
(
*
m_dev_tensor
);
}
void
MutableTensor
::
init_output_static_infer_desc
()
{
using
namespace
cg
::
static_infer
;
auto
&
mgr
=
owner_graph
()
->
static_infer_manager
();
auto
infer_shape
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
)
{
dest
=
m_dev_tensor
->
shape
();
return
true
;
};
mgr
.
register_shape_infer
(
output
(
0
),
{
SourceType
::
MUTABLE
,
{},
infer_shape
});
if
(
m_host_tensor
)
{
auto
infer_value
=
[
this
](
DeviceTensorND
&
dest
,
const
InpVal
&
)
{
if
(
!
m_host_tensor
->
layout
().
ndim
)
{
return
false
;
}
dest
=
m_host_tensor
->
proxy_to_default_cpu
();
return
true
;
};
mgr
.
register_value_infer
(
output
(
0
),
{
SourceType
::
MUTABLE
,
{},
infer_value
});
}
}
}
// namespace opr
}
// namespace mgb
...
...
imperative/src/include/megbrain/imperative/opr_utility.h
浏览文件 @
b4581788
...
...
@@ -16,6 +16,7 @@
#include "megbrain/opr/internal/identical_fwd.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/param_tag_defs.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h"
...
...
@@ -106,6 +107,28 @@ protected:
private
:
callback_t
m_callback
;
}
;
MGB_DEFINE_OPR_CLASS
(
MutableTensor
,
cg
::
SingleCNOperatorNodeBase
)
// {
public
:
MutableTensor
(
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
DeviceTensorND
>
dev_tensor
,
std
::
shared_ptr
<
HostTensorND
>
host_tensor
,
const
OperatorNodeConfig
&
config
);
static
SymbolVar
make
(
cg
::
ComputingGraph
&
graph
,
std
::
shared_ptr
<
DeviceTensorND
>
dev_tensor
,
std
::
shared_ptr
<
HostTensorND
>
host_tensor
=
{},
const
OperatorNodeConfig
&
config
=
{});
protected
:
void
init_output_comp_node
()
override
;
void
init_output_static_infer_desc
()
override
;
cg
::
OperatorNodeBase
::
NodeProp
*
do_make_node_prop
()
const
override
;
void
scn_do_execute
()
override
;
private
:
std
::
shared_ptr
<
DeviceTensorND
>
m_dev_tensor
;
std
::
shared_ptr
<
HostTensorND
>
m_host_tensor
;
}
;
}
// namespace opr
}
// namespace mgb
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录