Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e027dcbf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e027dcbf
编写于
9月 01, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(mge): improve symbolic tracing value/shape inference
GitOrigin-RevId: d1a6baac741726604c799752b19d2ed90e399639
上级
e6e29748
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
25 addition
and
3 deletion
+25
-3
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+25
-3
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
e027dcbf
...
...
@@ -186,6 +186,9 @@ class trace:
self
.
_seq
.
append
((
op
,
tuple
(
ihandles
),
tuple
(
ohandles
)))
self
.
_active_tensors
.
update
(
outputs
)
def
_record_const
(
self
,
op
,
outputs
):
pass
@
contextlib
.
contextmanager
def
_setup
(
self
):
global
active_trace
...
...
@@ -195,8 +198,10 @@ class trace:
if
self
.
_untraced
:
apply
.
enable
(
apply_with_tracing
)
apply
.
enable
(
apply_const_with_tracing
)
if
self
.
_symbolic
:
apply
.
enable
(
apply_symbolic_mode
)
apply
.
enable
(
apply_const_symbolic_mode
)
self
.
_lazy_eval_graph
=
G
.
Graph
()
else
:
apply
.
enable
(
apply_compiled_mode
)
...
...
@@ -239,7 +244,9 @@ class trace:
self
.
_pc
=
0
apply
.
disable
(
apply_with_tracing
)
apply
.
disable
(
apply_const_with_tracing
)
apply
.
disable
(
apply_symbolic_mode
)
apply
.
disable
(
apply_const_symbolic_mode
)
apply
.
disable
(
apply_compiled_mode
)
active_trace
=
None
...
...
@@ -477,6 +484,16 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
apply
.
disable
(
apply_symbolic_mode
)
@
apply
.
register
()
def
apply_const_symbolic_mode
(
op
:
Const
,
*
args
:
RawTensor
):
graph
=
active_trace
.
_lazy_eval_graph
ret
=
LazyEvalTensor
(
graph
.
make_const
(
op
.
value
,
dtype
=
op
.
dtype
,
device
=
op
.
device
))
return
(
ret
,)
apply
.
disable
(
apply_const_symbolic_mode
)
@
apply
.
register
()
def
apply_compiled_mode
(
op
:
OpDef
,
*
args
:
RawTensor
):
if
skip_tracing
:
...
...
@@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
apply
.
disable
(
apply_with_tracing
)
# @apply.register()
# def _(op: Const, *args: RawTensor):
# return active_trace._apply_const(op, args)
@
apply
.
register
()
def
apply_const_with_tracing
(
op
:
Const
,
*
args
:
RawTensor
):
outputs
=
apply
.
super
(
op
,
*
args
)
active_trace
.
_record_const
(
op
,
outputs
)
return
outputs
apply
.
disable
(
apply_const_with_tracing
)
class
BrokenRawTensor
(
RawTensor
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录