提交 ad9ac521 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): remove abandoned code

GitOrigin-RevId: 0178bb56848caacbbca40a76d09847ba4d0da001
上级 03320a05
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import io
import weakref
class partial(functools.partial):
def __get__(self, instance, owner=None):
if instance is None:
return self
return functools.partial(self, instance)
def hook(f):
def decorator(impl):
return functools.update_wrapper(partial(f, impl), impl)
return decorator
def on_input(impl, value):
tensor = impl(value)
trace = get_trace()
if trace:
var = trace.get_var(tensor)
event = InputEvent(var)
trace.append(event)
return tensor
def on_read_dtype(impl, self):
trace = get_trace()
if trace:
var = trace.get_var(self)
event = ReadDtypeEvent(var)
trace.append(event)
return impl(self)
def on_read_device(impl, self):
trace = get_trace()
if trace:
var = trace.get_var(self)
event = ReadDeviceEvent(var)
trace.append(event)
return impl(self)
def on_read_shape(impl, self):
trace = get_trace()
if trace:
var = trace.get_var(self)
event = ReadShapeEvent(var)
trace.append(event)
return impl(self)
def on_read_value(impl, self):
trace = get_trace()
if trace:
var = trace.get_var(self)
event = ReadValueEvent(var)
trace.append(event)
return impl(self)
def on_builtin_op(impl, op, *args):
outputs = impl(op, *args)
trace = get_trace()
if trace:
input_vars = tuple(map(trace.get_var, args))
output_vars = outputs and tuple(map(trace.get_var, outputs))
event = OpEvent(op, input_vars, output_vars)
trace.append(event)
return outputs
def on_del(impl, self):
trace = get_trace()
if trace:
var = trace.get_var(self)
event = DelEvent(var)
trace.append(event)
return impl(self)
class Trace(list):
def __init__(self):
self._var_id = 1
self._t2v = weakref.WeakKeyDictionary()
self._v2t = weakref.WeakValueDictionary()
def get_var(self, x):
v = self._t2v.get(x)
if v:
return v
v = self._var_id
self._var_id += 1
self._t2v[x] = v
self._v2t[v] = x
return v
def __bool__(self):
return True
def __enter__(self):
global _current_trace
if hasattr(self, "_prev_trace"):
raise RuntimeError
self._prev_trace = _current_trace
_current_trace = self
return self
def __exit__(self, *_):
global _current_trace
if _current_trace is not self:
raise RuntimeError
_current_trace = self._prev_trace
del self._prev_trace
class Event:
pass
class InputEvent(Event):
def __init__(self, var):
self.var = var
class ReadEvent(Event):
def __init__(self, var):
self.var = var
class ReadDtypeEvent(ReadEvent):
pass
class ReadDeviceEvent(ReadEvent):
pass
class ReadShapeEvent(ReadEvent):
pass
class ReadValueEvent(ReadEvent):
pass
class OpEvent(Event):
def __init__(self, op, inputs, outputs):
self.op = op
self.inputs = inputs
self.outputs = outputs
class DelEvent(Event):
def __init__(self, var):
self.var = var
_current_trace = None
def get_trace() -> Trace:
global _current_trace
return _current_trace
def format_trace(trace):
buf = io.StringIO()
active_vars = set()
def write(fmt, *args, **kwargs):
print(fmt.format(*args, **kwargs), file=buf)
def init_vars(*args):
for i in args:
if i in active_vars:
continue
active_vars.add(i)
write("_{} = input()", i)
for event in trace:
if isinstance(event, InputEvent):
init_vars(event.var)
elif isinstance(event, ReadDtypeEvent):
init_vars(event.var)
write("output(_{}.dtype)", event.var)
elif isinstance(event, ReadDeviceEvent):
init_vars(event.var)
write("output(_{}.device)", event.var)
elif isinstance(event, ReadShapeEvent):
init_vars(event.var)
write("output(_{}.shape)", event.var)
elif isinstance(event, ReadValueEvent):
init_vars(event.var)
write("output(_{}.dtype)", event.var)
elif isinstance(event, ReadValueEvent):
init_vars(event.var)
write("output(_{}.value)", event.var)
elif isinstance(event, OpEvent):
init_vars(*event.inputs)
active_vars.update(event.outputs)
ovars = ", ".join(map("_{}".format, event.outputs))
ivars = ", ".join(map("_{}".format, event.inputs))
if ovars:
write("{} = {}({})", ovars, repr(event.op), ivars)
else:
write("{}({})", repr(event.op), ivars)
elif isinstance(event, DelEvent):
init_vars(event.var)
write("del _{}", event.var)
else:
raise TypeError(type(event))
return buf.getvalue()
def compile_trace(trace):
trace = list(trace)
def static_function(f):
trace = None
@functools.wraps(f)
def wrapper(*args, **kwargs):
nonlocal trace
if trace is None:
with Trace() as trace:
return f(*args, **kwargs)
return f(*args, **kwargs)
return wrapper
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import weakref
# Concepts
#
# * Internal tensor
# Tensor produced by the static sequence
#
# * External tensor
# Tensor not produced, but used as input, by the static sequence
#
# * Irrelevant tensor
# Tensor not present in input/output of any op
#
# * Escape
# An internal tensor is said to escape if it is still alive
# at the end of the sequence
# JIT-ed execution
#
# 1. read attr (dtype, device, shape)
# a. internal tensor
# read out as soon as tensor is produced
# b. external or irrelevant tensor
# fallback
#
# 2. apply op
# bind external tensors in input
#
# 3. del
class Action:
pass
class ReadAttrAction(Action):
def __init__(self, var, name, getter):
self.var = var
self.name = name
self.getter = getter
class ReadValueAction(Action):
def __init__(self, var, getter):
self.var = var
self.getter = getter
class GetTensorAction(Action):
def __init__(self, var, getter):
self.var = var
self.getter = getter
class OpAction(Action):
def __init__(self, op, inputs, outputs, input_receivers):
self.op = op
self.inputs = inputs
self.outputs = outputs
self.input_receivers = input_receivers
class TensorAttr:
def __init__(self):
self.shape = None
self.dtype = None
self.device = None
class Bailout(Exception):
pass
class Fallback(Exception):
pass
def handle_bailout_fallback_finalize(f):
@functools.wraps(f)
def wrapper(self, impl, *args, **kwargs):
try:
return f(*args, **kwargs)
except Bailout:
self.bailout()
except Fallback:
pass
finally:
if self.pc == len(self):
self.finalize()
return impl(*args, **kwargs)
return wrapper
class ExecTrajectory(list):
def __init__(self):
super().__init__()
self.reset()
def __bool__(self):
return True
def __enter__(self):
global _current_trajectory
if hasattr(self, "_prev_trajectory"):
raise RuntimeError
self._prev_trajectory = _current_trajectory
_current_trajectory = self
self._exited = False
return self
def __exit__(self, *exc_info):
# cleanup should be done at completion,
# which is before exiting context manager
assert self._exited == (exc_info == (None, None, None))
if not self._exited:
assert self.pc < len(self)
self.bailout()
def _exit(self):
# clean up self and global varaible
assert not self._exited
self.reset()
global _current_trajectory
if _current_trajectory is not self:
raise RuntimeError
_current_trajectory = self._prev_trajectory
del self._prev_trajectory
def reset(self):
self._exited = True
self.pc = 0
self.attr_cache = weakref.WeakKeyDictionary()
### Internal and External Tensor ###
# internal tensors are those produced by us
# external tensors are those received from outside
# during JIT-ed execution, internal tensors are just placeholders.
# var_to_tensor is the binding table for all tensors
self.var_to_tensor = {} # var -> weakref[tensor]
# tensor_to_var is the reverse binding table for internal tensors
# note that external tensors could map to >1 vars.
self.tensor_to_var = weakref.WeakKeyDictionary()
# internal tensor will be materialized if its .data is accessed from outside
# after being meterialized, an intern tensor is much like an external tensor
def finalize(self):
assert self.pc == len(self)
self._exit()
def bailout(self):
self._exit()
raise NotImplementedError
def next_action(self):
assert not self._exited
assert self.pc < len(self)
return self[self.pc]
@handle_bailout_fallback_finalize
def read_attr(self, tensor, name):
attrs = self.attr_cache.setdefault(tensor, TensorAttr())
value = getattr(attrs, name, None)
if value is None:
action = self.next_action()
if not isinstance(action, ReadAttrAction):
raise Bailout
if name != action.name:
raise Bailout
value = action.getter()
setattr(attrs, name, value)
return value
@handle_bailout_fallback_finalize
def read_value(self, impl, tensor):
# possibilities:
# 1. internal tensor
# 2. external tensor
# 3. irrelevant tensor (not an input / output of any op)
if tensor not in self.tensor_to_var:
raise Fallback
assert tensor._data is None
action = self.next_action()
if not isinstance(action, ReadValueAction):
raise Bailout
return action.getter()
@handle_bailout_fallback_finalize
def apply_op(self, impl, op, *args):
from . import RawTensor
action = self.next_action()
if not isinstance(action, OpAction):
raise Bailout
if len(args) != len(action.inputs):
raise Bailout
assert len(actions.inputs) == len(action.input_receivers)
for v, t, r in zip(action.inputs, args, action.input_receivers):
if v in self.var_to_tensor:
assert r is None
if t is not self.var_to_tensor[v]():
raise Bailout
else:
# NOTE: not checking for aliasing (>=2 vars map to 1 tensor)
# the static execution backend must handle this
self.var_to_tensor[v] = weakref.ref(t)
r(t)
outputs = []
for v in action.outputs:
assert v not in self.var_to_tensor
t = RawTensor()
t._data_getter = functools.partial(self.get_data, v)
outputs.append(t)
self.var_to_tensor[v] = weakref.ref(t)
return tuple(outputs)
def get_data(self, var):
tensor = self.var_to_tensor[var]()
assert tensor is not None
assert tensor._data is None
assert tensor in self.tensor_to_var
action = self.next_action()
if not isinstance(action, GetTensorAction):
self.bailout()
elif action.var != var:
self.bailout()
else:
tensor._data = action.getter()
del tensor._data_getter
del self.tensor_to_var[tensor]
assert "_data_getter" not in tensor.__dict__
return tensor._data_getter()
_current_trajectory = None
def get_trajectory():
return _current_trajectory
def compile_trace(trace):
from .jit import ReadDTypeEvent, ReadDeviceEvent, ReadShapeEvent, OpEvent, DelEvent
traj = ExecutionTrajectory()
active_vars = set()
for event in trace:
if isinstance(event, ReadDTypeEvent):
traj.append(ReadAttrAction())
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册