diff --git a/imperative/python/megengine/core/_trace_option.py b/imperative/python/megengine/core/_trace_option.py index d5f65c34ee2adda4d66b2999f922baf8fb6b6079..638c142a12249cc9b7381b3c378d5b01f5b5ff9e 100644 --- a/imperative/python/megengine/core/_trace_option.py +++ b/imperative/python/megengine/core/_trace_option.py @@ -8,6 +8,8 @@ _use_symbolic_shape = False if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): _use_symbolic_shape = True +_use_xla_backend = False + def use_symbolic_shape() -> bool: r"""Returns whether tensor.shape returns a tensor instead of a tuple""" @@ -22,4 +24,15 @@ def set_symbolic_shape(option: bool): return _org +def use_xla_backend() -> bool: + return _use_xla_backend + + +def set_use_xla_backend(option: bool) -> bool: + global _use_xla_backend + _org = _use_xla_backend + _use_xla_backend = option + return _org + + set_cpp_use_symbolic_shape(use_symbolic_shape) diff --git a/imperative/python/megengine/jit/__init__.py b/imperative/python/megengine/jit/__init__.py index 56666ae08521b6b055228135be8b52797f48dd1e..e3ff5740ceec57e1e8780b5d0909e58a103294af 100644 --- a/imperative/python/megengine/jit/__init__.py +++ b/imperative/python/megengine/jit/__init__.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- from .dtr_config import DTRConfig from .graph_opt_config import GraphOptimizationConfig +from .partial_tracing import partial_trace from .sublinear_memory_config import SublinearMemoryConfig from .tracing import TraceError, exclude_from_trace, trace +from .xla_backend import xla_trace diff --git a/imperative/python/megengine/jit/partial_tracing.py b/imperative/python/megengine/jit/partial_tracing.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5240cec30da647bd95b3453431c2cd58ea9a14 --- /dev/null +++ b/imperative/python/megengine/jit/partial_tracing.py @@ -0,0 +1,224 @@ +from collections import OrderedDict +from typing import Sequence + +from ..core._imperative_rt.core2 import add_backward_callback as _add_backward_callback +from ..core._imperative_rt.core2 import get_grad_slot, get_handle_id +from ..tensor import Tensor +from .tracing import trace +from .xla_backend import xla_trace + + +def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map): + # partial_trace will record op sequences for forward/backward respectively, and get two TraceResult objects after tracing. + # But the inputs/outputs of backward graph are unknown. This function will determine the inputs and outputs of the backward graph + # var.handle_id is id of value ref. It's used to find the tensors used in both forward and backward calculation. + # inp_grad_map, key: handle id of forward inputs, value: handle id of grads of forward inputs. + # out_grad_map, key: handle id of foward outputs, value: handle id of grads of forward outputs. + fwd_features = set([t.handle_id for t in fwd._trace.vars]) + bwd_features = set([t.handle_id for t in bwd._trace.vars]) + keep_vars = fwd_features.intersection( + bwd_features + ) # some intermediate vars produced by forward, and will be used in backward. + current = max(fwd.out_list) + 1 + saved_feature_map = OrderedDict() + saved_featrues = [] + # mark keep_vars as forward outputs + for var in fwd._trace.vars: + if ( + var.handle_id in keep_vars + and var.data_required + and len(var.out_mark) == 0 + and var.kind not in ["const", "external"] + ): + keep_vars.remove(var.handle_id) + fwd._trace.mark_output(current, var.id) + saved_feature_map[var.handle_id] = current + saved_featrues.append(current) + current += 1 + fwd.keeped_activation = saved_featrues + + bwd_inp_idx = 0 + bwd_out_idx = 0 + bwd_dys = [] + bwd_inps = [-1] * len(saved_feature_map) + saved_feature_handle_id = list(saved_feature_map.keys()) + dy_ids = list(out_grad_map.values()) # handle_id of grad of forward output + inp_grad_ids = list(inp_grad_map.values()) # handle_id of grad of forward input + bwd_dys = [-1] * len(dy_ids) + bwd_outputs = [-1] * len(inp_grad_ids) + # dy_ids + saved_feature_map are backward inputs + # inp_grad_ids are backward outputs + # mark inputs/outputs for backward + for var in bwd._trace.vars: + if var.handle_id in dy_ids and var.kind == "external": + bwd._trace.mark_input(bwd_inp_idx, var.id) + idx = dy_ids.index(var.handle_id) + bwd_dys[idx] = bwd_inp_idx + bwd_inp_idx += 1 + elif var.handle_id in saved_feature_map and var.kind == "external": + bwd._trace.mark_input(bwd_inp_idx, var.id) + bwd_inps[saved_feature_handle_id.index(var.handle_id)] = bwd_inp_idx + bwd_inp_idx += 1 + if var.handle_id in inp_grad_ids and var.data_required: + bwd_outputs[inp_grad_ids.index(var.handle_id)] = bwd_out_idx + bwd._trace.mark_output(bwd_out_idx, var.id) + bwd_out_idx += 1 + # assert -1 not in bwd_dys + assert -1 not in bwd_inps + for var in fwd._trace.vars: + if not var.out_mark: + var.data_required = False + # assert -1 not in bwd_outputs + bwd.setup_io_without_trace(bwd_dys + bwd_inps, bwd_outputs) + bwd.setup_without_host() + + def check_external(trace_obj): + for var in trace_obj.vars: + if var.kind == "external" and not var.inp_mark: + raise RuntimeError("have unknown input in trace result") + + check_external(fwd) + check_external(bwd) + + +JIT_BACKEND = {"default": trace, "xla": xla_trace} + + +def partial_trace(func=None, *, backend="default", without_host=True, **trace_options): + assert backend in JIT_BACKEND + assert without_host, "partial_trace only support without_host mode currently!" + + def wrapper(func): + trace_obj = JIT_BACKEND[backend]( + func, without_host=without_host, **trace_options + ) + trace_options["capture_as_const"] = False + backward_trace_obj = JIT_BACKEND[backend]( + None, without_host=without_host, **trace_options + ) + backward_trace_obj.check_external = ( + False # check if there are unknown external vars after tracing. + ) + trace_obj.overall = False # if trace overall train step + backward_trace_obj.overall = False + trace_obj._trace.remove_unused_data_required = False + backward_trace_obj._trace.remove_unused_data_required = False + inp_grad_maps = OrderedDict() # x, dx map + out_grad_maps = OrderedDict() # y, dy map + traced = False # if wrapped function has been traced + compiled = False # if wrapped function has been compiled + custom_autodiff = None + outdef = None # treedef of forward return value + from ..core.autodiff.grad import Function + + class CustomAutodiff(Function): + def __init__(self, fwd, bwd): + self.fwd = fwd + self.bwd = bwd + del fwd.outdef + self.keeped_features = [] + + def forward(self, *args): + rst = self.fwd(*args) + keeped_features = rst[-1] + if not isinstance(keeped_features, Sequence): + keeped_features = tuple([keeped_features]) + else: + keeped_features = tuple(keeped_features) + self.keeped_features = keeped_features + return rst[0] + + def get_keeped_features(self): + rst = self.keeped_features + del self.keeped_features + return rst + + def backward(self, *output_grads): + output_grads = tuple([i for i in output_grads if i is not None]) + return self.bwd(*(output_grads + self.get_keeped_features())) + + class CustomFwd: + def __init__(self, fwd, bwd): + self.fwd = fwd + self.bwd = bwd + + def __call__(self, *args): + rst = self.fwd(*args) + if self.fwd.keeped_activation: + keeped_features = rst[-1] + if not isinstance(keeped_features, Sequence): + keeped_features = tuple([keeped_features]) + else: + keeped_features = tuple(keeped_features) + self.keeped_features = keeped_features + return rst[0] + else: + return rst + + def wrapped_func(*args, **kwargs): + from ..traced_module.pytree import tree_flatten + from ..module import Module + + nonlocal traced + nonlocal compiled + nonlocal custom_autodiff + nonlocal outdef + + if not traced: + traced = True + fargs = trace_obj.flatten_inputs(*args, **kwargs) + for t in fargs: + inp_grad_maps[t] = get_grad_slot(t) + del fargs + + def exit_trace(): + backward_trace_obj._trace.exit() + new_dict = {} + for k, v in inp_grad_maps.items(): + if v is not None: + new_dict[get_handle_id(k)] = get_handle_id(v.grad) + else: + new_dict[get_handle_id(k)] = -1 + inp_grad_maps.clear() + inp_grad_maps.update(new_dict) + + _add_backward_callback(exit_trace) + ret = trace_obj(*args) + rlist, outdef = tree_flatten(ret) + for t in rlist: + out_grad_maps[t] = get_grad_slot(t) + + def enter_trace(): + new_dict = {} + for k, v in out_grad_maps.items(): + if v is not None: + new_dict[get_handle_id(k)] = get_handle_id(v.grad) + out_grad_maps.clear() + out_grad_maps.update(new_dict) + backward_trace_obj._trace.enter() + + _add_backward_callback(enter_trace) + return ret + elif not compiled: + if custom_autodiff is None: + _process_fwd_bwd_trace_result( + trace_obj, backward_trace_obj, inp_grad_maps, out_grad_maps + ) + if len(backward_trace_obj._trace.ops) > 0: + custom_autodiff = CustomAutodiff(trace_obj, backward_trace_obj) + else: + custom_autodiff = CustomFwd(trace_obj, backward_trace_obj) + fargs = trace_obj.flatten_inputs(*args, **kwargs) + del args + del kwargs + if outdef is None: + return custom_autodiff(*fargs) + else: + return outdef.unflatten(custom_autodiff(*fargs)) + + return wrapped_func + + if func is None: + return wrapper + else: + return wrapper(func) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index f3a88eef23f7a464ec2eeab72409c4136b019056..6f2b22d52614ae78ccb757d78ca93430494e6cc2 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -9,6 +9,7 @@ import pickle import re import struct import sys +from collections import OrderedDict, defaultdict from typing import Any, Sequence import cv2 @@ -16,9 +17,22 @@ import numpy as np from .. import tensor from ..core import _imperative_rt as rt -from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata +from ..core._imperative_rt import ( + CompNode, + GraphProfiler, + GraphProfiler2, + SerializationMetadata, +) from ..core._imperative_rt.core2 import Tensor as RawTensor -from ..core._imperative_rt.core2 import Trace, TraceError, name_tensor # skip_tracing, +from ..core._imperative_rt.core2 import Trace, TraceError # skip_tracing, +from ..core._imperative_rt.core2 import add_backward_callback as _add_backward_callback +from ..core._imperative_rt.core2 import ( + get_marked_input_tensor, + get_marked_output_tensor, + get_marked_tensor, + marked_input_tensor, + name_tensor, +) from ..core._imperative_rt.graph import _set_priority_to_id from ..core._imperative_rt.ops import ( AssertEqual, @@ -31,6 +45,7 @@ from ..core._imperative_rt.ops import ( from ..core._trace_option import set_symbolic_shape from ..core.tensor import megbrain_graph as G from ..logger import get_logger +from ..tensor import Tensor from ..utils import comp_graph_tools as cgtools from ..utils.naming import AutoNaming from ..utils.profiler import is_profiling @@ -94,8 +109,13 @@ class trace: opt_level: optimization level for compiling trace. Default: 2 graph_opt_config: configuration for graph optimization. Default: None symbolic_shape: whether to use symbolic shape for tracing. Default: True + without_host: if True, will run python code of wrapped function on the first call, + and run the compiled graph/function on subsequent calls. if False, will run python code every time. + Default: False """ + third_party_backend = False + def __new__(cls, *args, **kwargs): if not args: return functools.partial(cls, **kwargs) @@ -113,6 +133,7 @@ class trace: opt_level: int = 2, graph_opt_config: GraphOptimizationConfig = None, symbolic_shape: bool = True, + without_host: bool = False, ): self.__wrapped__ = function self._capture_as_const = capture_as_const or record_only @@ -150,6 +171,7 @@ class trace: graph_options["graph_opt.jit_config.fuse_reduce"] = mapping[ graph_opt_config.jit_fuse_reduce ] + if sublinear_memory_config is not None: graph_options["enable_sublinear_memory_opt"] = True graph_options[ @@ -186,8 +208,114 @@ class trace: self._trace.profile = profiling self._trace.array_comparator = array_comparator self._trace.record_input_shapes = _input_node_use_static_shape() + self._trace.without_host = without_host + self.check_external = True + self.traced = False + self.input_num = 0 + self.output_num = 0 + self.arg_list = [] + self.out_list = [] + + self.overall = True + + # forward keeped activation + self.keeped_activation = [] + + self.third_party_backend_compiled = False + + @property + def check_external(self): + return self._trace.check_external + + @check_external.setter + def check_external(self, flag): + self._trace.check_external = flag + + @property + def without_host(self): + return self._trace.without_host + + def flatten_inputs(self, *args, **kwargs): + from ..traced_module.pytree import tree_flatten + from ..module import Module + + tensor_args = [] + modules = [] + fargs, _ = tree_flatten((args, kwargs)) + for a in fargs: + if isinstance(a, RawTensor): + tensor_args.append(a) + elif isinstance(a, Module): + modules.append(a) + for m in modules: + tensor_args.extend(list(m.parameters())) + return tensor_args + + def compile(self): + raise NotImplementedError + + def execute(self, *args, **kwargs): + raise NotImplementedError + + def setup_env(self): + pass + + def unset_env(self): + pass + + def compile_and_exec(self, *args, **kwargs): + if not self.third_party_backend_compiled: + self.compile() + self.third_party_backend_compiled = True + return self.execute(*args, **kwargs) + + def convert_optimizer_state_to_tensor(self, *args, **kwargs): + from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS + from ..optimizer import Optimizer + from ..tensor import Tensor + + if Optimizer not in SUPPORTED_LEAF_CLS: + SUPPORTED_LEAF_CLS.append(Optimizer) + args, _ = tree_flatten((args, kwargs)) + for arg in args: + if isinstance(arg, Optimizer): + arg._disable_type_convert = False + for param_group in arg.param_groups: + for k, v in param_group.items(): + if not isinstance(v, (Tensor, Sequence)): + param_group[k] = Tensor(v) + elif isinstance(v, Sequence) and not isinstance(v[0], Tensor): + new_v = [] + for i in range(len(v)): + new_v.append(Tensor(v[i])) + param_group[k] = new_v + + def setup_io_without_trace(self, inputs, outputs): + self.traced = True + self.arg_list = [i for i in inputs if i != -1] + self.out_list = outputs + self.input_num = len(self.arg_list) + self.output_num = len([i for i in outputs if i != -1]) + + def setup_without_host(self): + self.inp_modules = set() + self.module_tensors = set() + self.tensor_to_attr = dict() + self.attr_to_key = dict() + self.update_param_dict = dict() + self.update_opt_param_dict = dict() + self.capture_optimizer_state = set() + self.opt_param_dict = dict() def __call__(self, *args, **kwargs): + if not self.without_host: + return self.trace_normal(*args, **kwargs) + elif self.overall: + return self.trace_without_host_overall(*args, **kwargs) + else: + return self.trace_without_host(*args, **kwargs) + + def trace_normal(self, *args, **kwargs): global active_trace symbolic_shape = None outputs = None @@ -214,6 +342,270 @@ class trace: raise return outputs + def trace_without_host(self, *args, **kwargs): + from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS + from ..module import Module + from ..utils.module_utils import get_expand_structure + from ..tensor import Tensor + from ..optimizer import Optimizer + + assert self.without_host and not self.overall + global active_trace + symbolic_shape = None + outputs = None + if self.traced and self.third_party_backend: + return self.compile_and_exec(*args, **kwargs) + try: + active_trace = self + self._trace.enter() + if self._trace.compiled(): + arglist = self.flatten_inputs(*args, **kwargs) + idx = 0 + inp_dict = {} + for a in arglist: + if isinstance(a, RawTensor): + inp_dict[self.arg_list[idx]] = a + idx += 1 + self._trace.put_datas(inp_dict) + outlist = [] + for i in self.out_list: + if i == -1: + if not hasattr(self, "outdef"): + outlist.append(None) + else: + outlist.append(self._trace.get_data(i)) + keep_vars = [] + for i in self.keeped_activation: + keep_vars.append(self._trace.get_data(i)) + + outputs = ( + self.outdef.unflatten(outlist) + if hasattr(self, "outdef") + else outlist + ) + if keep_vars: + return outputs, keep_vars + else: + return outputs + + arg_list = self.flatten_inputs(*args, **kwargs) + for i, arg in enumerate(arg_list): + arg_list[i]._reset(get_marked_input_tensor(self.input_num, arg)) + self.arg_list.append(self.input_num) + self.input_num += 1 + del arg_list + symbolic_shape = set_symbolic_shape(self._symbolic_shape) + if self.third_party_backend: + self.setup_env() + outputs = self.__wrapped__(*args, **kwargs) + + finally: + handling_exc = sys.exc_info() != (None,) * 3 + active_trace = None + if symbolic_shape is not None: + symbolic_shape = set_symbolic_shape(symbolic_shape) + assert symbolic_shape == self._symbolic_shape + if self.third_party_backend: + self.unset_env() + if ( + self._capture_as_const + and (outputs is not None) + and not self.without_host + ): + self._process_outputs(outputs) + if not self._trace.compiled(): + outlist, self.outdef = tree_flatten(outputs) + for i, out in enumerate(outlist): + assert isinstance(out, RawTensor), type(out) + outlist[i] = get_marked_output_tensor(self.output_num, out) + del out + self.out_list.append(self.output_num) + self.output_num += 1 + outputs = self.outdef.unflatten(outlist) + try: + # may raise TraceError + self._trace.exit() + except Exception as e: + if isinstance(e, TraceError): + if not handling_exc: + raise + else: + self._trace.set_execption(str(e)) + raise + self.traced = True + return outputs + + def trace_without_host_overall(self, *args, **kwargs): + # record overall train step include forward, backward, param update in a single trace object + from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS + from ..module import Module + from ..utils.module_utils import get_expand_structure + from ..tensor import Tensor + from ..optimizer import Optimizer + + assert self.without_host + global active_trace + symbolic_shape = None + outputs = None + if self.traced and self.third_party_backend: + return self.compile_and_exec(*args, **kwargs) + try: + active_trace = self + if not self.traced: + self.convert_optimizer_state_to_tensor(*args, **kwargs) + self._trace.enter() + if self._trace.compiled(): + arglist, _ = tree_flatten((args, kwargs)) + idx = 0 + inp_dict = {} + for a in arglist: + if isinstance(a, RawTensor): + inp_dict[self.arg_list[idx]] = a + idx += 1 + for t, key in self.opt_param_dict.items(): + inp_dict[key] = t + self._trace.put_datas(inp_dict) + for attr, key in self.attr_to_key.items(): + param = get_expand_structure(attr[0], attr[1]) + self._trace.put_data(key, param) + outlist = [] + for i in self.out_list: + if i == -1: + if not hasattr(self, "outdef"): + outlist.append(None) + else: + outlist.append(self._trace.get_data(i)) + for attr, key in self.update_param_dict.items(): + param = get_expand_structure(attr[0], attr[1]) + param._reset(self._trace.get_data(key)) + for state, key in self.update_opt_param_dict.items(): + state._reset(self._trace.get_data(key)) + keep_vars = [] + for i in self.keeped_activation: + keep_vars.append(self._trace.get_data(i)) + + outputs = ( + self.outdef.unflatten(outlist) + if hasattr(self, "outdef") + else outlist + ) + if keep_vars: + return outputs, keep_vars + else: + return outputs + + self.setup_without_host() + + def get_attr_hook(obj, attr): + rst = object.__getattribute__(obj, attr) + if isinstance(rst, RawTensor): + assert rst in self.tensor_to_attr + attr = self.tensor_to_attr[rst] + if attr not in self.attr_to_key: + self.attr_to_key[attr] = self.input_num + self.input_num += 1 + marked_input_tensor(self.attr_to_key[attr], rst) + return rst + + origin_reset = Tensor._reset + self.update_param_num = 0 + + def tensor_wrapper_resethook(obj, other): + if obj in self.tensor_to_attr: + attr = self.tensor_to_attr[obj] + other = get_marked_output_tensor(self.output_num, other) + self.update_param_num += 1 + self.update_param_dict[attr] = self.output_num + self.output_num += 1 + elif obj in self.capture_optimizer_state: + other = get_marked_output_tensor(self.output_num, other) + self.update_opt_param_dict[obj] = self.output_num + self.output_num += 1 + origin_reset(obj, other) + + arg_list, self.argdef = tree_flatten((args, kwargs)) + for i, arg in enumerate(arg_list): + if isinstance(arg, Module): + for k, v in arg.named_tensors(): + if v not in self.tensor_to_attr: + self.tensor_to_attr[v] = (arg, k) + self.inp_modules.add(arg) + elif isinstance(arg, RawTensor): + arg_list[i] = get_marked_input_tensor(self.input_num, arg) + self.arg_list.append(self.input_num) + self.input_num += 1 + elif isinstance(arg, Optimizer): + opt_params, _ = tree_flatten(arg.state_dict(keep_var=True)) + for p in opt_params: + if isinstance(p, Tensor): + self.capture_optimizer_state.add(p) + self.opt_param_dict = {} + for t in self.capture_optimizer_state: + if t not in self.tensor_to_attr: # not module parameter + mark_param = get_marked_input_tensor(self.input_num, t) + self.opt_param_dict[t] = self.input_num + t[...] = mark_param + self.input_num += 1 + args, kwargs = self.argdef.unflatten(arg_list) + Module.__getattribute__ = get_attr_hook + Tensor._reset = tensor_wrapper_resethook + symbolic_shape = set_symbolic_shape(self._symbolic_shape) + if self.third_party_backend: + self.setup_env() + outputs = self.__wrapped__(*args, **kwargs) + del arg_list + del args + del kwargs + + Module.__getattribute__ = object.__getattribute__ + Tensor._reset = origin_reset + + for attr, key in self.attr_to_key.items(): + param = get_expand_structure(attr[0], attr[1]) + finally: + handling_exc = sys.exc_info() != (None,) * 3 + active_trace = None + if symbolic_shape is not None: + symbolic_shape = set_symbolic_shape(symbolic_shape) + assert symbolic_shape == self._symbolic_shape + if self.third_party_backend: + self.unset_env() + if ( + self._capture_as_const + and (outputs is not None) + and not self.without_host + ): + self._process_outputs(outputs) + if not self._trace.compiled(): + outlist, self.outdef = tree_flatten(outputs) + for i, out in enumerate(outlist): + assert isinstance(out, RawTensor) + outlist[i] = get_marked_output_tensor(self.output_num, out) + del out + self.out_list.append(self.output_num) + self.output_num += 1 + outputs = self.outdef.unflatten(outlist) + try: + # may raise TraceError + self._trace.exit() + except Exception as e: + if isinstance(e, TraceError): + if not handling_exc: + raise + else: + self._trace.set_execption(str(e)) + raise + self.traced = True + return outputs + + @property + def ops(self): + return self._trace.ops + + @property + def vars(self): + return self._trace.vars + def _process_inputs(self, *args, **kwargs): for i, arg in enumerate(args): assert isinstance( diff --git a/imperative/python/megengine/jit/xla_backend.py b/imperative/python/megengine/jit/xla_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..31a47a149da77a98a785ad84daac86490c3fd323 --- /dev/null +++ b/imperative/python/megengine/jit/xla_backend.py @@ -0,0 +1,201 @@ +from collections import OrderedDict, defaultdict + +from ..core._imperative_rt import CompNode +from ..core._imperative_rt.core2 import Tensor as RawTensor +from ..core._trace_option import set_use_xla_backend +from ..device import get_default_device +from ..utils.dlpack import from_dlpack, to_dlpack +from .tracing import trace + +try: + from ..xla.lib import xla_client as xc +except ImportError: + pass + + +class xla_trace(trace): + r"""Wraps a callable, and provides accelerated evaluation compiled by xla. + Currently it is an experimental feature. + Refer to :class:`~.jit.tracing.trace` for more information. + + + Examples: + + .. code-block:: python + + import numpy as np + from basecls.models.resnet import resnet18 + from megengine.autodiff.grad_manager import GradManager + from megengine.jit import xla_trace + from megengine.optimizer import Adam + + model = resnet18() + gm = GradManager() + opt = Adam(model.parameters(), lr=1e-4) + gm.attach(model.parameters()) + + # Only tensors in wrapped func args/kwargs will be treated as graph inputs, + # and other tensors will be captured as const value. + # Module, optimizer, and train data/label should be arguments of the wrapped function. + @xla_trace(capture_as_const=True) + def train_step(model, opt, data, label): + with gm: + pred = model(data) + loss = F.loss.cross_entropy(pred, label) + gm.backward(loss) + opt.step().clear_grad() + return loss + + """ + + third_party_backend = True + + def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs): + assert without_host, "xla trace only support without host mode" + assert not symbolic_shape, "xla doesn't support dynamic shape currently" + super().__init__( + function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs + ) + + def setup_env(self): + self.orig_use_xla = set_use_xla_backend(True) + + def unset_env(self): + set_use_xla_backend(self.orig_use_xla) + + def compile(self): + from ..xla import build_xla + from ..traced_module.pytree import SUPPORTED_LEAF_TYPE, register_supported_type + from ..utils.module_utils import get_expand_structure + from ..xla.device import get_xla_backend_and_device + from ..tensor import Tensor + + assert self.traced + if self.overall: + for attr, _ in self.attr_to_key.items(): + param = get_expand_structure(attr[0], attr[1]) + param._reset(param.to("cpux")) + + for tensor, _ in self.opt_param_dict.items(): + tensor._reset(tensor.to("cpux")) + self.xla_exec, self.inp_ids, self.out_ids = build_xla( + self, return_with_io=True, return_device_array=True + ) + id2inpidx = defaultdict(list) + id2outidx = defaultdict(list) + for idx, id in enumerate(self.inp_ids): + id2inpidx[id].append(idx) + for idx, id in enumerate(self.out_ids): + id2outidx[id].append(idx) + self.inpkey2idx = {} + self.outkey2idx = {} + if self.input_num == len(set(self.inp_ids)) - 1: + self.has_randomstate = True + self.random_seed = Tensor([[1, 2], [3, 4]], dtype="int32") + else: + assert self.input_num == len(set(self.inp_ids)) + self.has_randomstate = False + inpmark2id = dict() + outmark2id = dict() + for var in self.vars: + if var.kind == "external": + for mark in var.inp_mark: + inpmark2id[mark] = var.id + elif var.data_required and var.out_mark: + for mark in var.out_mark: + outmark2id[mark] = var.id + for k, v in inpmark2id.items(): + for idx in id2inpidx[v]: + self.inpkey2idx[k] = idx + + for k, v in outmark2id.items(): + for idx in id2outidx[v]: + self.outkey2idx[k] = idx + + def prepare_xla_inputs(self, tensors): + from ..utils.module_utils import get_expand_structure + + inp_count = 0 + inp_list = [0] * self.input_num + for idx, t in enumerate(tensors): + inp = self.inpkey2idx[self.arg_list[idx]] + inp_list[inp] = t + inp_count += 1 + if self.overall: + for attr, key in self.attr_to_key.items(): + param = get_expand_structure(attr[0], attr[1]) + inp = self.inpkey2idx[key] + inp_list[inp] = param + inp_count += 1 + for tensor, k in self.opt_param_dict.items(): + inp = self.inpkey2idx[k] + inp_list[inp] = tensor + inp_count += 1 + assert inp_count == self.input_num + if self.has_randomstate: + inp_list.append(self.random_seed) + return inp_list + + def to_dlpack(self, x, take_ownership: bool = True): + return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership) + + def execute(self, *args, **kwargs): + from ..traced_module.pytree import tree_flatten + from ..tensor import Tensor + from ..utils.module_utils import get_expand_structure + + inputs, _ = tree_flatten((args, kwargs)) + arrays = [] + cn = CompNode(get_default_device()) + stream = dict(self.xla_exec.backend.get_compute_compnode()) + device_kind, device_id, stream_id = cn.physical_locator + + xla_stream = stream[device_id] + xla_comp_cn = "gpu{}:{}".format(device_id, xla_stream) + for t in inputs: + if isinstance(t, RawTensor): + assert cn == t.device + arrays.append(t.to(xla_comp_cn, _borrow=True)) + + arrays = self.prepare_xla_inputs(arrays) + outputs = self.xla_exec(*arrays) + return_vals = [] + for i in self.out_list: + if i == -1: + if not hasattr(self, "outdef"): + return_vals.append(None) + else: + return_vals.append(outputs[self.outkey2idx[i]]) + keeped_features = [] + for i in self.keeped_activation: + capsule = self.to_dlpack(outputs[self.outkey2idx[i]]) + t = from_dlpack(capsule, xla_stream).to(cn, _borrow=True) + keeped_features.append(t) + out_tensors = [] + for array in return_vals: + if array is not None: + capsule = self.to_dlpack(array) + t = from_dlpack(capsule, xla_stream) + out_tensors.append(t.to(cn, _borrow=True)) + else: + out_tensors.append(array) + if self.overall: + for attr, key in self.update_param_dict.items(): + param = get_expand_structure(attr[0], attr[1]) + xla_array = outputs[self.outkey2idx[key]] + capsule = self.to_dlpack(xla_array) + param._reset(from_dlpack(capsule).to(cn, _borrow=True)) + + for state, key in self.update_opt_param_dict.items(): + xla_array = outputs[self.outkey2idx[key]] + capsule = self.to_dlpack(xla_array) + state._reset(from_dlpack(capsule).to(cn, _borrow=True)) + rst = ( + self.outdef.unflatten(out_tensors) + if hasattr(self, "outdef") + else out_tensors + ) + if keeped_features: + return rst, keeped_features + else: + return rst diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 8dcd8d5316395673cf20be95e1a2bcf9e3aab850..6b170541c94cdc2aa2d1d202ab8eb73dc1fcf2be 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -49,6 +49,8 @@ def _access_structure(obj, key, callback=None): cur = cur[k] else: cur = getattr(cur, k) + if callable is None: + return cur return callback(parent, k, cur) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index a6085aef1b405939551452c4c879e44e20901307..03d2b2c005b53a658af2e629ce2349cb9d6474c2 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -115,11 +115,9 @@ void _set_priority_to_id(const std::vector& dest_vars) { } py::object Py_Varnode = py::none(); - +const std::unique_ptr _imperative_sm_opr_footprint_ptr{ + std::make_unique()}; void init_graph_rt(py::module m) { - static const std::unique_ptr _imperative_sm_opr_footprint_ptr{ - std::make_unique()}; - def_rendezvous(m, "DeviceTensorNDRendezvous"); def_rendezvous(m, "HostTensorNDRendezvous"); diff --git a/imperative/python/src/graph_rt.h b/imperative/python/src/graph_rt.h index 8dd004e22cd1c3c941471f5fd741a3c7b69fb1f1..423b44c35376385da493b4b279eaec3fad3bf64d 100644 --- a/imperative/python/src/graph_rt.h +++ b/imperative/python/src/graph_rt.h @@ -10,7 +10,7 @@ namespace py = pybind11; extern py::object Py_Varnode; - +extern const std::unique_ptr _imperative_sm_opr_footprint_ptr; template class GraphNodePtr { std::shared_ptr m_graph; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index fa71f37745f493e9ef0fd7f2c738320a26b4e30b..194829cd4979828e21e76124c09473cb6837c02d 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -5,6 +5,7 @@ #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/backward_graph.h" +#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/transformation.h" @@ -50,6 +51,8 @@ #include "../../src/impl/mgb_cg_impl.h" #include "./backtrace.h" +#include + namespace py = pybind11; namespace views = ranges::views; @@ -729,6 +732,10 @@ PyObject* TensorWrapper::isscalar() { } } +PyObject* TensorWrapper::value_id() { + return py::cast(m_tensor->value_id()).release().ptr(); +} + PyObject* TensorWrapper::_var() { TypedValueRef value = imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref(); @@ -863,6 +870,10 @@ void init_tensor(py::module m) { .register_at( std::make_shared()) .release()); + MGB_MARK_USED_VAR(transformations + .register_at( + std::make_shared()) + .release()); auto format_trans = std::make_shared(); MGB_MARK_USED_VAR( transformations.register_at(format_trans).release()); @@ -919,6 +930,7 @@ void init_tensor(py::module m) { .def<&TensorWrapper::_watch>("_watch") .def<&TensorWrapper::_var>("var") .def<&TensorWrapper::_graph>("graph") + .def<&TensorWrapper::value_id>("value_id") .def_getset< &TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node") @@ -1081,6 +1093,14 @@ void init_tensor(py::module m) { sync_py_task_q(); }); + py::class_(m, "GradSlot") + .def_property_readonly("grad", [](GradSlotPtr& self) -> py::object { + if (self->grad()) + return TensorWrapper::make(py_tensor_type, self->grad()); + else + return py::none(); + }); + // GradTransformation py::handle grad_key_type = GradKeyWrapper::wrap_t::type() @@ -1098,6 +1118,29 @@ void init_tensor(py::module m) { py::setattr(m, "GradKey", grad_key_type); m.def("backward", &GradKeyWrapper::backward); m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure); + m.def("get_grad_slot", [](py::object tensor) -> py::object { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + if (tw) { + auto rst = imperative::apply(GetGradSlot(), tw->m_tensor->data()); + if (rst.size() == 1) { + GradSlotPtr slot = rst[0].cast(); + return py::cast(slot); + } else { + return py::none(); + } + } + + return py::none(); + }); + m.def("get_handle_id", [](py::object tensor) -> py::object { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + if (tw) { + auto rst = imperative::apply(GetId(), tw->m_tensor->data()); + int id = rst[0].cast(); + return py::cast(id); + } + return py::none(); + }); m.def("set_py_tensor_type", [](py::object type_obj) { py_tensor_type = reinterpret_cast(type_obj.inc_ref().ptr()); @@ -1120,6 +1163,9 @@ void init_tensor(py::module m) { bool capture_as_const = false; bool profile = false; bool record_input_shapes = false; + bool without_host = false; + bool check_external = true; + bool remove_unused_data_required = true; py::function options_visitor; std::shared_ptr tracing; std::shared_ptr compiled; @@ -1130,6 +1176,8 @@ void init_tensor(py::module m) { std::unique_ptr> tracing_guard; std::unique_ptr> compiled_guard; std::unique_ptr> lazy_eval_guard; + std::unordered_map inpmark_to_id; + std::unordered_map outmark_to_id; bool compare_value(ValueRef lhs, ValueRef rhs) { auto lvalue = lhs.cast_ref(); @@ -1149,11 +1197,23 @@ void init_tensor(py::module m) { return array_comparator(larr, rarr); } + void mark_input(size_t mark, size_t id) { + trace_result->vars[id].inp_marker.insert(mark); + mgb_assert(inpmark_to_id.find(mark) == inpmark_to_id.end()); + inpmark_to_id[mark] = id; + } + void mark_output(size_t mark, size_t id) { + trace_result->vars[id].out_marker.insert(mark); + mgb_assert(outmark_to_id.find(mark) == outmark_to_id.end()); + outmark_to_id[mark] = id; + } void enter() { auto& self = *this; if (!self.trace_result) { // untraced self.tracing = std::make_shared( self.capture_as_const, self.record_input_shapes); + if (self.without_host) + self.tracing->enable_record_all_shapes(); if (self.symbolic) { self.lazy_eval = std::make_shared(self.no_exec); @@ -1183,8 +1243,11 @@ void init_tensor(py::module m) { std::make_shared(¤t_graph)); } } - compiled_guard = - transformations.register_at(self.compiled); + if (!without_host) + compiled_guard = + transformations.register_at(self.compiled); + else + self.compiled->set_pc_to_end(); // start execute because InputCallback depends self.compiled->execute(); } else if (self.tracing) { @@ -1203,7 +1266,31 @@ void init_tensor(py::module m) { auto& self = *this; if (self.tracing) { tracing_guard.reset(); + if (self.without_host) { + self.tracing->postprocess_trace_result(); + self.inpmark_to_id = self.tracing->inpmark_to_id; + self.outmark_to_id = self.tracing->outmark_to_id; + } self.trace_result = self.tracing->get_result(); + if (self.without_host) { + for (auto&& var : self.trace_result->vars) { + var.shape_required = false; + var.value_required = false; + if (var.data_required && var.out_marker.empty() && + remove_unused_data_required) + var.data_required = false; + if (var.inp_marker.empty() && + var.kind == TraceResult::VarInfo::Kind::External) { + if (var.bound_data) { + var.kind = TraceResult::VarInfo::Kind::Constant; + } else if (self.check_external) { + throw std::runtime_error( + "have some unknown input tensors in trace " + "result"); + } + } + } + } self.tracing.reset(); if (self.lazy_eval) { auto lazy_eval = std::move(self.lazy_eval); @@ -1211,7 +1298,8 @@ void init_tensor(py::module m) { lazy_eval->check_exception(); } } else if (self.compiled) { - compiled_guard.reset(); + if (!without_host) + compiled_guard.reset(); self.compiled->wait(); } else { mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled"); @@ -1255,6 +1343,10 @@ void init_tensor(py::module m) { .def_readwrite("record_input_shapes", &Trace::record_input_shapes) .def_readwrite("array_comparator", &Trace::array_comparator) .def_readwrite("profile", &Trace::profile) + .def_readwrite("without_host", &Trace::without_host) + .def_readwrite("check_external", &Trace::check_external) + .def_readwrite( + "remove_unused_data_required", &Trace::remove_unused_data_required) .def_property_readonly( "options", [](Trace& self) { @@ -1281,6 +1373,65 @@ void init_tensor(py::module m) { .def("enter", &Trace::enter) .def("exit", &Trace::exit) .def("dump", &Trace::dump) + .def("set_execption", + [](Trace& self, std::string error) { + if (self.compiled) { + auto exc = std::make_exception_ptr(TraceError(error)); + self.compiled->set_exception(exc); + } + }) + .def("compiled", [](Trace& self) { return bool(self.compiled); }) + .def("put_data", + [](Trace& self, int mark, py::object tensor) { + auto id = self.inpmark_to_id[mark]; + auto&& varinfo = self.trace_result->vars[id]; + mgb_assert(varinfo.kind == TraceResult::VarInfo::Kind::External); + auto&& accessor = self.compiled->get_accessor_by_id(id); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw); + accessor.data_setter( + tw->m_tensor->data().dev_tensor()->as_nd(true)); + }) + .def("put_datas", + [](Trace& self, std::unordered_map inps) { + for (auto&& inp : inps) { + auto&& mark = inp.first; + auto&& tensor = inp.second; + auto id = self.inpmark_to_id[mark]; + auto&& varinfo = self.trace_result->vars[id]; + mgb_assert( + varinfo.kind == TraceResult::VarInfo::Kind::External); + auto&& accessor = self.compiled->get_accessor_by_id(id); + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + mgb_assert(tw); + accessor.data_setter( + tw->m_tensor->data().dev_tensor()->as_nd(true)); + } + }) + .def("get_data", + [](Trace& self, int mark) { + auto id = self.outmark_to_id[mark]; + auto&& varinfo = self.trace_result->vars[id]; + mgb_assert(varinfo.data_required); + auto&& accessor = self.compiled->get_accessor_by_id(id); + mgb_assert(accessor.data_getter); + auto dev_value = DeviceValue::make(accessor.data_getter()); + return TensorWrapper::make( + py_tensor_type, + imperative::apply( + CreateTensor( + CreateTensor::Common, dev_value->device(), + dev_value->dtype(), dev_value->shape()), + DeviceStorage::make(dev_value->storage()))[0]); + }) + .def_property_readonly( + "ops", [](Trace& self) { return self.trace_result->seq; }) + .def_property_readonly( + "vars", [](Trace& self) { return self.trace_result->vars; }) + .def_property_readonly( + "inpmark_to_id", [](Trace& self) { return self.inpmark_to_id; }) + .def_property_readonly( + "outmark_to_id", [](Trace& self) { return self.outmark_to_id; }) .def("begin_excluded_region", [](Trace& self) { mgb_assert(bool(self.tracing) ^ bool(self.compiled)); @@ -1290,17 +1441,137 @@ void init_tensor(py::module m) { self.compiled_guard.reset(); } }) - .def("end_excluded_region", [](Trace& self) { - mgb_assert(bool(self.tracing) ^ bool(self.compiled)); - if (self.tracing) { - self.tracing_guard = - transformations.register_at(self.tracing); - } else if (self.compiled) { - self.compiled_guard = - transformations.register_at(self.compiled); + .def("end_excluded_region", + [](Trace& self) { + mgb_assert(bool(self.tracing) ^ bool(self.compiled)); + if (self.tracing) { + self.tracing_guard = + transformations.register_at( + self.tracing); + } else if (self.compiled) { + self.compiled_guard = + transformations.register_at( + self.compiled); + } + }) + .def("mark_output", &Trace::mark_output) + .def("mark_input", &Trace::mark_input); + using VarInfo = TraceResult::VarInfo; + using OpKind = TraceResult::SeqItem::OpKind; + std::unordered_map kind2str = { + {VarInfo::Kind::Internal, "internal"}, + {VarInfo::Kind::External, "external"}, + {VarInfo::Kind::Constant, "const"}, + }; + std::unordered_map opkind2str = { + {OpKind::Unknown, "unknown"}, + {OpKind::TraceMarkVar, "trace_mark_var"}, + {OpKind::IOMarkVar, "io_mark_var"}, + {OpKind::CreateTensor, "create_tensor"}, + {OpKind::Rename, "rename"} + + }; + py::class_(m, "VarInfo") + .def_property_readonly("shape", [](VarInfo& self) { return self.shape; }) + .def_property_readonly( + "value_required", [](VarInfo& self) { return self.value_required; }) + .def_property_readonly( + "shape_required", [](VarInfo& self) { return self.shape_required; }) + .def_readwrite("data_required", &VarInfo::data_required) + .def("set_external", + [](VarInfo& self) { self.kind = VarInfo::Kind::External; }) + .def_property_readonly( + "bound_data", + [](VarInfo& self) -> py::object { + if (self.bound_data) + return py::reinterpret_steal( + npy::ndarray_from_tensor( + self.bound_data.numpy()->as_nd(true), + npy::ShareType::TRY_SHARE)); + return py::none(); + }) + .def_property_readonly( + "dtype", + [](VarInfo& self) { + auto ret = static_cast(*self.dtype); + if (ret == dtype::Byte()) { + ret = dtype::Uint8(); + } + return ret; + }) + .def_property_readonly( + "device", + [](VarInfo& self) { return static_cast(*self.device); }) + .def_property_readonly("id", [](VarInfo& self) { return self.id; }) + .def_property_readonly( + "handle_id", [](VarInfo& self) { return self.handle_id; }) + .def_property_readonly("name", [](VarInfo& self) { return self.name; }) + .def_property_readonly("mark", [](VarInfo& self) { return self.mark; }) + .def_property_readonly( + "inp_mark", [](VarInfo& self) { return self.inp_marker; }) + .def_property_readonly( + "out_mark", [](VarInfo& self) { return self.out_marker; }) + .def_property_readonly("kind", [kind2str](VarInfo& self) { + return kind2str.find(self.kind)->second; + }); + using SeqItem = TraceResult::SeqItem; + auto json = py::module::import("json"); + + py::class_(m, "OpInfo") + .def(py::init([opkind2str]( + std::shared_ptr op, + const SmallVector& inputs, + const SmallVector& outputs, + const std::string& op_kind) { + SeqItem::OpKind enum_op_kind = SeqItem::OpKind::Unknown; + for (auto&& kv : opkind2str) { + if (op_kind == kv.second) { + enum_op_kind = kv.first; + } + } + return SeqItem{op, inputs, outputs, enum_op_kind}; + })) + .def_property_readonly( + "op", + [opkind2str](SeqItem& self) -> py::object { + if (self.op) { + if (auto* op = self.op->try_cast_final()) { + return py::cast(op->type); + } + return py::cast(self.op); + } else + return py::cast(opkind2str.find(self.kind)->second); + }) + .def_property_readonly("inputs", [](SeqItem& self) { return self.inputs; }) + .def_property_readonly( + "outputs", [](SeqItem& self) { return self.outputs; }) + .def_property_readonly( + "type", + [opkind2str](SeqItem& self) -> py::object { + if (self.op) + return py::cast(self.op->type_name()); + else + return py::cast(opkind2str.find(self.kind)->second); + }) + .def_property_readonly( + "kind", + [opkind2str](SeqItem& self) { + return opkind2str.find(self.kind)->second; + }) + .def_property_readonly("param", [json](SeqItem& self) -> py::object { + if (self.op) { + if (auto* op = self.op->try_cast_final()) { + auto param = + op->mgb_param(_imperative_sm_opr_footprint_ptr.get()) + ->to_string(); + return json.attr("loads")(py::cast(param)); + } else { + auto pyop = py::cast(self.op); + return pyop.attr("__getstate__")(); + } } + return py::dict(); }); - m.def("name_tensor", [](std::string name, py::object tensor) { auto* tw = TensorWrapper::try_cast(tensor.ptr()); mgb_assert(tw, "Arg_1 shoud be Tensor!"); @@ -1308,6 +1579,33 @@ void init_tensor(py::module m) { tw->m_tensor->reset(output); }); + m.def("get_marked_tensor", [](std::string name, py::object tensor) { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; + return TensorWrapper::make(py_tensor_type, output); + }); + + m.def("get_marked_input_tensor", [](int mark, py::object tensor) { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto output = imperative::apply( + IOMarkVar(mark, IOMarkVar::Kind::Input), tw->m_tensor->data())[0]; + return TensorWrapper::make(py_tensor_type, output); + }); + + m.def("marked_input_tensor", [](int mark, py::object tensor) { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto output = imperative::apply( + IOMarkVar(mark, IOMarkVar::Kind::Input), tw->m_tensor->data())[0]; + tw->m_tensor->reset(output); + }); + + m.def("get_marked_output_tensor", [](int mark, py::object tensor) { + auto* tw = TensorWrapper::try_cast(tensor.ptr()); + auto output = imperative::apply( + IOMarkVar(mark, IOMarkVar::Kind::Output), tw->m_tensor->data())[0]; + return TensorWrapper::make(py_tensor_type, output); + }); + m.def("is_grad_attached", [](std::vector tensors) -> bool { SmallVector values(tensors.size()); for (size_t i = 0; i < tensors.size(); ++i) { @@ -1379,6 +1677,17 @@ void init_tensor(py::module m) { return wrapped_outputs; }); + m.def("add_backward_callback", [](py::function callback) { + ValueRef id = IntegerValue::make(0); + GenericFunction generic_function = + [callback](Span inputs) -> ValueRefList { + callback(); + return {}; + }; + auto output_values = + imperative::apply(InsertGradCallback(generic_function), id); + }); + // ModuleTraceTransformation static py::function module_trace_hook; diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 477832347aaa8b2e461a779b62f2ecd63526d98b..945c6321924fd41c7a544f509834a279e970e95d 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -74,6 +74,7 @@ public: inline ValueRef data() const { return m_data.unwrap(); } bool is_scalar() { return data().is_scalar(); } inline std::string name() { return m_name; } + inline size_t value_id() { return m_data.id(); } inline void set_name(std::string name) { m_name = name; if (!name.empty()) { @@ -128,6 +129,7 @@ public: void reset(PyObject*); PyObject* detach(); PyObject* isscalar(); + PyObject* value_id(); PyObject* _dev_tensor(); void _drop(); PyObject* varnode(); diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index f0eb376db2eb4309a85d542759886ad1fc0951f8..6c68d52c29ee833ddfd6366f04803b385b154da3 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -18,7 +18,13 @@ from megengine.core.ops import builtin as ops from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.utils import isscalar from megengine.functional import exp, log -from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace +from megengine.jit import ( + GraphOptimizationConfig, + TraceError, + exclude_from_trace, + partial_trace, + trace, +) from megengine.module import Module from megengine.random import normal, uniform from megengine.utils.naming import AutoNaming @@ -803,3 +809,87 @@ def test_dump_without_output_error(): str(e) == "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]" ) + + +@pytest.mark.parametrize("trace_mode", [False, True]) +def test_trace_without_host(trace_mode): + @trace(symbolic=trace_mode, without_host=True) + def fwd(a, b, c): + x = a + b + y = a + c + z = x * y + z1 = x / y + return [z, z1] + + a = tensor([1.0]) + b = tensor([2.0]) + c = tensor([3.0]) + rst = fwd(a, b, c) + for _ in range(2): + trace_rst = fwd(a, b, c) + np.testing.assert_equal(rst[0], trace_rst[0]) + np.testing.assert_equal(rst[1], trace_rst[1]) + + +def test_trace_without_error(): + const = tensor([8.0]) + + @trace(symbolic=False, without_host=True) + def fwd(a, b, c): + x = a + b + y = a + c + z = x * y + z1 = x / y + const + return [z, z1] + + try: + a = tensor([1.0]) + b = tensor([2.0]) + c = tensor([3.0]) + fwd(a, b, c) + except Exception as e: + assert str(e) == "have some unknown input tensors in trace result" + else: + assert False + + +def test_partial_trace_fwd_bwd(): + class Simple(Module): + def __init__(self): + super().__init__() + self.a = Parameter([1.0], dtype=np.float32) + self.b = Parameter([2.0], dtype=np.float32) + + @partial_trace + def forward(self, x): + x = x * self.a + x / self.b + x = F.exp(x) + return x + + def clear_grad(self): + self.a.grad = None + self.b.grad = None + + @partial_trace + def fwd_only(a, b): + return a * b + a / b + + m = Simple() + gm = GradManager() + gm.attach(m.parameters()) + + def func(x): + with gm: + x = x * 3 + x = m(x) + x = x * 2 + gm.backward(x) + a = m.a.grad + b = m.b.grad + m.clear_grad() + return fwd_only(a, b) + a + b + + gt = func(tensor(1.0)) + for _ in range(3): + out = func(tensor(1.0)) + np.testing.assert_equal(gt.numpy(), out.numpy()) diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index 52f48b64c0b2a8d43393bc61a540def485f13484..e3c4c034cb34712264b914cb6d7b784406766047 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -105,6 +105,10 @@ std::string IsScalar::to_string() const { return "IsScalar"; } +std::string GetId::to_string() const { + return "GetId"; +} + std::string GetFormat::to_string() const { return "GetFormat{}"; } diff --git a/imperative/src/impl/basic_values.cpp b/imperative/src/impl/basic_values.cpp index 01b8b3e70ac22b40f57735d195f1406bef2486a3..9f0dbecc5ad761ebe993982b10fcd1b5d80bade2 100644 --- a/imperative/src/impl/basic_values.cpp +++ b/imperative/src/impl/basic_values.cpp @@ -15,6 +15,10 @@ std::string BoolValue::to_string() const { return (*this) ? "true" : "false"; } +std::string IntegerValue::to_string() const { + return std::to_string((int)*this); +} + std::string HostStorage::to_string() const { return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str()); } diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 7fbc3ca612d580971141d8ae24a6a5204fdbf509..e84862cf9047b4163af7af264d7099ffa0f3d09b 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -142,6 +142,10 @@ const std::string OpDef::make_name() const { return m_scope + "." + trait()->make_name(*this); } +const std::string OpDef::type_name() const { + return trait()->name; +} + static thread_local OpDef::allocator_t local_allocator; void OpDef::set_allocator(allocator_t allocator) { diff --git a/imperative/src/impl/ops/opr_attr.cpp b/imperative/src/impl/ops/opr_attr.cpp index 30c416680900d232d1e102588d676bf7ee22d973..d3b30f0d841621bfd0dfe3668e6605ce5bf0185a 100644 --- a/imperative/src/impl/ops/opr_attr.cpp +++ b/imperative/src/impl/ops/opr_attr.cpp @@ -13,6 +13,10 @@ namespace imperative { namespace { class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD { +public: + bool strict = true; + +private: const OprAttr::Param& m_param; size_t m_pos = 0; ComputingGraph* m_graph; @@ -40,7 +44,8 @@ public: m_graph(graph) {} ~OprParamsLoadContext() { - mgb_assert(m_pos == m_param.size(), "param not fully consumed"); + if (strict) + mgb_assert(m_pos == m_param.size(), "param not fully consumed"); } ComputingGraph& graph() override { return *m_graph; } @@ -126,7 +131,9 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* opr) { if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr); } - return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config()); + return OprAttr::make( + registry->name, std::move(ctx.m_param), policy, opr->config(), + opr->dyn_typeinfo()); } std::vector> props(const OpDef& def) { @@ -168,6 +175,12 @@ size_t OprAttr::hash() const { config.hash()); } +std::shared_ptr OprAttr::mgb_param(OprFootprint* footprint) { + OprParamsLoadContext ctx{param, nullptr}; + ctx.strict = false; + return footprint->get_serial_param_json(mgb_opr_type, ctx); +}; + MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr); } // namespace imperative diff --git a/imperative/src/impl/transformations/eval.cpp b/imperative/src/impl/transformations/eval.cpp index 400a2264550161a895cead2d98743aee834100da..47713e3873d05954925479d98754fc2c18a80347 100644 --- a/imperative/src/impl/transformations/eval.cpp +++ b/imperative/src/impl/transformations/eval.cpp @@ -130,6 +130,11 @@ ValueRefList InterpreterTransformation::apply_transformation( } else { return {ValueRef()}; } + } else if (op.is()) { + auto& val = inputs[0].cast(m_value_type); + int64_t id = val.id(); + return {IntegerValue::make(id)}; + } else if (op.is()) { auto& input = inputs[0].cast(m_value_type); DeviceTensorND dev_tensor; diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp index 8603bcd18b96d29e30374b3244c66a8af69cb3af..e13c018658fecc7c85acd77fc223f5b9df3c02bb 100644 --- a/imperative/src/impl/transformations/grad.cpp +++ b/imperative/src/impl/transformations/grad.cpp @@ -7,6 +7,7 @@ #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/resource_manager.h" +#include namespace mgb { namespace imperative { @@ -226,7 +227,7 @@ void GradKey::backward() { if constexpr (std::is_same_v) { mgb_throw(AssertionError, "invalid backward"); } else { - mgb_assert(grad_fn->m_slots.size() > 0); + // mgb_assert(grad_fn->m_slots.size() > 0); SmallVector grads (grad_fn->m_slots.size()); auto iter = grads.begin(); for (auto&& slot : grad_fn->m_slots) { @@ -419,6 +420,23 @@ ValueRefList GradTransformation::apply_transformation( mgb_assert(!grad_fn->m_slots.empty()); m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); return outputs; + } else if (auto* igc = op.as()) { + auto grad_fn = LocalPtr::make(); + auto& backward = + std::get(grad_fn->m_backward = CustomBackward()); + auto id = inputs[0]; + backward.m_backward = [id, callback = igc->callback()]( + Span inputs) -> SmallVector { + callback({&id, (size_t)1}); + return {}; + }; + m_key->m_side_effects.push_back(grad_fn); + m_key->m_tape.push_back({grad_fn, nullptr}); + auto next_id = IntegerValue::make((int)id.cast() + 1); + auto prev_count = + imperative::apply(InsertGradCallback(igc->callback()), next_id)[0]; + auto count = IntegerValue::make((int)prev_count.cast() + 1); + return {count}; } else if (op.is()) { return imperative::apply(op, inputs); } else if (auto* attach_grad = op.as()) { @@ -514,6 +532,13 @@ ValueRefList GradTransformation::apply_transformation( } } return imperative::apply(op, inputs); + } else if (op.is()) { + mgb_assert(inputs.size() == 1); + if (auto&& grad_value = as_grad_value(inputs[0])) { + return {GradSlotValue::make(grad_value->slot())}; + } else { + return {}; + } } else if (op.kind() == Operator::IdentityLike) { mgb_assert(inputs.size() == 1); if (auto&& grad_value = as_grad_value(inputs[0])) { diff --git a/imperative/src/impl/transformations/lazy.cpp b/imperative/src/impl/transformations/lazy.cpp index 33e23f54c22c584987c86c8c6402cdd13bb09aa0..88a0deda9874cc76c95e19824c36dc3cc463938b 100644 --- a/imperative/src/impl/transformations/lazy.cpp +++ b/imperative/src/impl/transformations/lazy.cpp @@ -53,6 +53,9 @@ ValueRefList LazyEvalTransformation::apply_transformation( outputs[i] = record_var(output_nodes[i]); } return outputs; + } else if (op.is()) { + int64_t id = inputs[0].id(); + return {IntegerValue::make(id)}; } else if (auto* create_tensor = op.as()) { auto&& args = create_tensor->parse(inputs); auto get_dev_val = [&] { diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index d1c7a56c35cc67206942f746b188c72fa5a7e501..5fd3a5c32dfc13533b81b48040d98c51d4c46f99 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -83,7 +83,7 @@ VarNodeArray TraceResult::dump( std::unordered_map> name2ops; // iterate over opr_seq for (auto&& item : seq) { - auto&& [op, inputs, outputs] = item; + auto&& [op, inputs, outputs, type] = item; VarNodeArray input_nodes; for (auto&& input : inputs) { auto& node = nodes[input]; @@ -207,7 +207,8 @@ ValueRefList TracingTransformation::apply_transformation( auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal); auto input_id = wrapped_input->id(); auto output_id = wrapped_output->id(); - m_seq.push_back({{}, {input_id}, {output_id}}); + + m_seq.push_back({{}, {input_id}, {output_id}, OpKind::CreateTensor}); return {wrapped_output}; } else if (auto* get_attr = op.as()) { auto unwrapped_input = unwrap_var(inputs[0]); @@ -246,7 +247,30 @@ ValueRefList TracingTransformation::apply_transformation( } auto output = record_var(input, false, VarKind::Internal); m_vars[output->id()].mark = trace_mark_var->mark(); - m_seq.push_back({{}, {tracing_var->id()}, {output->id()}}); + m_seq.push_back( + + {{}, {tracing_var->id()}, {output->id()}, OpKind::TraceMarkVar}); + return {output}; + } else if (auto* iomarker = op.as()) { + mgb_assert(inputs.size() == 1, "IOMarkVar expects exactly one input"); + auto input = inputs[0]; + auto tracing_var = input.as_ref(m_value_type); + if (!tracing_var) { + bool is_input = iomarker->kind() == IOMarkVar::Kind::Input; + if (is_input) { + tracing_var = record_var(input, false, VarKind::External); + } else { + tracing_var = record_var(input, m_capture_as_const, VarKind::External); + } + } else { + input = tracing_var->value(); + } + auto output = record_var(input, false, VarKind::Internal); + if (iomarker->kind() == IOMarkVar::Kind::Input) + m_vars[tracing_var->id()].inp_marker.insert(iomarker->mark()); + else + m_vars[output->id()].out_marker.insert(iomarker->mark()); + m_seq.push_back({{}, {tracing_var->id()}, {output->id()}, OpKind::IOMarkVar}); return {output}; } else if (auto* trace_name_var = op.as()) { mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input"); @@ -259,7 +283,7 @@ ValueRefList TracingTransformation::apply_transformation( } auto output = record_var(input, false, VarKind::Internal); m_vars[output->id()].name = trace_name_var->name(); - m_seq.push_back({{}, {tracing_var->id()}, {output->id()}}); + m_seq.push_back({{}, {tracing_var->id()}, {output->id()}, OpKind::Rename}); return {output}; } else if (op.is()) { mgb_assert(inputs.size() == 1, "GetName expects exactly one input"); @@ -279,6 +303,78 @@ ValueRefList TracingTransformation::apply_transformation( } } +void TracingTransformation::postprocess_trace_result() { + std::unordered_map identity_oi_map, identity_io_map; + for (auto&& op : m_seq) { + if (op.op == nullptr && op.inputs.size() == 1 && op.outputs.size() == 1) { + identity_oi_map[op.outputs[0]] = op.inputs[0]; + identity_io_map[op.inputs[0]] = op.outputs[0]; + } + } + + for (auto&& op : m_seq) { + if (op.kind == OpKind::IOMarkVar) { + auto&& inpvar = m_vars[op.inputs[0]]; + auto&& outvar = m_vars[op.outputs[0]]; + if (inpvar.inp_marker.size() > 0) { + auto id = inpvar.id; + if (inpvar.kind != VarKind::External) { + while (identity_oi_map.find(id) != identity_oi_map.end()) { + id = identity_oi_map[id]; + } + if (m_vars[id].kind == VarKind::External) { + for (auto mark : inpvar.inp_marker) { + mgb_assert( + inpmark_to_id.find(mark) == inpmark_to_id.end() || + inpmark_to_id[mark] == id, + "two nodes have same mark"); + inpmark_to_id[mark] = id; + m_vars[id].inp_marker.insert(mark); + } + inpvar.inp_marker.clear(); + } + } else { + for (auto mark : inpvar.inp_marker) { + mgb_assert( + inpmark_to_id.find(mark) == inpmark_to_id.end() || + inpmark_to_id[mark] == id, + "two nodes have same mark"); + inpmark_to_id[mark] = id; + } + } + } else { + mgb_assert(outvar.out_marker.size() > 0); + auto id = outvar.id; + if (!outvar.data_required) { + while (identity_io_map.find(id) != identity_io_map.end()) { + id = identity_io_map[id]; + } + + if (m_vars[id].data_required) { + for (auto mark : outvar.out_marker) { + mgb_assert( + outmark_to_id.find(mark) == outmark_to_id.end() || + outmark_to_id[mark] == id, + "two nodes have same mark"); + outmark_to_id[mark] = id; + m_vars[id].out_marker.insert(mark); + } + outvar.out_marker.clear(); + } + } else { + for (auto mark : outvar.out_marker) { + mgb_assert( + outmark_to_id.find(mark) == outmark_to_id.end() || + outmark_to_id[mark] == id, + "two nodes have same mark"); + outmark_to_id[mark] = id; + } + } + } + } + } +} + void TracingTransformation::on_unregister() noexcept { for (auto&& weak_var : m_weak_vars) { if (auto tracing_value = weak_var.lock()) { @@ -526,7 +622,10 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { var.device->to_string().c_str(), device.to_string().c_str()); } - var_accessor.data_setter(value.dev_tensor()->as_nd()); + if (m_setted_extern.find(id) == m_setted_extern.end()) { + var_accessor.data_setter(value.dev_tensor()->as_nd()); + m_setted_extern.insert(id); + } break; } case VarKind::Constant: { @@ -732,6 +831,7 @@ void CompiledTransformation::wait() { m_pc = 0; std::exception_ptr graph_exc; std::swap(m_graph_exc, graph_exc); + m_setted_extern.clear(); if (graph_exc) { // graph with exception cannot be reused recompile(); diff --git a/imperative/src/impl/value.cpp b/imperative/src/impl/value.cpp index 7bd35f0e8f4c0720818c92cb614087280f3b1553..3feb9cf3b71afc081886941947d1976326329831 100644 --- a/imperative/src/impl/value.cpp +++ b/imperative/src/impl/value.cpp @@ -127,6 +127,10 @@ bool ValueRef::watching() const { return this->storage()->m_watching; } +int ValueRef::handle_id() const { + return imperative::apply(GetId(), *this)[0].cast(); +} + ValueRef ValueRef::make(ValueRef::storage_t storage) { if (recording_values) { recorded_values.push_back({storage}); diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index cca005516c02079302bf9a59853d54602417081c..9b065fe90e596f654b471473a8c1831666da848d 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -141,6 +141,14 @@ public: ValueRefList fallback(Span inputs) const override { return {ValueRef()}; } }; +class GetId final : public OperatorImpl { +public: + std::string to_string() const override; + std::string raw_type() const { return "GetId"; } + + ValueRefList fallback(Span inputs) const override { return {ValueRef()}; } +}; + /** * \brief return a value with new name * diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index ad4adae56a6094f050ada76a46772064939b4f13..b8d80b0f404c1bebfe096a4e019a09f20f6b3d6c 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -48,6 +48,25 @@ public: std::string to_string() const override; }; +class Integer { +private: + int64_t m_value; + +public: + Integer() = default; + Integer(int64_t value) : m_value(value) {} + + operator int64_t() const { return m_value; } +}; + +// TODO: override factory method +class IntegerValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override; +}; + class HostStorage final : public PrimitiveValue { public: using PrimitiveValue::PrimitiveValue; diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 9142c8e45f8cd42649aa33e054faff86764b0dad..9391b4461a996d950450ae5de1ee4cfe118b114f 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -80,6 +80,8 @@ public: const std::string make_name() const; + virtual const std::string type_name() const; + void set_scope(const std::string& scope); virtual size_t hash() const; diff --git a/imperative/src/include/megbrain/imperative/ops/opr_attr.h b/imperative/src/include/megbrain/imperative/ops/opr_attr.h index 5ca709bf34fd5569abecdb3c3458c159ec273973..9798cb68af233d8a408d77131db539b9c768aaa2 100644 --- a/imperative/src/include/megbrain/imperative/ops/opr_attr.h +++ b/imperative/src/include/megbrain/imperative/ops/opr_attr.h @@ -2,6 +2,7 @@ #include "megbrain/imperative/op_def.h" #include "megbrain/opr/param_defs.h" +#include "megbrain/plugin/opr_footprint.h" namespace mgb { namespace imperative { @@ -28,6 +29,7 @@ public: Type type; Param param; + Typeinfo* mgb_opr_type; megdnn::param::ExecutionPolicy policy; cg::OperatorNodeConfig config; @@ -36,13 +38,14 @@ public: OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) : type(t), param(p), config(c) {} OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps, - const cg::OperatorNodeConfig& c) - : type(t), param(p), policy(ps), config(c) {} + const cg::OperatorNodeConfig& c, Typeinfo* optype) + : type(t), param(p), policy(ps), config(c), mgb_opr_type(optype) {} std::string repr() const; - + std::shared_ptr mgb_param(OprFootprint*); bool is_same_st(const Hashable& rhs) const override; size_t hash() const override; + const std::string type_name() const override { return type; } }; } // namespace imperative diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h index 42422d99eae482cd10ce6afdd0ea094a59fa0e3c..c46269e1b114ff8c0d6db4cea9796956b58e76ba 100644 --- a/imperative/src/include/megbrain/imperative/transformations/grad.h +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -107,7 +107,7 @@ private: public: std::string to_string() const; - + ValueRef grad() const { return m_grad; } friend class GradKey; friend class GradSlotProducerPtr; friend class GradTransformation; @@ -224,6 +224,7 @@ public: class GradKey : public std::enable_shared_from_this { private: std::string m_name; + std::vector> m_side_effects; std::vector, std::shared_ptr>> m_tape; std::vector, std::shared_ptr>> m_frozen_tape; bool m_frozen = false; @@ -253,6 +254,13 @@ public: } }; +class GradSlotValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override { return ssprintf("GradSlot{}"); } +}; + class GradTransformation final : public Transformation { private: ObjectType m_value_type{"GradValue"}; @@ -404,6 +412,28 @@ public: ValueRefList fallback(Span inputs) const override { return {ValueRef()}; } }; +class GetGradSlot : public OperatorImpl { +public: + GetGradSlot() = default; + + std::string to_string() const override { return ssprintf("GetGradSlot{}"); } + std::string raw_type() const { return "GetGradSlot"; }; + ValueRefList fallback(Span inputs) const override { return {}; } +}; + +class InsertGradCallback : public OperatorImpl { +public: + GenericFunction m_callback; + +public: + InsertGradCallback(GenericFunction callback) : m_callback(callback) {} + + GenericFunction callback() const { return m_callback; } + + std::string to_string() const override { return ssprintf("InsertGradCallback{}"); } + std::string raw_type() const { return "InsertGradCallback"; } +}; + class GetBackwardColsure : public OperatorImpl { private: @@ -420,4 +450,19 @@ public: std::string raw_type() const { return "GetBackwardClosure"; } }; +class GradTransformationGuard final : public Transformation { + ValueRefList apply_transformation( + const Operator& op, Span inputs) override { + if (auto* igc = op.as()) { + auto count = IntegerValue::make(0); + return {count}; + } + return imperative::apply(op, inputs); + } + + ValueRef unwrap(ValueRef value) override { return value; }; + + std::string name() const override { return "GradTransformationGuard"; }; +}; + } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index 9ad558ece7553d83bc79d8c399556021ca26c136..1be4d54c10e6f1ca4147b7361890b493ffc418ce 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -2,8 +2,8 @@ #include #include +#include #include - #include "megbrain/gopt/inference.h" #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/interpreter.h" @@ -17,11 +17,21 @@ namespace mgb::imperative { struct TraceResult { struct SeqItem { + enum OpKind { + Unknown, + TraceMarkVar, + Rename, + IOMarkVar, + CreateTensor, + }; std::shared_ptr op; SmallVector inputs; SmallVector outputs; + OpKind kind = OpKind::Unknown; }; + using OpKind = SeqItem::OpKind; + struct VarInfo { enum Kind { External, // End point of traced graph, its value is received from @@ -41,12 +51,14 @@ struct TraceResult { ValueRef bound_data; std::string mark; std::string name; + int handle_id; Kind kind; bool value_required = false; bool data_required = false; bool shape_required = false; - + std::set inp_marker; + std::set out_marker; TensorShape shape; }; @@ -91,6 +103,27 @@ public: std::string raw_type() const { return "TraceMarkVar"; } }; +class IOMarkVar : public OperatorImpl { +public: + enum Kind { + Input, + Output, + }; + +private: + size_t m_mark; + Kind m_kind; + +public: + IOMarkVar(size_t mark, Kind kind) : m_mark(mark), m_kind(kind) {} + + size_t mark() const { return m_mark; } + Kind kind() const { return m_kind; } + + std::string to_string() const override { return ssprintf("IOMarkVar"); } + std::string raw_type() const override { return "IOMarkVar"; } +}; + class TracingValue final : public ObjectValue { private: ValueRef m_value = {}; @@ -125,15 +158,22 @@ class TracingTransformation final : public Transformation { public: using VarInfo = TraceResult::VarInfo; using VarKind = VarInfo::Kind; + using OpKind = TraceResult::SeqItem::OpKind; private: std::vector m_seq; std::vector m_vars; std::vector m_weak_vars; + std::unordered_map extern_var_to_id; bool m_capture_as_const = false; bool m_record_input_shapes = false; + bool m_record_all_shapes = false; ObjectType m_value_type{"TracingValue"}; +public: + std::unordered_map inpmark_to_id; + std::unordered_map outmark_to_id; + public: TracingTransformation(bool capture_as_const, bool record_input_shapes) : m_capture_as_const(capture_as_const), @@ -148,7 +188,14 @@ public: * \return TypedValueRef traced value */ TypedValueRef record_var(ValueRef value, bool capture, VarKind kind) { + if (kind == VarKind::External && + extern_var_to_id.find(value.id()) != extern_var_to_id.end()) { + return m_value_type.make(value, extern_var_to_id[value.id()]); + } size_t id = m_vars.size(); + if (kind == VarKind::External) { + extern_var_to_id[value.id()] = id; + } auto wrapped_value = m_value_type.make(value, id); m_vars.push_back({id, value.dtype(), value.device()}); auto& var = m_vars.back(); @@ -156,9 +203,12 @@ public: var.bound_data = value; } var.kind = kind; - if (m_record_input_shapes && kind != VarKind::Internal) { + if ((m_record_input_shapes && kind != VarKind::Internal) || + m_record_all_shapes) { var.shape = value.shape()->as_tensor_shape(); } + if (m_record_all_shapes) + var.handle_id = value.handle_id(); if (auto name = value.name()) { var.name = *name; } @@ -185,8 +235,9 @@ public: std::string name() const override { return "TracingTransformation"; } void on_unregister() noexcept override; - + void postprocess_trace_result(); TraceResult get_result() { return {m_seq, m_vars}; } + void enable_record_all_shapes() { m_record_all_shapes = true; } }; class TraceError : public std::exception { @@ -211,6 +262,7 @@ class CompiledTransformation final : public Transformation { public: using VarInfo = TraceResult::VarInfo; using VarKind = VarInfo::Kind; + using OpKind = TraceResult::SeqItem::OpKind; struct VarAccessor { VarNode* node; @@ -254,6 +306,7 @@ private: std::vector m_seq; std::vector m_vars; std::vector m_var_accessors; + std::unordered_map mark2id; size_t m_pc = 0; std::shared_ptr m_graph; std::unique_ptr m_executable; @@ -268,6 +321,7 @@ private: std::vector> m_boxes; ComputingGraph::OutputSpec m_output_spec; ObjectType m_value_type{"TracedValue"}; + std::set m_setted_extern; public: CompiledTransformation(TraceResult result, bool input_shape_static) @@ -360,8 +414,10 @@ public: return value; } - std::string name() const override { return "CompiledTransformation"; } + VarAccessor& get_accessor_by_id(size_t id) { return m_var_accessors[id]; } + std::string name() const override { return "CompiledTransformation"; } + void set_pc_to_end() { m_pc = m_seq.size(); } void execute(); void wait(); diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index ea37bf0e9d16407387fbb62cff0eba7822d52bd6..2c80fc9ebdb04fe777678ab140ad6ad169da87be 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -222,6 +222,7 @@ public: TypedValueRef dtype() const; TypedValueRef format() const; TypedValueRef name() const; + int handle_id() const; bool is_scalar() const; void watch() const; @@ -298,7 +299,7 @@ protected: public: const IType& type() const { return *m_type; } - + uint64_t id() const { return m_id; } static void register_value(ValueRef value); static ValueRef get_value_by_id(uint64_t id); static void begin_record_values(); @@ -538,11 +539,11 @@ public: const ValueRef* data() const { return m_data; } bool empty() const { return m_size == 0; } ValueRef& front() { - mgb_assert(m_size > 1); + mgb_assert(m_size >= 1); return m_data[0]; } ValueRef& back() { - mgb_assert(m_size > 1); + mgb_assert(m_size >= 1); return m_data[m_size - 1]; } };