Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
14d8b709
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
14d8b709
编写于
12月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge/imperative): add mini graph to partially replace proxy graph
GitOrigin-RevId: 73e2529ba53ccb6c0607f52aee40e69e2c289343
上级
c294b9d1
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
799 addition
and
9 deletion
+799
-9
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+4
-1
imperative/src/impl/proxy_graph/common.h
imperative/src/impl/proxy_graph/common.h
+10
-0
imperative/src/impl/proxy_graph/mini_graph.h
imperative/src/impl/proxy_graph/mini_graph.h
+617
-0
imperative/src/impl/proxy_graph/proxy_graph.cpp
imperative/src/impl/proxy_graph/proxy_graph.cpp
+27
-0
imperative/src/impl/proxy_graph/proxy_graph_base.h
imperative/src/impl/proxy_graph/proxy_graph_base.h
+118
-0
imperative/src/impl/proxy_graph_detail.cpp
imperative/src/impl/proxy_graph_detail.cpp
+5
-5
imperative/src/include/megbrain/imperative/physical_tensor.h
imperative/src/include/megbrain/imperative/physical_tensor.h
+8
-0
src/core/include/megbrain/graph/static_infer.h
src/core/include/megbrain/graph/static_infer.h
+5
-2
src/core/include/megbrain/graph/var_node.h
src/core/include/megbrain/graph/var_node.h
+5
-1
未找到文件。
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
14d8b709
...
...
@@ -258,6 +258,9 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice =
MGB_LOCK_GUARD
(
m_mutex
);
dest
->
value_fetched
=
ptr
->
value_fetched
();
// update tensor desc for static infer
// if (dest->desc.layout.ndim) {
// mgb_assert(dest->desc.layout.eq_shape(ptr->layout()));
// }
dest
->
desc
.
layout
=
ptr
->
layout
();
dest
->
desc
.
comp_node
=
ptr
->
comp_node
();
dest
->
ptr
=
std
::
move
(
ptr
);
...
...
@@ -363,7 +366,7 @@ void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) {
}
inputs
.
push_back
(
i
->
ptr
);
}
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
path
.
op
,
inputs
);
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
path
.
op
,
inputs
);
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
out_ptr
=
path
.
outputs
[
i
].
lock
();
if
(
out_ptr
)
{
...
...
imperative/src/impl/proxy_graph/common.h
0 → 100644
浏览文件 @
14d8b709
namespace
mgb
::
imperative
::
proxy_graph
{
// a "namespace" struct to simplify friend declaration,
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph
struct
ProxyGraph
{
struct
InputPlaceholder
;
struct
MiniGraph
;
};
}
// namespace mgb::imperative::proxy_graph
imperative/src/impl/proxy_graph/mini_graph.h
0 → 100644
浏览文件 @
14d8b709
此差异已折叠。
点击以展开。
imperative/src/impl/proxy_graph/proxy_graph.cpp
0 → 100644
浏览文件 @
14d8b709
#include "./mini_graph.h"
// #include "../proxy_graph.h"
namespace
mgb
::
imperative
::
proxy_graph
{
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ProxyGraph
::
InputPlaceholder
);
thread_local
std
::
unique_ptr
<
ProxyGraphTypeI
>
ProxyGraphTypeI
::
sm_instance
=
{};
}
// namespace mgb::imperative::proxy_graph
namespace
mgb
::
imperative
::
proxy_graph_detail
{
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
ret
=
proxy_graph
::
ProxyGraphTypeI
::
inst
().
infer_output_attrs_fallible
(
def
,
inputs
);
// auto ref = ProxyGraph::get_default_graph()->infer_output_attrs_fallible(def, inputs);
// auto& [a, _1] = ret;
// auto& [b, _2] = ref;
// if (a.size() != b.size()) mgb_trap();
// for (size_t i = 0; i < a.size(); ++i) {
// if (a[i].layout.dtype != b[i].layout.dtype) mgb_trap();
// if (a[i].comp_node != b[i].comp_node) mgb_trap();
// if (!a[i].layout.eq_shape(b[i].layout)) mgb_trap();
// }
return
ret
;
}
}
// namespace mgb::imperative::proxy_graph_detail
imperative/src/impl/proxy_graph/proxy_graph_base.h
0 → 100644
浏览文件 @
14d8b709
#include "megbrain/graph/cg.h"
namespace
mgb
::
imperative
::
proxy_graph
{
using
cg
::
VarNode
;
struct
ExecEnvBase
:
cg
::
GraphExecutable
::
ExecEnv
{
void
dispatch_on_comp_node
(
CompNode
,
Task
&&
task
)
override
{
task
();
}
void
dispatch_on_comp_node_with_mask
(
CompNode
,
Task
&&
,
cg
::
ExecutionMask
*
)
override
{
mgb_assert
(
0
);}
void
pause_exec
()
override
{
mgb_assert
(
0
);}
void
resume_exec
()
override
{
mgb_assert
(
0
);}
};
struct
StaticInferManagerBase
:
cg
::
static_infer
::
StaticInferManager
{
protected:
void
register_shape_infer
(
VarNode
*
,
const
cg
::
static_infer
::
ShapeInferDesc
&
)
override
{
mgb_assert
(
0
);};
void
register_value_infer
(
VarNode
*
,
const
cg
::
static_infer
::
ValueInferDesc
&
)
override
{
mgb_assert
(
0
);};
cg
::
static_infer
::
InferType
get_infer_type
(
VarNode
*
)
override
{
mgb_assert
(
0
);};
const
TensorShape
&
infer_shape
(
VarNode
*
)
override
{
mgb_assert
(
0
);}
const
TensorShape
*
infer_shape_fallible
(
VarNode
*
)
override
{
mgb_assert
(
0
);}
const
DeviceTensorND
&
infer_value
(
VarNode
*
)
override
{
mgb_assert
(
0
);}
const
DeviceTensorND
*
infer_value_fallible
(
VarNode
*
)
override
{
mgb_assert
(
0
);}
cg
::
static_infer
::
DepVal
get_rt_static_source_deps
(
const
cg
::
static_infer
::
DepElement
&
)
override
{
mgb_assert
(
0
);}
};
struct
SeqCompNodeOptimizerBase
:
cg
::
SeqCompNodeOptimizer
{
protected:
void
register_stream_var
(
VarNode
*
,
StreamPropType
)
override
{}
void
register_propagate_function
(
VarNode
*
,
PropFunction
)
override
{}
StreamPropType
stream_prop_type
(
VarNode
*
)
override
{
mgb_assert
(
0
);}
};
struct
ProxyGraphBase
:
cg
::
ComputingGraph
{
private:
VarReceiverInfo
m_var_receiver_info
;
SeqCompNodeOptimizerBase
m_seq_comp_node_optimizer
;
StaticInferManagerBase
m_static_infer_manager
;
protected:
MemPool
<
VarNode
>
m_var_node_pool
;
ProxyGraphBase
()
{
options
().
imperative_proxy_graph
=
true
;
options
().
no_force_inplace
=
true
;
options
().
log_level
=
0
;
m_var_receiver_info
.
dev_value
=
1
;
m_var_receiver_info
.
allow_empty_value
=
1
;
}
void
*
alloc_varnode_storage
()
override
{
return
m_var_node_pool
.
alloc_raw
();
}
void
free_varnode_storage
(
void
*
ptr
)
override
{
m_var_node_pool
.
free_raw
(
ptr
);
}
const
VarReceiverInfo
&
var_receiver_in_current_comp_seq
(
const
VarNode
*
var
)
const
override
{
return
m_var_receiver_info
;
}
cg
::
static_infer
::
StaticInferManager
&
static_infer_manager
()
override
{
return
m_static_infer_manager
;
}
cg
::
SeqCompNodeOptimizer
&
seq_comp_node_optimizer
()
override
{
return
m_seq_comp_node_optimizer
;
}
std
::
shared_ptr
<
void
>
on_comp_node_finalize
()
override
{
return
{};
}
std
::
unique_ptr
<
cg
::
AsyncExecutable
>
compile
(
const
OutputSpec
&
)
override
{
mgb_assert
(
0
);}
SmallVector
<
std
::
unique_ptr
<
cg
::
AsyncExecutable
>>
compile_multi_part
(
const
SmallVector
<
OutputSpec
>&
)
override
{
mgb_assert
(
0
);}
cg
::
AsyncExecutable
*
current_comp_seq
()
override
{
mgb_assert
(
0
);}
std
::
string
get_mem_allocation_info
()
const
override
{
mgb_assert
(
0
);}
VarNode
*
find_var_by_id
(
size_t
)
const
override
{
mgb_assert
(
0
);}
void
share_device_memory_with
(
ComputingGraph
&
)
override
{
mgb_assert
(
0
);}
void
set_device_memory_allocator
(
std
::
shared_ptr
<
cg
::
DeviceMemoryAllocator
>
)
override
{
mgb_assert
(
0
);}
size_t
get_device_memory_size
(
CompNode
)
override
{
mgb_assert
(
0
);}
size_t
clear_device_memory
()
override
{
mgb_assert
(
0
);}
void
set_as_subgraph
(
ComputingGraph
&
)
override
{
mgb_assert
(
0
);}
void
record_async_error
(
std
::
unique_ptr
<
MegBrainError
>
)
override
{
mgb_assert
(
0
);}
};
MGB_DEFINE_OPR_CLASS
(
ProxyGraph
::
InputPlaceholder
,
cg
::
OperatorNodeBase
)
// {
void
on_output_comp_node_stream_changed
()
override
{
mgb_assert
(
0
);}
void
init_output_comp_node
()
override
{}
void
init_output_format
()
override
{}
void
init_output_dtype
()
override
{}
void
init_output_static_infer_desc
()
override
{}
void
init_output_mem_plan
(
bool
)
override
{
mgb_assert
(
0
);}
void
do_execute
(
ExecEnv
&
)
override
{
mgb_assert
(
0
);}
public:
InputPlaceholder
(
cg
::
ComputingGraph
&
graph
)
:
Super
(
&
graph
,
{},
"placeholder"
,
{})
{
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
);
// never dedup
add_equivalence_component
<
ScalarHash
<
void
*>>
(
this
);
}
InputPlaceholder
(
cg
::
ComputingGraph
&
graph
,
DType
dtype
,
CompNode
cn
)
:
InputPlaceholder
(
graph
)
{
output
(
0
)
->
dtype
(
dtype
).
comp_node
(
cn
);
}
};
using
InputPlaceholder
=
ProxyGraph
::
InputPlaceholder
;
}
// namespace mgb::imperative::proxy_graph
imperative/src/impl/proxy_graph_detail.cpp
浏览文件 @
14d8b709
...
...
@@ -80,11 +80,11 @@ apply_on_physical_tensor(const OpDef& def,
return
outputs
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
graph
=
ProxyGraph
::
get_default_graph
();
return
graph
->
infer_output_attrs_fallible
(
def
,
inputs
);
}
//
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
//
const SmallVector<LogicalTensorDesc>& inputs) {
//
auto&& graph = ProxyGraph::get_default_graph();
//
return graph->infer_output_attrs_fallible(def, inputs);
//
}
namespace
{
...
...
imperative/src/include/megbrain/imperative/physical_tensor.h
浏览文件 @
14d8b709
...
...
@@ -89,10 +89,18 @@ public:
return
m_blob
->
comp_node
();
}
DType
dtype
()
const
{
return
m_layout
.
dtype
;
}
TensorLayout
layout
()
const
{
return
m_layout
;
}
const
TensorShape
&
shape
()
const
{
return
m_layout
;
}
DeviceTensorND
dev_tensor
();
static
TensorPtr
make_scalar
(
DTypeScalar
value
,
CompNode
cn
);
...
...
src/core/include/megbrain/graph/static_infer.h
浏览文件 @
14d8b709
...
...
@@ -16,7 +16,10 @@
namespace
mgb
{
namespace
imperative
{
class
ProxyGraph
;
class
ProxyGraph
;
namespace
proxy_graph
{
class
ProxyGraph
;
}
// namespace proxy_graph
}
// namespace imperative
namespace
cg
{
...
...
@@ -56,6 +59,7 @@ namespace static_infer {
friend
class
StaticInferManagerImpl
;
friend
class
imperative
::
ProxyGraph
;
friend
class
imperative
::
proxy_graph
::
ProxyGraph
;
public:
/*!
...
...
@@ -342,4 +346,3 @@ using StaticInferInpVal = static_infer::InpVal;
}
// mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/core/include/megbrain/graph/var_node.h
浏览文件 @
14d8b709
...
...
@@ -23,7 +23,10 @@
namespace
mgb
{
namespace
imperative
{
class
ProxyGraph
;
class
ProxyGraph
;
namespace
proxy_graph
{
class
ProxyGraph
;
}
}
// namespace imperative
namespace
cg
{
...
...
@@ -587,6 +590,7 @@ class VarNode final: public GraphNodeBase {
friend
class
EagerEvalManager
;
friend
class
MemAllocPlan
;
friend
class
imperative
::
ProxyGraph
;
friend
class
imperative
::
proxy_graph
::
ProxyGraph
;
};
enum
class
VarNode
::
Flag
:
uint32_t
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录