Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b3d0affa
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b3d0affa
编写于
8月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): support trace custom qat module
GitOrigin-RevId: 49f70a5f467e93ff58fc5152499f04733258fd0d
上级
15712807
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
199 addition
and
22 deletion
+199
-22
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+36
-5
imperative/python/megengine/experimental/traced_module/fake_quant.py
...python/megengine/experimental/traced_module/fake_quant.py
+48
-0
imperative/python/megengine/experimental/traced_module/module_tracer.py
...hon/megengine/experimental/traced_module/module_tracer.py
+3
-1
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+2
-2
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+5
-4
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+105
-10
未找到文件。
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
b3d0affa
...
...
@@ -17,12 +17,14 @@ from typing import Callable, Dict, List
from
...core._imperative_rt
import
OpDef
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...core._imperative_rt.core2
import
apply
,
set_module_tracing
,
unset_module_tracing
from
...core.ops.builtin
import
FakeQuant
from
...core.ops.special
import
Const
from
...module
import
Module
from
...tensor
import
Parameter
,
Tensor
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
TreeDef
,
tree_flatten
from
.pytree
import
ArgsIndex
,
TreeDef
,
_is_const_leaf
,
_is_leaf
,
tree_flatten
from
.serialization
import
get_opdef_state
,
load_opdef_from_state
def
rstrip
(
s
:
str
,
__chars
:
str
):
...
...
@@ -76,6 +78,7 @@ class Expr:
node
.
users
.
append
(
self
)
else
:
assert
node
is
None
assert
_is_leaf
(
val
)
and
_is_const_leaf
(
val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
...
...
@@ -154,6 +157,11 @@ class Expr:
return
self
.
_top_graph
()
return
None
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
.
pop
(
"_top_graph"
,
None
)
return
state
# expr: None (i.e. fake expression which is used to mark input)
class
Input
(
Expr
):
...
...
@@ -321,14 +329,36 @@ class Apply(Expr):
", "
.
join
(
str
(
i
)
for
i
in
self
.
inputs
),
)
def
__getstate__
(
self
):
state
=
super
().
__getstate__
()
state
[
"opdef"
]
=
get_opdef_state
(
state
[
"opdef"
])
return
state
def
__setstate__
(
self
,
state
):
state
[
"opdef"
]
=
load_opdef_from_state
(
state
[
"opdef"
])
for
k
,
v
in
state
.
items
():
setattr
(
self
,
k
,
v
)
@
classmethod
def
apply_module_trace_hook
(
cls
,
opdef
,
*
inputs
):
for
i
in
inputs
:
node
=
NodeMixin
.
get
(
i
,
None
)
if
node
is
None
:
# capture as constant
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
apply_node
=
cls
.
make
(
opdef
)
apply_node
.
add_inputs
(
inputs
)
if
isinstance
(
opdef
,
FakeQuant
):
inp_nodes
=
[
NodeMixin
.
get
(
inputs
[
0
])]
for
i
in
inputs
[
1
:]:
node
=
Constant
.
make
(
i
)
inp_nodes
.
append
(
node
)
apply_node
=
cls
.
make
(
opdef
)
for
n
in
inp_nodes
:
n
.
users
.
append
(
apply_node
)
apply_node
.
inputs
=
inp_nodes
else
:
apply_node
=
cls
.
make
(
opdef
)
apply_node
.
add_inputs
(
inputs
)
assert
not
apply_node
.
const_val
unset_module_tracing
()
...
...
@@ -387,7 +417,7 @@ class Constant(Expr):
super
().
__init__
()
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
assert
module_tracer
.
is_builtin
(
c
)
assert
module_tracer
.
is_builtin
(
c
)
or
c
.
is_qat
self
.
value
=
c
self
.
name
=
name
self
.
inputs
=
[]
...
...
@@ -395,6 +425,7 @@ class Constant(Expr):
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
),
]
self
.
outputs
[
0
].
_name
=
name
if
name
else
"const_"
+
str
(
self
.
_id
)
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
...
...
@@ -422,7 +453,7 @@ class Constant(Expr):
)
def
__getstate__
(
self
):
state
=
s
elf
.
__dict__
.
copy
()
state
=
s
uper
().
__getstate__
()
if
isinstance
(
self
.
value
,
RawTensor
):
state
[
"value"
]
=
Tensor
(
self
.
value
)
return
state
imperative/python/megengine/experimental/traced_module/fake_quant.py
0 → 100644
浏览文件 @
b3d0affa
from
copy
import
deepcopy
from
typing
import
Union
from
...core.tensor.dtype
import
QuantDtypeMeta
from
...quantization.fake_quant
import
QParamsModuleMixin
,
_FakeQuantize
from
...quantization.utils
import
QParams
,
QuantMode
,
fake_quant_tensor
class
FakeQuantize
(
_FakeQuantize
,
QParamsModuleMixin
):
def
__init__
(
self
,
dtype
:
Union
[
str
,
QuantDtypeMeta
],
enable
:
bool
=
True
,
**
kwargs
):
super
().
__init__
(
dtype
,
enable
,
**
kwargs
)
self
.
qparams
=
None
def
fake_quant_forward
(
self
,
inp
,
qparams
:
QParams
=
None
):
if
qparams
is
None
:
qparams
=
self
.
get_qparams
()
assert
(
qparams
.
dtype_meta
is
self
.
dtype
),
"input qparams' dtype is not equal to self.dtype.
\n
qparams.dtype_meta={}
\n
self.dtype={}"
.
format
(
qparams
.
dtype_meta
,
self
.
dtype
)
return
fake_quant_tensor
(
inp
,
qparams
)
def
get_qparams
(
self
):
return
self
.
qparams
def
set_qparams
(
self
,
qparams
:
QParams
):
"""
:param qparams: used to set initial scale.
"""
if
qparams
.
scale
is
None
:
raise
AssertionError
(
"Can not get an initialized scale"
)
scale
=
qparams
.
scale
if
qparams
.
dtype_meta
is
None
:
qparams
.
dtype_meta
=
self
.
dtype
else
:
assert
(
qparams
.
dtype_meta
is
self
.
dtype
),
"input qparams' dtype is not equal to self.dtype.
\n
qparams.dtype_meta={}
\n
self.dtype={}"
.
format
(
qparams
.
dtype_meta
,
self
.
dtype
)
dtype_meta
=
qparams
.
dtype_meta
zero_point
=
qparams
.
zero_point
mode
=
qparams
.
mode
self
.
qparams
=
QParams
(
mode
,
dtype_meta
,
scale
,
zero_point
)
imperative/python/megengine/experimental/traced_module/module_tracer.py
浏览文件 @
b3d0affa
...
...
@@ -12,6 +12,7 @@ from ... import Tensor
from
...
import
functional
as
F
from
...core.tensor.array_method
import
ArrayMethodMixin
from
...module
import
Module
from
...module.qat
import
QATModule
_active_module_tracer
=
None
...
...
@@ -68,7 +69,7 @@ BUILTIN_ARRAY_METHOD = [
"__iand__"
,
"__ior__"
,
"__ixor__"
,
"
T
"
,
"
transpose
"
,
"astype"
,
"reshape"
,
"_broadcast"
,
...
...
@@ -180,6 +181,7 @@ class Patcher:
self
.
patch_method
(
ArrayMethodMixin
,
meth
,
self
.
wrap_fn
)
self
.
patch_method
(
Tensor
,
"detach"
,
self
.
wrap_fn
)
self
.
patch_method
(
Tensor
,
"__new__"
,
self
.
wrap_fn
)
self
.
patch_method
(
QATModule
,
"_apply_fakequant_with_observer"
,
self
.
wrap_fn
)
for
i
,
j
in
self
.
_builtin_functions
:
if
id
(
i
)
not
in
self
.
visited_frames_ids
:
self
.
patch_function
(
i
,
j
,
self
.
wrap_fn
)
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
b3d0affa
...
...
@@ -127,7 +127,7 @@ class TensorNode(Node):
shape
=
None
# type: Tuple[int]
dtype
=
None
# type: numpy.dtype
qparam
=
None
qparam
s
=
None
device
=
None
def
__getstate__
(
self
):
...
...
@@ -135,7 +135,7 @@ class TensorNode(Node):
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"qparam
"
:
self
.
qparam
,
"qparam
s"
:
self
.
qparams
,
"shape"
:
self
.
shape
,
"dtype"
:
self
.
dtype
,
"device"
:
self
.
device
,
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
浏览文件 @
b3d0affa
...
...
@@ -155,10 +155,7 @@ def tree_flatten(
assert
is_leaf
(
values
),
values
node
=
LeafDef
(
leaf_type
(
values
))
if
is_const_leaf
(
values
):
if
isinstance
(
values
,
np
.
ndarray
):
node
.
const_val
=
str
(
values
)
else
:
node
.
const_val
=
values
node
.
const_val
=
values
return
[
values
,],
node
rst
=
[]
...
...
@@ -232,9 +229,13 @@ class LeafDef(TreeDef):
return
leaves
[
0
]
def
__eq__
(
self
,
other
):
if
isinstance
(
self
.
const_val
,
np
.
ndarray
):
return
self
.
type
==
other
.
type
and
(
self
.
const_val
==
other
.
const_val
).
all
()
return
self
.
type
==
other
.
type
and
self
.
const_val
==
other
.
const_val
def
__hash__
(
self
):
if
isinstance
(
self
.
const_val
,
np
.
ndarray
):
return
hash
(
tuple
([
self
.
type
,
str
(
self
.
const_val
)]))
return
hash
(
tuple
([
self
.
type
,
self
.
const_val
]))
def
__repr__
(
self
):
...
...
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
b3d0affa
...
...
@@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import (
from
...core._trace_option
import
set_symbolic_shape
from
...core.tensor.array_method
import
ArrayMethodMixin
from
...module
import
Module
from
...quantization.fake_quant
import
LSQ
,
TQT
,
FakeQuantize
from
...module.qat
import
QATModule
from
...quantization.fake_quant
import
LSQ
,
TQT
,
FakeQuantize
,
_FakeQuantize
from
...quantization.observer
import
(
ExponentialMovingAverageObserver
,
HistogramObserver
,
MinMaxObserver
,
Observer
,
PassiveObserver
,
SyncExponentialMovingAverageObserver
,
SyncMinMaxObserver
,
)
from
...tensor
import
Tensor
from
.expr
import
Apply
,
CallFunction
,
CallMethod
,
Constant
,
Expr
,
GetAttr
,
Input
from
.fake_quant
import
FakeQuantize
as
TM_FakeQuant
from
.module_tracer
import
(
Patcher
,
active_module_tracer
,
...
...
@@ -613,7 +619,8 @@ def _wrapped_function(orig_func):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
meth_name
=
_get_meth_name
(
args
[
0
],
wrapped_fn
)
if
args
else
None
if
meth_name
:
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
if
meth_name
and
issubclass
(
arg_type
,
RawTensor
):
self
=
inputs
[
0
]
if
meth_name
==
"__new__"
:
if
all
([
not
isinstance
(
i
,
RawTensor
)
for
i
in
inputs
]):
...
...
@@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin):
self
.
_mod
=
mod
self
.
_body
=
None
self
.
_is_top
=
is_top_module
self
.
_is_builtin
=
module_tracer
.
is_builtin
(
mod
)
self
.
_is_builtin
=
(
True
if
isinstance
(
mod
,
(
Observer
,
_FakeQuantize
))
else
module_tracer
.
is_builtin
(
mod
)
)
if
isinstance
(
self
.
_mod
,
QATModule
):
unset_module_tracing
()
self
.
_check_qat_module
(
self
.
_mod
)
set_module_tracing
()
self
.
_argdef_graph_map
=
{}
self
.
_argdef_outdef_map
=
{}
...
...
@@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin):
dict
(
TracedModuleBuilder
.
__dict__
),
)
def
_check_qat_module
(
self
,
qat_module
):
def
isbuiltin
(
m
):
return
m
is
None
or
module_tracer
.
is_builtin
(
m
)
if
qat_module
.
with_act
:
act_observer
=
qat_module
.
act_observer
act_fakequant
=
qat_module
.
act_fake_quant
if
not
isbuiltin
(
act_observer
)
or
not
isbuiltin
(
act_fakequant
):
qparams
=
(
act_observer
.
get_qparams
()
if
hasattr
(
act_observer
,
"get_qparams"
)
else
act_fakequant
.
get_qparams
()
)
dtype
=
(
act_observer
.
dtype
if
hasattr
(
act_observer
,
"dtype"
)
else
act_fakequant
.
dtype
)
qat_module
.
act_observer
=
None
qat_module
.
act_fake_quant
=
TM_FakeQuant
(
dtype
)
qat_module
.
act_fake_quant
.
set_qparams
(
qparams
)
if
qat_module
.
with_weight
:
weight_observer
=
qat_module
.
weight_observer
weight_fakequant
=
qat_module
.
weight_fake_quant
if
not
isbuiltin
(
weight_observer
)
or
not
isbuiltin
(
weight_fakequant
):
qparams
=
(
weight_observer
.
get_qparams
()
if
hasattr
(
weight_observer
,
"get_qparams"
)
else
weight_fakequant
.
get_qparams
()
)
dtype
=
(
weight_observer
.
dtype
if
hasattr
(
weight_observer
,
"dtype"
)
else
weight_fakequant
.
dtype
)
qat_module
.
weight_observer
=
None
qat_module
.
weight_fake_quant
=
TM_FakeQuant
(
dtype
)
qat_module
.
weight_fake_quant
.
set_qparams
(
qparams
)
def
build
(
self
):
if
self
.
_is_builtin
or
isinstance
(
self
.
_mod
,
TracedModule
):
if
module_tracer
.
is_builtin
(
self
.
_mod
)
or
isinstance
(
self
.
_mod
,
TracedModule
):
mod_type
=
type
(
self
.
_mod
)
else
:
assert
isinstance
(
self
.
_mod
,
(
Observer
,
_FakeQuantize
))
mod_type
=
(
Observer
if
isinstance
(
self
.
_mod
,
Observer
)
else
_FakeQuantize
)
for
node
in
self
.
nodes
:
node
.
module_type
=
type
(
self
.
_mod
)
# node._owner = weakref.ref(self._mod)
node
.
module_type
=
mod_type
return
self
.
_mod
else
:
is_qat
=
isinstance
(
self
.
_mod
,
QATModule
)
traced_module
=
TracedModule
(
self
.
_is_top
,
self
.
_argdef_graph_map
,
self
.
_argdef_outdef_map
self
.
_is_top
,
self
.
_argdef_graph_map
,
self
.
_argdef_outdef_map
,
is_qat
)
for
_
,
g
in
self
.
_argdef_graph_map
.
items
():
g
.
compile
()
...
...
@@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin):
v
=
v
.
build
()
setattr
(
traced_module
,
k
,
v
)
if
isinstance
(
self
.
_mod
,
QATModule
):
unset_module_tracing
()
traced_module
.
with_act
=
self
.
_mod
.
with_act
traced_module
.
with_weight
=
self
.
_mod
.
with_weight
if
not
hasattr
(
traced_module
,
"act_fake_quant"
):
traced_module
.
act_fakequant
=
None
if
not
hasattr
(
traced_module
,
"act_observer"
):
traced_module
.
act_observer
=
None
if
not
hasattr
(
traced_module
,
"weight_fake_quant"
):
traced_module
.
weight_fakequant
=
None
if
not
hasattr
(
traced_module
,
"weight_observer"
):
traced_module
.
weight_observer
=
None
set_module_tracing
()
return
traced_module
def
_record_wrapped_nodes
(
self
,
node
):
...
...
@@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin):
attr
=
getattr
(
self
.
_mod
,
name
)
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
setattr
(
self
,
name
,
attr
)
if
isinstance
(
attr
,
(
Module
,
RawTensor
)):
setattr
(
self
,
name
,
attr
)
NodeMixin
.
wrap
(
attr
,
lambda
:
GetAttr
.
make
(
...
...
@@ -1066,7 +1146,7 @@ class TracedModule(Module):
argdef_graph_map
=
None
argdef_outdef_map
=
None
def
__init__
(
self
,
is_top
,
argdef_graph_map
,
argdef_outdef_map
):
def
__init__
(
self
,
is_top
,
argdef_graph_map
,
argdef_outdef_map
,
is_qat
=
False
):
super
(
TracedModule
,
self
).
__init__
()
self
.
argdef_graph_map
=
argdef_graph_map
self
.
argdef_outdef_map
=
argdef_outdef_map
...
...
@@ -1074,6 +1154,7 @@ class TracedModule(Module):
self
.
watch_points
=
[]
self
.
watch_node_value
=
{}
self
.
end_points
=
[]
self
.
is_qat
=
is_qat
def
forward
(
self
,
*
args
,
**
kwargs
):
inputs
,
treedef
=
tree_flatten
(((
self
,
*
args
),
kwargs
))
...
...
@@ -1195,8 +1276,8 @@ class TracedModule(Module):
):
if
graph
is
not
None
and
prefix_name
and
prefix_name
[
-
1
]
!=
"_"
:
prefix_name
+=
"_"
if
graph
is
None
:
assert
not
isinstance
(
module
,
TracedModule
)
if
graph
is
None
or
module
.
is_qat
:
assert
not
isinstance
(
module
,
TracedModule
)
or
module
.
is_qat
const
=
Constant
(
module
,
"self.%s"
%
module2name
[
id
(
module
)])
m_node
=
call
.
inputs
[
0
]
if
m_node
.
top_graph
!=
active_module_tracer
().
current_scope
():
...
...
@@ -1326,9 +1407,23 @@ def _register_all_builtin_module():
isclass
(
m
[
1
])
and
issubclass
(
m
[
1
],
M
.
Module
)
and
m
[
1
]
is
not
M
.
Sequential
and
m
[
1
]
is
not
M
.
ModuleList
):
module_tracer
.
register_as_builtin
(
m
[
1
])
module_tracer
.
register_as_builtin
(
Observer
)
module_tracer
.
register_as_builtin
(
MinMaxObserver
)
module_tracer
.
register_as_builtin
(
SyncMinMaxObserver
)
module_tracer
.
register_as_builtin
(
ExponentialMovingAverageObserver
)
module_tracer
.
register_as_builtin
(
SyncExponentialMovingAverageObserver
)
module_tracer
.
register_as_builtin
(
HistogramObserver
)
module_tracer
.
register_as_builtin
(
PassiveObserver
)
module_tracer
.
register_as_builtin
(
LSQ
)
module_tracer
.
register_as_builtin
(
TQT
)
module_tracer
.
register_as_builtin
(
FakeQuantize
)
module_tracer
.
register_as_builtin
(
TM_FakeQuant
)
def
trace_module
(
mod
:
Module
,
*
args
:
Tensor
,
**
kwargs
:
Tensor
)
->
TracedModule
:
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录