Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fb20cb36
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fb20cb36
编写于
9月 15, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(mge/traced_module): update traced_module api doc
GitOrigin-RevId: 19a95d26c71e672376c5fda00a4e7dc6050e1c6a
上级
c7a8d945
变更
6
展开全部
显示空白变更内容
内联
并排
Showing
6 changed file
with
632 addition
and
80 deletion
+632
-80
imperative/python/megengine/__init__.py
imperative/python/megengine/__init__.py
+1
-0
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+49
-8
imperative/python/megengine/traced_module/fake_quant.py
imperative/python/megengine/traced_module/fake_quant.py
+5
-2
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+51
-17
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+21
-6
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+505
-47
未找到文件。
imperative/python/megengine/__init__.py
浏览文件 @
fb20cb36
...
...
@@ -130,3 +130,4 @@ import megengine.optimizer
import
megengine.quantization
import
megengine.random
import
megengine.utils
import
megengine.traced_module
imperative/python/megengine/traced_module/expr.py
浏览文件 @
fb20cb36
...
...
@@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str):
class
Expr
:
"""``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``."""
r
"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
"""
__total_id
=
0
inputs
=
None
# type: List[Node]
r
"""The input Nodes of this Expr."""
outputs
=
None
# type: List[Node]
r
"""The output Nodes of this Expr."""
const_val
=
None
# type: List[Any]
r
"""The non-tensor object in the input of the operation."""
arg_def
=
None
# type: TreeDef
r
"""The :class:`TreeDef` used to reconstruct the input of the operation."""
out_def
=
None
# type: TreeDef
r
"""The :class:`TreeDef` used to reconstruct the output of the operation."""
_top_graph
=
None
# type: weakref.ReferenceType
__total_id
=
0
def
__init__
(
self
)
->
None
:
self
.
_id
=
Expr
.
__total_id
...
...
@@ -125,6 +132,11 @@ class Expr:
return
inputs
,
{}
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
r
"""Replace the input Nodes of this Expr.
Args:
repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
"""
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
assert
type
(
node
)
==
type
(
repl_node
)
...
...
@@ -147,16 +159,19 @@ class Expr:
@
property
def
kwargs
(
self
):
r
"""Get the the keyword arguments of the operation corresponding to this Expr."""
_
,
kwargs
=
self
.
unflatten_args
(
self
.
inputs
)
return
kwargs
@
property
def
args
(
self
):
r
"""Get the the positional arguments of the operation corresponding to this Expr."""
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
return
args
@
property
def
top_graph
(
self
):
r
"""Get the parent graph of this Expr."""
if
self
.
_top_graph
:
return
self
.
_top_graph
()
return
None
...
...
@@ -168,17 +183,18 @@ class Expr:
return
state
@
classmethod
def
get_total
_id
(
cls
):
def
_get_next
_id
(
cls
):
return
cls
.
__total_id
@
classmethod
def
set_total
_id
(
cls
,
id
:
int
=
0
):
def
_set_next
_id
(
cls
,
id
:
int
=
0
):
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
# expr: None (i.e. fake expression which is used to mark input)
class
Input
(
Expr
):
r
"""A fake Expr which is used to mark the input of graph."""
name
=
None
def
__init__
(
self
,
name
=
None
,
type
=
None
,
orig_name
=
None
):
...
...
@@ -204,13 +220,15 @@ class Input(Expr):
return
expr
.
outputs
[
0
]
def
__repr__
(
self
):
return
"%{}:
\t
{} = Input(
{})"
.
format
(
self
.
_id
,
self
.
outputs
[
0
],
self
.
name
)
return
"%{}:
\t
{} = Input(
)"
.
format
(
self
.
_id
,
self
.
outputs
[
0
]
)
# expr: outputs = getattr(inputs[0], self.name)
class
GetAttr
(
Expr
):
name
=
None
r
"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
name
=
None
r
"""name: the qualified name of the attribute to be retrieved."""
def
__init__
(
self
,
module
,
name
,
type
=
None
,
orig_name
=
None
):
super
().
__init__
()
assert
isinstance
(
module
,
ModuleNode
)
...
...
@@ -251,6 +269,13 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:])
class
CallMethod
(
Expr
):
r
"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
Args:
node: the Node to be called.
method: the method name.
Default: "__call__"
"""
def
__init__
(
self
,
node
,
method
=
"__call__"
):
super
().
__init__
()
if
isinstance
(
node
,
type
):
...
...
@@ -320,8 +345,12 @@ class CallMethod(Expr):
# expr: outputs = apply(self.opdef, *inputs)
class
Apply
(
Expr
):
opdef
=
None
r
"""``Apply`` represents a call to :func:`apply`.
Args:
opdef: the applied :class:`OpDef`.
"""
opdef
=
None
def
__init__
(
self
,
opdef
):
super
().
__init__
()
assert
isinstance
(
opdef
,
OpDef
)
...
...
@@ -388,6 +417,11 @@ class Apply(Expr):
class
CallFunction
(
Expr
):
r
"""``CallFunction`` represents a call to a built-in function.
Args:
func: a built-in function.
"""
def
__init__
(
self
,
func
):
super
().
__init__
()
assert
isinstance
(
func
,
Callable
)
...
...
@@ -425,7 +459,14 @@ class CallFunction(Expr):
# expr outputs = self.value
class
Constant
(
Expr
):
r
"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
Args:
c: a const Tensor or Module.
name: the name of output Node.
"""
value
=
None
r
"""The const Tensor or Module"""
# TODO: constant cache to reduce the size of dumped model
_constant_cache
=
{}
...
...
imperative/python/megengine/traced_module/fake_quant.py
浏览文件 @
fb20cb36
...
...
@@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
class
FakeQuantize
(
_FakeQuantize
,
QParamsModuleMixin
):
r
"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""
def
__init__
(
self
,
dtype
:
Union
[
str
,
QuantDtypeMeta
],
enable
:
bool
=
True
,
**
kwargs
):
...
...
@@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
return
self
.
qparams
def
set_qparams
(
self
,
qparams
:
QParams
):
r
"""
r
"""Initialize :attr:`~.FakeQuantize.qparams`.
Args:
qparams: used to set initial scale
.
qparams: used to set initial ``scale`` and ``zero_point``
.
"""
if
qparams
.
scale
is
None
:
raise
AssertionError
(
"Can not get an initialized scale"
)
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
fb20cb36
...
...
@@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type
import
numpy
from
..
import
get_logger
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..module
import
Module
from
..tensor
import
Tensor
logger
=
get_logger
(
__name__
)
class
Node
:
r
"""``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method.
They are inputs/outputs of Expr(the operations on variables).
Args
:
expr: the Expr which produces the node
name: the name of the node
class
Node
:
r
"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
They are inputs/outputs of Expr (the operations on variables).
"""
expr
=
None
__total_id
=
0
_id
=
None
expr
=
None
# type: Expr
r
"""The Expr which produces the Node."""
__total_id
=
0
# type: int
_id
=
None
# type: int
_top_graph
=
None
# type: weakref.ReferenceType
_name
=
None
_orig_name
=
None
_format_spec
=
""
_name
=
None
# type: str
_orig_name
=
None
# type: str
_format_spec
=
""
# type: str
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
orig_name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
,
orig_name
:
str
):
self
.
expr
=
expr
self
.
users
=
[]
# List[Expr]
self
.
_id
=
Node
.
__total_id
...
...
@@ -73,24 +73,42 @@ class Node:
else
:
return
name
if
name
else
(
"%d"
%
self
.
_id
)
@
property
def
name
(
self
):
r
"""Return the name of this Node."""
return
self
.
_name
@
name
.
setter
def
name
(
self
,
new_name
:
str
):
graph
=
self
.
top_graph
assert
graph
is
not
None
,
"The parent graph of this Node cannot be None."
assert
new_name
not
in
graph
.
_used_names
,
(
"The name(%s) is already in use. Please try a different one again."
%
(
new_name
)
)
new_name
=
graph
.
_create_unique_name
(
new_name
)
self
.
_name
=
new_name
self
.
_orig_name
=
new_name
@
property
def
top_graph
(
self
):
r
"""Get the parent graph of this Node."""
if
self
.
_top_graph
:
return
self
.
_top_graph
()
return
None
@
classmethod
def
set_format_spec
(
cls
,
str
):
def
_
set_format_spec
(
cls
,
str
):
old_format_spec
=
cls
.
_format_spec
cls
.
_format_spec
=
str
return
old_format_spec
@
classmethod
def
get_total
_id
(
cls
):
def
_get_next
_id
(
cls
):
return
cls
.
__total_id
@
classmethod
def
set_total
_id
(
cls
,
id
:
int
=
0
):
def
_set_next
_id
(
cls
,
id
:
int
=
0
):
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
...
...
@@ -99,6 +117,7 @@ class ModuleNode(Node):
r
"""``ModuleNode`` represents the Module objects."""
module_type
=
Module
# type: Type[Module]
r
"""The type of the Module correspending to the ModuleNode."""
_owner
=
None
# type: weakref.ReferenceType
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
orig_name
:
str
=
None
):
...
...
@@ -116,6 +135,11 @@ class ModuleNode(Node):
@
property
def
owner
(
self
):
r
"""Get the ``Module`` corresponding to this ``ModuleNode``.
Returns:
An :calss:`~.Module`.
"""
if
self
.
_owner
:
return
self
.
_owner
()
return
None
...
...
@@ -145,6 +169,7 @@ class TensorNode(Node):
@
property
def
shape
(
self
):
r
"""Get the shape of this Node."""
return
self
.
_shape
@
shape
.
setter
...
...
@@ -153,6 +178,7 @@ class TensorNode(Node):
@
property
def
dtype
(
self
):
r
"""Get the dtype of this Node."""
return
self
.
_dtype
@
dtype
.
setter
...
...
@@ -161,6 +187,7 @@ class TensorNode(Node):
@
property
def
device
(
self
):
r
"""Get the device of this Node pointed Tensor."""
return
self
.
_device
@
device
.
setter
...
...
@@ -169,6 +196,7 @@ class TensorNode(Node):
@
property
def
qparams
(
self
):
r
"""Get the :calss:`QParams` of this Node."""
return
self
.
_qparams
@
qparams
.
setter
...
...
@@ -177,10 +205,16 @@ class TensorNode(Node):
@
property
def
value
(
self
):
r
"""Get the bound Tensor of this Node."""
return
self
.
_value
@
value
.
setter
def
value
(
self
,
value
):
r
"""Bind a Tensor to this Node.
Args:
value: A :class:`Tensor`.
"""
if
isinstance
(
value
,
RawTensor
)
and
NodeMixin
.
get
(
value
,
None
)
is
not
None
:
setattr
(
value
,
"_NodeMixin__node"
,
None
)
self
.
_value
=
value
...
...
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
fb20cb36
...
...
@@ -150,6 +150,9 @@ def tree_flatten(
is_leaf
:
Callable
=
_is_leaf
,
is_const_leaf
:
Callable
=
_is_const_leaf
,
):
r
"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used
to reconstruct the object.
"""
if
type
(
values
)
not
in
SUPPORTED_TYPE
:
assert
is_leaf
(
values
),
values
node
=
LeafDef
(
leaf_type
(
values
))
...
...
@@ -169,6 +172,15 @@ def tree_flatten(
class
TreeDef
:
r
"""A ``TreeDef`` represents the structure of a pytree.
Args:
type: the type of root Node of the pytree.
aux_data: some const data that is useful in unflattening the pytree.
children_defs: ``TreeDef`` for each child of the root Node.
num_leaves: the number of leaves.
"""
def
__init__
(
self
,
type
,
aux_data
,
children_defs
):
self
.
type
=
type
self
.
aux_data
=
aux_data
...
...
@@ -176,6 +188,9 @@ class TreeDef:
self
.
num_leaves
=
sum
(
ch
.
num_leaves
for
ch
in
children_defs
)
def
unflatten
(
self
,
leaves
):
r
"""Given a list of values and a ``TreeDef``, builds a object.
This is the inverse operation of ``tree_flatten``.
"""
assert
len
(
leaves
)
==
self
.
num_leaves
start
=
0
children
=
[]
...
...
@@ -196,13 +211,10 @@ class TreeDef:
)
)
def
__lt__
(
self
,
other
):
return
self
.
__hash__
()
<
other
.
__hash__
()
def
__gt__
(
self
,
other
):
return
self
.
__hash__
()
>
other
.
__hash__
()
def
__ne__
(
self
,
other
)
->
bool
:
return
not
self
.
__eq__
(
other
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
)
->
bool
:
return
(
self
.
type
==
other
.
type
and
self
.
aux_data
==
other
.
aux_data
...
...
@@ -227,6 +239,9 @@ class LeafDef(TreeDef):
assert
isinstance
(
leaves
[
0
],
self
.
type
),
self
.
type
return
leaves
[
0
]
def
__ne__
(
self
,
other
)
->
bool
:
return
not
self
.
__eq__
(
other
)
def
__eq__
(
self
,
other
):
if
isinstance
(
self
.
const_val
,
np
.
ndarray
):
return
self
.
type
==
other
.
type
and
(
self
.
const_val
==
other
.
const_val
).
all
()
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
fb20cb36
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录