From ad9ac521ff7c029564392cff335cd21dd371b997 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 18 Aug 2020 20:39:43 +0800 Subject: [PATCH] refactor(mge/imperative): remove abandoned code GitOrigin-RevId: 0178bb56848caacbbca40a76d09847ba4d0da001 --- .../megengine/core/tensor/raw_tensor/jit.py | 251 ----------------- .../core/tensor/raw_tensor/trace_exec.py | 263 ------------------ 2 files changed, 514 deletions(-) delete mode 100644 imperative/python/megengine/core/tensor/raw_tensor/jit.py delete mode 100644 imperative/python/megengine/core/tensor/raw_tensor/trace_exec.py diff --git a/imperative/python/megengine/core/tensor/raw_tensor/jit.py b/imperative/python/megengine/core/tensor/raw_tensor/jit.py deleted file mode 100644 index 091b3789d..000000000 --- a/imperative/python/megengine/core/tensor/raw_tensor/jit.py +++ /dev/null @@ -1,251 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import functools -import io -import weakref - - -class partial(functools.partial): - def __get__(self, instance, owner=None): - if instance is None: - return self - return functools.partial(self, instance) - - -def hook(f): - def decorator(impl): - return functools.update_wrapper(partial(f, impl), impl) - - return decorator - - -def on_input(impl, value): - tensor = impl(value) - trace = get_trace() - if trace: - var = trace.get_var(tensor) - event = InputEvent(var) - trace.append(event) - return tensor - - -def on_read_dtype(impl, self): - trace = get_trace() - if trace: - var = trace.get_var(self) - event = ReadDtypeEvent(var) - trace.append(event) - - return impl(self) - - -def on_read_device(impl, self): - trace = get_trace() - if trace: - var = trace.get_var(self) - event = ReadDeviceEvent(var) - trace.append(event) - - return impl(self) - - -def on_read_shape(impl, self): - trace = get_trace() - if trace: - var = trace.get_var(self) - event = ReadShapeEvent(var) - trace.append(event) - - return impl(self) - - -def on_read_value(impl, self): - trace = get_trace() - if trace: - var = trace.get_var(self) - event = ReadValueEvent(var) - trace.append(event) - - return impl(self) - - -def on_builtin_op(impl, op, *args): - outputs = impl(op, *args) - - trace = get_trace() - if trace: - input_vars = tuple(map(trace.get_var, args)) - output_vars = outputs and tuple(map(trace.get_var, outputs)) - event = OpEvent(op, input_vars, output_vars) - trace.append(event) - - return outputs - - -def on_del(impl, self): - trace = get_trace() - if trace: - var = trace.get_var(self) - event = DelEvent(var) - trace.append(event) - - return impl(self) - - -class Trace(list): - def __init__(self): - self._var_id = 1 - self._t2v = weakref.WeakKeyDictionary() - self._v2t = weakref.WeakValueDictionary() - - def get_var(self, x): - v = self._t2v.get(x) - if v: - return v - v = self._var_id - self._var_id += 1 - self._t2v[x] = v - self._v2t[v] = x - return v - - def __bool__(self): - return True - - def __enter__(self): - global _current_trace - if hasattr(self, "_prev_trace"): - raise RuntimeError - self._prev_trace = _current_trace - _current_trace = self - return self - - def __exit__(self, *_): - global _current_trace - if _current_trace is not self: - raise RuntimeError - _current_trace = self._prev_trace - del self._prev_trace - - -class Event: - pass - - -class InputEvent(Event): - def __init__(self, var): - self.var = var - - -class ReadEvent(Event): - def __init__(self, var): - self.var = var - - -class ReadDtypeEvent(ReadEvent): - pass - - -class ReadDeviceEvent(ReadEvent): - pass - - -class ReadShapeEvent(ReadEvent): - pass - - -class ReadValueEvent(ReadEvent): - pass - - -class OpEvent(Event): - def __init__(self, op, inputs, outputs): - self.op = op - self.inputs = inputs - self.outputs = outputs - - -class DelEvent(Event): - def __init__(self, var): - self.var = var - - -_current_trace = None - - -def get_trace() -> Trace: - global _current_trace - return _current_trace - - -def format_trace(trace): - buf = io.StringIO() - active_vars = set() - - def write(fmt, *args, **kwargs): - print(fmt.format(*args, **kwargs), file=buf) - - def init_vars(*args): - for i in args: - if i in active_vars: - continue - active_vars.add(i) - write("_{} = input()", i) - - for event in trace: - if isinstance(event, InputEvent): - init_vars(event.var) - elif isinstance(event, ReadDtypeEvent): - init_vars(event.var) - write("output(_{}.dtype)", event.var) - elif isinstance(event, ReadDeviceEvent): - init_vars(event.var) - write("output(_{}.device)", event.var) - elif isinstance(event, ReadShapeEvent): - init_vars(event.var) - write("output(_{}.shape)", event.var) - elif isinstance(event, ReadValueEvent): - init_vars(event.var) - write("output(_{}.dtype)", event.var) - elif isinstance(event, ReadValueEvent): - init_vars(event.var) - write("output(_{}.value)", event.var) - elif isinstance(event, OpEvent): - init_vars(*event.inputs) - active_vars.update(event.outputs) - ovars = ", ".join(map("_{}".format, event.outputs)) - ivars = ", ".join(map("_{}".format, event.inputs)) - if ovars: - write("{} = {}({})", ovars, repr(event.op), ivars) - else: - write("{}({})", repr(event.op), ivars) - elif isinstance(event, DelEvent): - init_vars(event.var) - write("del _{}", event.var) - else: - raise TypeError(type(event)) - - return buf.getvalue() - - -def compile_trace(trace): - trace = list(trace) - - -def static_function(f): - trace = None - - @functools.wraps(f) - def wrapper(*args, **kwargs): - nonlocal trace - if trace is None: - with Trace() as trace: - return f(*args, **kwargs) - return f(*args, **kwargs) - - return wrapper diff --git a/imperative/python/megengine/core/tensor/raw_tensor/trace_exec.py b/imperative/python/megengine/core/tensor/raw_tensor/trace_exec.py deleted file mode 100644 index d16a6ef06..000000000 --- a/imperative/python/megengine/core/tensor/raw_tensor/trace_exec.py +++ /dev/null @@ -1,263 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import functools -import weakref - -# Concepts -# -# * Internal tensor -# Tensor produced by the static sequence -# -# * External tensor -# Tensor not produced, but used as input, by the static sequence -# -# * Irrelevant tensor -# Tensor not present in input/output of any op -# -# * Escape -# An internal tensor is said to escape if it is still alive -# at the end of the sequence - -# JIT-ed execution -# -# 1. read attr (dtype, device, shape) -# a. internal tensor -# read out as soon as tensor is produced -# b. external or irrelevant tensor -# fallback -# -# 2. apply op -# bind external tensors in input -# -# 3. del - - -class Action: - pass - - -class ReadAttrAction(Action): - def __init__(self, var, name, getter): - self.var = var - self.name = name - self.getter = getter - - -class ReadValueAction(Action): - def __init__(self, var, getter): - self.var = var - self.getter = getter - - -class GetTensorAction(Action): - def __init__(self, var, getter): - self.var = var - self.getter = getter - - -class OpAction(Action): - def __init__(self, op, inputs, outputs, input_receivers): - self.op = op - self.inputs = inputs - self.outputs = outputs - self.input_receivers = input_receivers - - -class TensorAttr: - def __init__(self): - self.shape = None - self.dtype = None - self.device = None - - -class Bailout(Exception): - pass - - -class Fallback(Exception): - pass - - -def handle_bailout_fallback_finalize(f): - @functools.wraps(f) - def wrapper(self, impl, *args, **kwargs): - try: - return f(*args, **kwargs) - except Bailout: - self.bailout() - except Fallback: - pass - finally: - if self.pc == len(self): - self.finalize() - return impl(*args, **kwargs) - - return wrapper - - -class ExecTrajectory(list): - def __init__(self): - super().__init__() - self.reset() - - def __bool__(self): - return True - - def __enter__(self): - global _current_trajectory - if hasattr(self, "_prev_trajectory"): - raise RuntimeError - self._prev_trajectory = _current_trajectory - _current_trajectory = self - self._exited = False - return self - - def __exit__(self, *exc_info): - # cleanup should be done at completion, - # which is before exiting context manager - assert self._exited == (exc_info == (None, None, None)) - if not self._exited: - assert self.pc < len(self) - self.bailout() - - def _exit(self): - # clean up self and global varaible - assert not self._exited - self.reset() - - global _current_trajectory - if _current_trajectory is not self: - raise RuntimeError - _current_trajectory = self._prev_trajectory - del self._prev_trajectory - - def reset(self): - self._exited = True - self.pc = 0 - self.attr_cache = weakref.WeakKeyDictionary() - - ### Internal and External Tensor ### - # internal tensors are those produced by us - # external tensors are those received from outside - # during JIT-ed execution, internal tensors are just placeholders. - # var_to_tensor is the binding table for all tensors - self.var_to_tensor = {} # var -> weakref[tensor] - # tensor_to_var is the reverse binding table for internal tensors - # note that external tensors could map to >1 vars. - self.tensor_to_var = weakref.WeakKeyDictionary() - # internal tensor will be materialized if its .data is accessed from outside - # after being meterialized, an intern tensor is much like an external tensor - - def finalize(self): - assert self.pc == len(self) - self._exit() - - def bailout(self): - self._exit() - raise NotImplementedError - - def next_action(self): - assert not self._exited - assert self.pc < len(self) - return self[self.pc] - - @handle_bailout_fallback_finalize - def read_attr(self, tensor, name): - attrs = self.attr_cache.setdefault(tensor, TensorAttr()) - value = getattr(attrs, name, None) - if value is None: - action = self.next_action() - if not isinstance(action, ReadAttrAction): - raise Bailout - if name != action.name: - raise Bailout - value = action.getter() - setattr(attrs, name, value) - return value - - @handle_bailout_fallback_finalize - def read_value(self, impl, tensor): - # possibilities: - # 1. internal tensor - # 2. external tensor - # 3. irrelevant tensor (not an input / output of any op) - if tensor not in self.tensor_to_var: - raise Fallback - assert tensor._data is None - action = self.next_action() - if not isinstance(action, ReadValueAction): - raise Bailout - return action.getter() - - @handle_bailout_fallback_finalize - def apply_op(self, impl, op, *args): - from . import RawTensor - - action = self.next_action() - if not isinstance(action, OpAction): - raise Bailout - if len(args) != len(action.inputs): - raise Bailout - assert len(actions.inputs) == len(action.input_receivers) - - for v, t, r in zip(action.inputs, args, action.input_receivers): - if v in self.var_to_tensor: - assert r is None - if t is not self.var_to_tensor[v](): - raise Bailout - else: - # NOTE: not checking for aliasing (>=2 vars map to 1 tensor) - # the static execution backend must handle this - self.var_to_tensor[v] = weakref.ref(t) - r(t) - - outputs = [] - for v in action.outputs: - assert v not in self.var_to_tensor - t = RawTensor() - t._data_getter = functools.partial(self.get_data, v) - outputs.append(t) - self.var_to_tensor[v] = weakref.ref(t) - - return tuple(outputs) - - def get_data(self, var): - tensor = self.var_to_tensor[var]() - assert tensor is not None - assert tensor._data is None - assert tensor in self.tensor_to_var - action = self.next_action() - if not isinstance(action, GetTensorAction): - self.bailout() - elif action.var != var: - self.bailout() - else: - tensor._data = action.getter() - del tensor._data_getter - del self.tensor_to_var[tensor] - assert "_data_getter" not in tensor.__dict__ - return tensor._data_getter() - - -_current_trajectory = None - - -def get_trajectory(): - return _current_trajectory - - -def compile_trace(trace): - from .jit import ReadDTypeEvent, ReadDeviceEvent, ReadShapeEvent, OpEvent, DelEvent - - traj = ExecutionTrajectory() - active_vars = set() - - for event in trace: - if isinstance(event, ReadDTypeEvent): - traj.append(ReadAttrAction()) -- GitLab