Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ad9ac521
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ad9ac521
编写于
8月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/imperative): remove abandoned code
GitOrigin-RevId: 0178bb56848caacbbca40a76d09847ba4d0da001
上级
03320a05
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
0 addition
and
514 deletion
+0
-514
imperative/python/megengine/core/tensor/raw_tensor/jit.py
imperative/python/megengine/core/tensor/raw_tensor/jit.py
+0
-251
imperative/python/megengine/core/tensor/raw_tensor/trace_exec.py
...ive/python/megengine/core/tensor/raw_tensor/trace_exec.py
+0
-263
未找到文件。
imperative/python/megengine/core/tensor/raw_tensor/jit.py
已删除
100644 → 0
浏览文件 @
03320a05
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
io
import
weakref
class
partial
(
functools
.
partial
):
def
__get__
(
self
,
instance
,
owner
=
None
):
if
instance
is
None
:
return
self
return
functools
.
partial
(
self
,
instance
)
def
hook
(
f
):
def
decorator
(
impl
):
return
functools
.
update_wrapper
(
partial
(
f
,
impl
),
impl
)
return
decorator
def
on_input
(
impl
,
value
):
tensor
=
impl
(
value
)
trace
=
get_trace
()
if
trace
:
var
=
trace
.
get_var
(
tensor
)
event
=
InputEvent
(
var
)
trace
.
append
(
event
)
return
tensor
def
on_read_dtype
(
impl
,
self
):
trace
=
get_trace
()
if
trace
:
var
=
trace
.
get_var
(
self
)
event
=
ReadDtypeEvent
(
var
)
trace
.
append
(
event
)
return
impl
(
self
)
def
on_read_device
(
impl
,
self
):
trace
=
get_trace
()
if
trace
:
var
=
trace
.
get_var
(
self
)
event
=
ReadDeviceEvent
(
var
)
trace
.
append
(
event
)
return
impl
(
self
)
def
on_read_shape
(
impl
,
self
):
trace
=
get_trace
()
if
trace
:
var
=
trace
.
get_var
(
self
)
event
=
ReadShapeEvent
(
var
)
trace
.
append
(
event
)
return
impl
(
self
)
def
on_read_value
(
impl
,
self
):
trace
=
get_trace
()
if
trace
:
var
=
trace
.
get_var
(
self
)
event
=
ReadValueEvent
(
var
)
trace
.
append
(
event
)
return
impl
(
self
)
def
on_builtin_op
(
impl
,
op
,
*
args
):
outputs
=
impl
(
op
,
*
args
)
trace
=
get_trace
()
if
trace
:
input_vars
=
tuple
(
map
(
trace
.
get_var
,
args
))
output_vars
=
outputs
and
tuple
(
map
(
trace
.
get_var
,
outputs
))
event
=
OpEvent
(
op
,
input_vars
,
output_vars
)
trace
.
append
(
event
)
return
outputs
def
on_del
(
impl
,
self
):
trace
=
get_trace
()
if
trace
:
var
=
trace
.
get_var
(
self
)
event
=
DelEvent
(
var
)
trace
.
append
(
event
)
return
impl
(
self
)
class
Trace
(
list
):
def
__init__
(
self
):
self
.
_var_id
=
1
self
.
_t2v
=
weakref
.
WeakKeyDictionary
()
self
.
_v2t
=
weakref
.
WeakValueDictionary
()
def
get_var
(
self
,
x
):
v
=
self
.
_t2v
.
get
(
x
)
if
v
:
return
v
v
=
self
.
_var_id
self
.
_var_id
+=
1
self
.
_t2v
[
x
]
=
v
self
.
_v2t
[
v
]
=
x
return
v
def
__bool__
(
self
):
return
True
def
__enter__
(
self
):
global
_current_trace
if
hasattr
(
self
,
"_prev_trace"
):
raise
RuntimeError
self
.
_prev_trace
=
_current_trace
_current_trace
=
self
return
self
def
__exit__
(
self
,
*
_
):
global
_current_trace
if
_current_trace
is
not
self
:
raise
RuntimeError
_current_trace
=
self
.
_prev_trace
del
self
.
_prev_trace
class
Event
:
pass
class
InputEvent
(
Event
):
def
__init__
(
self
,
var
):
self
.
var
=
var
class
ReadEvent
(
Event
):
def
__init__
(
self
,
var
):
self
.
var
=
var
class
ReadDtypeEvent
(
ReadEvent
):
pass
class
ReadDeviceEvent
(
ReadEvent
):
pass
class
ReadShapeEvent
(
ReadEvent
):
pass
class
ReadValueEvent
(
ReadEvent
):
pass
class
OpEvent
(
Event
):
def
__init__
(
self
,
op
,
inputs
,
outputs
):
self
.
op
=
op
self
.
inputs
=
inputs
self
.
outputs
=
outputs
class
DelEvent
(
Event
):
def
__init__
(
self
,
var
):
self
.
var
=
var
_current_trace
=
None
def
get_trace
()
->
Trace
:
global
_current_trace
return
_current_trace
def
format_trace
(
trace
):
buf
=
io
.
StringIO
()
active_vars
=
set
()
def
write
(
fmt
,
*
args
,
**
kwargs
):
print
(
fmt
.
format
(
*
args
,
**
kwargs
),
file
=
buf
)
def
init_vars
(
*
args
):
for
i
in
args
:
if
i
in
active_vars
:
continue
active_vars
.
add
(
i
)
write
(
"_{} = input()"
,
i
)
for
event
in
trace
:
if
isinstance
(
event
,
InputEvent
):
init_vars
(
event
.
var
)
elif
isinstance
(
event
,
ReadDtypeEvent
):
init_vars
(
event
.
var
)
write
(
"output(_{}.dtype)"
,
event
.
var
)
elif
isinstance
(
event
,
ReadDeviceEvent
):
init_vars
(
event
.
var
)
write
(
"output(_{}.device)"
,
event
.
var
)
elif
isinstance
(
event
,
ReadShapeEvent
):
init_vars
(
event
.
var
)
write
(
"output(_{}.shape)"
,
event
.
var
)
elif
isinstance
(
event
,
ReadValueEvent
):
init_vars
(
event
.
var
)
write
(
"output(_{}.dtype)"
,
event
.
var
)
elif
isinstance
(
event
,
ReadValueEvent
):
init_vars
(
event
.
var
)
write
(
"output(_{}.value)"
,
event
.
var
)
elif
isinstance
(
event
,
OpEvent
):
init_vars
(
*
event
.
inputs
)
active_vars
.
update
(
event
.
outputs
)
ovars
=
", "
.
join
(
map
(
"_{}"
.
format
,
event
.
outputs
))
ivars
=
", "
.
join
(
map
(
"_{}"
.
format
,
event
.
inputs
))
if
ovars
:
write
(
"{} = {}({})"
,
ovars
,
repr
(
event
.
op
),
ivars
)
else
:
write
(
"{}({})"
,
repr
(
event
.
op
),
ivars
)
elif
isinstance
(
event
,
DelEvent
):
init_vars
(
event
.
var
)
write
(
"del _{}"
,
event
.
var
)
else
:
raise
TypeError
(
type
(
event
))
return
buf
.
getvalue
()
def
compile_trace
(
trace
):
trace
=
list
(
trace
)
def
static_function
(
f
):
trace
=
None
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
trace
if
trace
is
None
:
with
Trace
()
as
trace
:
return
f
(
*
args
,
**
kwargs
)
return
f
(
*
args
,
**
kwargs
)
return
wrapper
imperative/python/megengine/core/tensor/raw_tensor/trace_exec.py
已删除
100644 → 0
浏览文件 @
03320a05
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
weakref
# Concepts
#
# * Internal tensor
# Tensor produced by the static sequence
#
# * External tensor
# Tensor not produced, but used as input, by the static sequence
#
# * Irrelevant tensor
# Tensor not present in input/output of any op
#
# * Escape
# An internal tensor is said to escape if it is still alive
# at the end of the sequence
# JIT-ed execution
#
# 1. read attr (dtype, device, shape)
# a. internal tensor
# read out as soon as tensor is produced
# b. external or irrelevant tensor
# fallback
#
# 2. apply op
# bind external tensors in input
#
# 3. del
class
Action
:
pass
class
ReadAttrAction
(
Action
):
def
__init__
(
self
,
var
,
name
,
getter
):
self
.
var
=
var
self
.
name
=
name
self
.
getter
=
getter
class
ReadValueAction
(
Action
):
def
__init__
(
self
,
var
,
getter
):
self
.
var
=
var
self
.
getter
=
getter
class
GetTensorAction
(
Action
):
def
__init__
(
self
,
var
,
getter
):
self
.
var
=
var
self
.
getter
=
getter
class
OpAction
(
Action
):
def
__init__
(
self
,
op
,
inputs
,
outputs
,
input_receivers
):
self
.
op
=
op
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
input_receivers
=
input_receivers
class
TensorAttr
:
def
__init__
(
self
):
self
.
shape
=
None
self
.
dtype
=
None
self
.
device
=
None
class
Bailout
(
Exception
):
pass
class
Fallback
(
Exception
):
pass
def
handle_bailout_fallback_finalize
(
f
):
@
functools
.
wraps
(
f
)
def
wrapper
(
self
,
impl
,
*
args
,
**
kwargs
):
try
:
return
f
(
*
args
,
**
kwargs
)
except
Bailout
:
self
.
bailout
()
except
Fallback
:
pass
finally
:
if
self
.
pc
==
len
(
self
):
self
.
finalize
()
return
impl
(
*
args
,
**
kwargs
)
return
wrapper
class
ExecTrajectory
(
list
):
def
__init__
(
self
):
super
().
__init__
()
self
.
reset
()
def
__bool__
(
self
):
return
True
def
__enter__
(
self
):
global
_current_trajectory
if
hasattr
(
self
,
"_prev_trajectory"
):
raise
RuntimeError
self
.
_prev_trajectory
=
_current_trajectory
_current_trajectory
=
self
self
.
_exited
=
False
return
self
def
__exit__
(
self
,
*
exc_info
):
# cleanup should be done at completion,
# which is before exiting context manager
assert
self
.
_exited
==
(
exc_info
==
(
None
,
None
,
None
))
if
not
self
.
_exited
:
assert
self
.
pc
<
len
(
self
)
self
.
bailout
()
def
_exit
(
self
):
# clean up self and global varaible
assert
not
self
.
_exited
self
.
reset
()
global
_current_trajectory
if
_current_trajectory
is
not
self
:
raise
RuntimeError
_current_trajectory
=
self
.
_prev_trajectory
del
self
.
_prev_trajectory
def
reset
(
self
):
self
.
_exited
=
True
self
.
pc
=
0
self
.
attr_cache
=
weakref
.
WeakKeyDictionary
()
### Internal and External Tensor ###
# internal tensors are those produced by us
# external tensors are those received from outside
# during JIT-ed execution, internal tensors are just placeholders.
# var_to_tensor is the binding table for all tensors
self
.
var_to_tensor
=
{}
# var -> weakref[tensor]
# tensor_to_var is the reverse binding table for internal tensors
# note that external tensors could map to >1 vars.
self
.
tensor_to_var
=
weakref
.
WeakKeyDictionary
()
# internal tensor will be materialized if its .data is accessed from outside
# after being meterialized, an intern tensor is much like an external tensor
def
finalize
(
self
):
assert
self
.
pc
==
len
(
self
)
self
.
_exit
()
def
bailout
(
self
):
self
.
_exit
()
raise
NotImplementedError
def
next_action
(
self
):
assert
not
self
.
_exited
assert
self
.
pc
<
len
(
self
)
return
self
[
self
.
pc
]
@
handle_bailout_fallback_finalize
def
read_attr
(
self
,
tensor
,
name
):
attrs
=
self
.
attr_cache
.
setdefault
(
tensor
,
TensorAttr
())
value
=
getattr
(
attrs
,
name
,
None
)
if
value
is
None
:
action
=
self
.
next_action
()
if
not
isinstance
(
action
,
ReadAttrAction
):
raise
Bailout
if
name
!=
action
.
name
:
raise
Bailout
value
=
action
.
getter
()
setattr
(
attrs
,
name
,
value
)
return
value
@
handle_bailout_fallback_finalize
def
read_value
(
self
,
impl
,
tensor
):
# possibilities:
# 1. internal tensor
# 2. external tensor
# 3. irrelevant tensor (not an input / output of any op)
if
tensor
not
in
self
.
tensor_to_var
:
raise
Fallback
assert
tensor
.
_data
is
None
action
=
self
.
next_action
()
if
not
isinstance
(
action
,
ReadValueAction
):
raise
Bailout
return
action
.
getter
()
@
handle_bailout_fallback_finalize
def
apply_op
(
self
,
impl
,
op
,
*
args
):
from
.
import
RawTensor
action
=
self
.
next_action
()
if
not
isinstance
(
action
,
OpAction
):
raise
Bailout
if
len
(
args
)
!=
len
(
action
.
inputs
):
raise
Bailout
assert
len
(
actions
.
inputs
)
==
len
(
action
.
input_receivers
)
for
v
,
t
,
r
in
zip
(
action
.
inputs
,
args
,
action
.
input_receivers
):
if
v
in
self
.
var_to_tensor
:
assert
r
is
None
if
t
is
not
self
.
var_to_tensor
[
v
]():
raise
Bailout
else
:
# NOTE: not checking for aliasing (>=2 vars map to 1 tensor)
# the static execution backend must handle this
self
.
var_to_tensor
[
v
]
=
weakref
.
ref
(
t
)
r
(
t
)
outputs
=
[]
for
v
in
action
.
outputs
:
assert
v
not
in
self
.
var_to_tensor
t
=
RawTensor
()
t
.
_data_getter
=
functools
.
partial
(
self
.
get_data
,
v
)
outputs
.
append
(
t
)
self
.
var_to_tensor
[
v
]
=
weakref
.
ref
(
t
)
return
tuple
(
outputs
)
def
get_data
(
self
,
var
):
tensor
=
self
.
var_to_tensor
[
var
]()
assert
tensor
is
not
None
assert
tensor
.
_data
is
None
assert
tensor
in
self
.
tensor_to_var
action
=
self
.
next_action
()
if
not
isinstance
(
action
,
GetTensorAction
):
self
.
bailout
()
elif
action
.
var
!=
var
:
self
.
bailout
()
else
:
tensor
.
_data
=
action
.
getter
()
del
tensor
.
_data_getter
del
self
.
tensor_to_var
[
tensor
]
assert
"_data_getter"
not
in
tensor
.
__dict__
return
tensor
.
_data_getter
()
_current_trajectory
=
None
def
get_trajectory
():
return
_current_trajectory
def
compile_trace
(
trace
):
from
.jit
import
ReadDTypeEvent
,
ReadDeviceEvent
,
ReadShapeEvent
,
OpEvent
,
DelEvent
traj
=
ExecutionTrajectory
()
active_vars
=
set
()
for
event
in
trace
:
if
isinstance
(
event
,
ReadDTypeEvent
):
traj
.
append
(
ReadAttrAction
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录