Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b43f6a26
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b43f6a26
编写于
8月 13, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/quantization): handle empty Observer in QATModule
GitOrigin-RevId: e8a62297bc513a30be900743c3c199ccc2b30273
上级
13e8f00a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
32 addition
and
24 deletion
+32
-24
python_module/megengine/module/qat/module.py
python_module/megengine/module/qat/module.py
+32
-24
未找到文件。
python_module/megengine/module/qat/module.py
浏览文件 @
b43f6a26
...
@@ -70,18 +70,22 @@ class QATModule(Module):
...
@@ -70,18 +70,22 @@ class QATModule(Module):
def
_apply_fakequant_with_observer
(
def
_apply_fakequant_with_observer
(
self
,
target
:
Tensor
,
fake_quant
:
FakeQuantize
,
observer
:
Observer
self
,
target
:
Tensor
,
fake_quant
:
FakeQuantize
,
observer
:
Observer
):
):
# do observer
if
observer
is
None
:
if
observer
is
None
:
return
target
q_dict
=
None
oup
=
observer
(
target
)
oup
=
target
q_dict
=
observer
.
get_qparams
()
else
:
q_dict
=
observer
.
get_qparams
()
oup
=
observer
(
target
)
# do fake quant
# do fake quant
if
fake_quant
is
not
None
:
if
fake_quant
is
not
None
:
oup
=
fake_quant
(
oup
,
q_dict
)
oup
=
fake_quant
(
oup
,
q_dict
)
# use qparams of fake_quant if have.
# use qparams of fake_quant if have.
if
hasattr
(
fake_quant
,
"get_qparams"
):
if
hasattr
(
fake_quant
,
"get_qparams"
):
q_dict
=
fake_quant
.
get_qparams
()
q_dict
=
fake_quant
.
get_qparams
()
# set to tensor qparams.
# set to tensor qparams.
oup
.
q_dict
.
update
(
q_dict
)
if
q_dict
is
not
None
:
oup
.
q_dict
.
update
(
q_dict
)
return
oup
return
oup
def
apply_quant_weight
(
self
,
target
:
Tensor
):
def
apply_quant_weight
(
self
,
target
:
Tensor
):
...
@@ -100,42 +104,46 @@ class QATModule(Module):
...
@@ -100,42 +104,46 @@ class QATModule(Module):
target
,
self
.
act_fake_quant
,
self
.
act_observer
target
,
self
.
act_fake_quant
,
self
.
act_observer
)
)
def
_get_method_result
(
self
,
method
:
str
,
fake_quant
:
FakeQuantize
,
observer
:
Observer
):
if
hasattr
(
fake_quant
,
method
):
return
getattr
(
fake_quant
,
method
)()
elif
hasattr
(
observer
,
method
):
return
getattr
(
observer
,
method
)()
return
None
def
get_weight_dtype
(
self
):
def
get_weight_dtype
(
self
):
r
"""
r
"""
Get weight's quantization dtype as the method from ``qconfig``.
Get weight's quantization dtype as the method from ``qconfig``.
"""
"""
if
hasattr
(
self
.
weight_fake_quant
,
"get_dtype"
):
return
self
.
_get_method_result
(
return
self
.
weight_fake_quant
.
get_dtype
()
"get_dtype"
,
self
.
weight_fake_quant
,
self
.
weight_observer
else
:
)
return
self
.
weight_observer
.
get_dtype
()
def
get_activation_dtype
(
self
):
def
get_activation_dtype
(
self
):
r
"""
r
"""
Get activation's quantization dtype as the method from ``qconfig``.
Get activation's quantization dtype as the method from ``qconfig``.
"""
"""
if
hasattr
(
self
.
act_fake_quant
,
"get_dtype"
):
return
self
.
_get_method_result
(
return
self
.
act_fake_quant
.
get_dtype
()
"get_dtype"
,
self
.
act_fake_quant
,
self
.
act_observer
else
:
)
return
self
.
act_observer
.
get_dtype
()
def
_get_qparams
(
self
,
fake_quant
:
FakeQuantize
,
observer
:
Observer
):
if
hasattr
(
fake_quant
,
"get_qparams"
):
return
fake_quant
.
get_qparams
()
elif
observer
is
not
None
:
return
observer
.
get_qparams
()
return
None
def
get_weight_qparams
(
self
):
def
get_weight_qparams
(
self
):
r
"""
r
"""
Get weight's quantization parameters.
Get weight's quantization parameters.
"""
"""
return
self
.
_get_qparams
(
self
.
weight_fake_quant
,
self
.
weight_observer
)
return
self
.
_get_method_result
(
"get_qparams"
,
self
.
weight_fake_quant
,
self
.
weight_observer
)
def
get_activation_qparams
(
self
):
def
get_activation_qparams
(
self
):
r
"""
r
"""
Get activation's quantization parameters.
Get activation's quantization parameters.
"""
"""
return
self
.
_get_qparams
(
self
.
act_fake_quant
,
self
.
act_observer
)
return
self
.
_get_method_result
(
"get_qparams"
,
self
.
act_fake_quant
,
self
.
act_observer
)
@
classmethod
@
classmethod
@
abstractmethod
@
abstractmethod
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录