提交 b11d4430 编写于 作者: M Megvii Engine Team

feat(mge/jit): support trace withouthost mode

GitOrigin-RevId: 09b29e3dac44a4e4330f2ceb10da7d55df772466
上级 3116e9f7
......@@ -8,6 +8,8 @@ _use_symbolic_shape = False
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
_use_symbolic_shape = True
_use_xla_backend = False
def use_symbolic_shape() -> bool:
r"""Returns whether tensor.shape returns a tensor instead of a tuple"""
......@@ -22,4 +24,15 @@ def set_symbolic_shape(option: bool):
return _org
def use_xla_backend() -> bool:
return _use_xla_backend
def set_use_xla_backend(option: bool) -> bool:
global _use_xla_backend
_org = _use_xla_backend
_use_xla_backend = option
return _org
set_cpp_use_symbolic_shape(use_symbolic_shape)
# -*- coding: utf-8 -*-
from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig
from .partial_tracing import partial_trace
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import TraceError, exclude_from_trace, trace
from .xla_backend import xla_trace
from collections import OrderedDict
from typing import Sequence
from ..core._imperative_rt.core2 import add_backward_callback as _add_backward_callback
from ..core._imperative_rt.core2 import get_grad_slot, get_handle_id
from ..tensor import Tensor
from .tracing import trace
from .xla_backend import xla_trace
def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map):
# partial_trace will record op sequences for forward/backward respectively, and get two TraceResult objects after tracing.
# But the inputs/outputs of backward graph are unknown. This function will determine the inputs and outputs of the backward graph
# var.handle_id is id of value ref. It's used to find the tensors used in both forward and backward calculation.
# inp_grad_map, key: handle id of forward inputs, value: handle id of grads of forward inputs.
# out_grad_map, key: handle id of foward outputs, value: handle id of grads of forward outputs.
fwd_features = set([t.handle_id for t in fwd._trace.vars])
bwd_features = set([t.handle_id for t in bwd._trace.vars])
keep_vars = fwd_features.intersection(
bwd_features
) # some intermediate vars produced by forward, and will be used in backward.
current = max(fwd.out_list) + 1
saved_feature_map = OrderedDict()
saved_featrues = []
# mark keep_vars as forward outputs
for var in fwd._trace.vars:
if (
var.handle_id in keep_vars
and var.data_required
and len(var.out_mark) == 0
and var.kind not in ["const", "external"]
):
keep_vars.remove(var.handle_id)
fwd._trace.mark_output(current, var.id)
saved_feature_map[var.handle_id] = current
saved_featrues.append(current)
current += 1
fwd.keeped_activation = saved_featrues
bwd_inp_idx = 0
bwd_out_idx = 0
bwd_dys = []
bwd_inps = [-1] * len(saved_feature_map)
saved_feature_handle_id = list(saved_feature_map.keys())
dy_ids = list(out_grad_map.values()) # handle_id of grad of forward output
inp_grad_ids = list(inp_grad_map.values()) # handle_id of grad of forward input
bwd_dys = [-1] * len(dy_ids)
bwd_outputs = [-1] * len(inp_grad_ids)
# dy_ids + saved_feature_map are backward inputs
# inp_grad_ids are backward outputs
# mark inputs/outputs for backward
for var in bwd._trace.vars:
if var.handle_id in dy_ids and var.kind == "external":
bwd._trace.mark_input(bwd_inp_idx, var.id)
idx = dy_ids.index(var.handle_id)
bwd_dys[idx] = bwd_inp_idx
bwd_inp_idx += 1
elif var.handle_id in saved_feature_map and var.kind == "external":
bwd._trace.mark_input(bwd_inp_idx, var.id)
bwd_inps[saved_feature_handle_id.index(var.handle_id)] = bwd_inp_idx
bwd_inp_idx += 1
if var.handle_id in inp_grad_ids and var.data_required:
bwd_outputs[inp_grad_ids.index(var.handle_id)] = bwd_out_idx
bwd._trace.mark_output(bwd_out_idx, var.id)
bwd_out_idx += 1
# assert -1 not in bwd_dys
assert -1 not in bwd_inps
for var in fwd._trace.vars:
if not var.out_mark:
var.data_required = False
# assert -1 not in bwd_outputs
bwd.setup_io_without_trace(bwd_dys + bwd_inps, bwd_outputs)
bwd.setup_without_host()
def check_external(trace_obj):
for var in trace_obj.vars:
if var.kind == "external" and not var.inp_mark:
raise RuntimeError("have unknown input in trace result")
check_external(fwd)
check_external(bwd)
JIT_BACKEND = {"default": trace, "xla": xla_trace}
def partial_trace(func=None, *, backend="default", without_host=True, **trace_options):
assert backend in JIT_BACKEND
assert without_host, "partial_trace only support without_host mode currently!"
def wrapper(func):
trace_obj = JIT_BACKEND[backend](
func, without_host=without_host, **trace_options
)
trace_options["capture_as_const"] = False
backward_trace_obj = JIT_BACKEND[backend](
None, without_host=without_host, **trace_options
)
backward_trace_obj.check_external = (
False # check if there are unknown external vars after tracing.
)
trace_obj.overall = False # if trace overall train step
backward_trace_obj.overall = False
trace_obj._trace.remove_unused_data_required = False
backward_trace_obj._trace.remove_unused_data_required = False
inp_grad_maps = OrderedDict() # x, dx map
out_grad_maps = OrderedDict() # y, dy map
traced = False # if wrapped function has been traced
compiled = False # if wrapped function has been compiled
custom_autodiff = None
outdef = None # treedef of forward return value
from ..core.autodiff.grad import Function
class CustomAutodiff(Function):
def __init__(self, fwd, bwd):
self.fwd = fwd
self.bwd = bwd
del fwd.outdef
self.keeped_features = []
def forward(self, *args):
rst = self.fwd(*args)
keeped_features = rst[-1]
if not isinstance(keeped_features, Sequence):
keeped_features = tuple([keeped_features])
else:
keeped_features = tuple(keeped_features)
self.keeped_features = keeped_features
return rst[0]
def get_keeped_features(self):
rst = self.keeped_features
del self.keeped_features
return rst
def backward(self, *output_grads):
output_grads = tuple([i for i in output_grads if i is not None])
return self.bwd(*(output_grads + self.get_keeped_features()))
class CustomFwd:
def __init__(self, fwd, bwd):
self.fwd = fwd
self.bwd = bwd
def __call__(self, *args):
rst = self.fwd(*args)
if self.fwd.keeped_activation:
keeped_features = rst[-1]
if not isinstance(keeped_features, Sequence):
keeped_features = tuple([keeped_features])
else:
keeped_features = tuple(keeped_features)
self.keeped_features = keeped_features
return rst[0]
else:
return rst
def wrapped_func(*args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..module import Module
nonlocal traced
nonlocal compiled
nonlocal custom_autodiff
nonlocal outdef
if not traced:
traced = True
fargs = trace_obj.flatten_inputs(*args, **kwargs)
for t in fargs:
inp_grad_maps[t] = get_grad_slot(t)
del fargs
def exit_trace():
backward_trace_obj._trace.exit()
new_dict = {}
for k, v in inp_grad_maps.items():
if v is not None:
new_dict[get_handle_id(k)] = get_handle_id(v.grad)
else:
new_dict[get_handle_id(k)] = -1
inp_grad_maps.clear()
inp_grad_maps.update(new_dict)
_add_backward_callback(exit_trace)
ret = trace_obj(*args)
rlist, outdef = tree_flatten(ret)
for t in rlist:
out_grad_maps[t] = get_grad_slot(t)
def enter_trace():
new_dict = {}
for k, v in out_grad_maps.items():
if v is not None:
new_dict[get_handle_id(k)] = get_handle_id(v.grad)
out_grad_maps.clear()
out_grad_maps.update(new_dict)
backward_trace_obj._trace.enter()
_add_backward_callback(enter_trace)
return ret
elif not compiled:
if custom_autodiff is None:
_process_fwd_bwd_trace_result(
trace_obj, backward_trace_obj, inp_grad_maps, out_grad_maps
)
if len(backward_trace_obj._trace.ops) > 0:
custom_autodiff = CustomAutodiff(trace_obj, backward_trace_obj)
else:
custom_autodiff = CustomFwd(trace_obj, backward_trace_obj)
fargs = trace_obj.flatten_inputs(*args, **kwargs)
del args
del kwargs
if outdef is None:
return custom_autodiff(*fargs)
else:
return outdef.unflatten(custom_autodiff(*fargs))
return wrapped_func
if func is None:
return wrapper
else:
return wrapper(func)
......@@ -9,6 +9,7 @@ import pickle
import re
import struct
import sys
from collections import OrderedDict, defaultdict
from typing import Any, Sequence
import cv2
......@@ -16,9 +17,22 @@ import numpy as np
from .. import tensor
from ..core import _imperative_rt as rt
from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata
from ..core._imperative_rt import (
CompNode,
GraphProfiler,
GraphProfiler2,
SerializationMetadata,
)
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import Trace, TraceError, name_tensor # skip_tracing,
from ..core._imperative_rt.core2 import Trace, TraceError # skip_tracing,
from ..core._imperative_rt.core2 import add_backward_callback as _add_backward_callback
from ..core._imperative_rt.core2 import (
get_marked_input_tensor,
get_marked_output_tensor,
get_marked_tensor,
marked_input_tensor,
name_tensor,
)
from ..core._imperative_rt.graph import _set_priority_to_id
from ..core._imperative_rt.ops import (
AssertEqual,
......@@ -31,6 +45,7 @@ from ..core._imperative_rt.ops import (
from ..core._trace_option import set_symbolic_shape
from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
from ..tensor import Tensor
from ..utils import comp_graph_tools as cgtools
from ..utils.naming import AutoNaming
from ..utils.profiler import is_profiling
......@@ -94,8 +109,13 @@ class trace:
opt_level: optimization level for compiling trace. Default: 2
graph_opt_config: configuration for graph optimization. Default: None
symbolic_shape: whether to use symbolic shape for tracing. Default: True
without_host: if True, will run python code of wrapped function on the first call,
and run the compiled graph/function on subsequent calls. if False, will run python code every time.
Default: False
"""
third_party_backend = False
def __new__(cls, *args, **kwargs):
if not args:
return functools.partial(cls, **kwargs)
......@@ -113,6 +133,7 @@ class trace:
opt_level: int = 2,
graph_opt_config: GraphOptimizationConfig = None,
symbolic_shape: bool = True,
without_host: bool = False,
):
self.__wrapped__ = function
self._capture_as_const = capture_as_const or record_only
......@@ -150,6 +171,7 @@ class trace:
graph_options["graph_opt.jit_config.fuse_reduce"] = mapping[
graph_opt_config.jit_fuse_reduce
]
if sublinear_memory_config is not None:
graph_options["enable_sublinear_memory_opt"] = True
graph_options[
......@@ -186,8 +208,114 @@ class trace:
self._trace.profile = profiling
self._trace.array_comparator = array_comparator
self._trace.record_input_shapes = _input_node_use_static_shape()
self._trace.without_host = without_host
self.check_external = True
self.traced = False
self.input_num = 0
self.output_num = 0
self.arg_list = []
self.out_list = []
self.overall = True
# forward keeped activation
self.keeped_activation = []
self.third_party_backend_compiled = False
@property
def check_external(self):
return self._trace.check_external
@check_external.setter
def check_external(self, flag):
self._trace.check_external = flag
@property
def without_host(self):
return self._trace.without_host
def flatten_inputs(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..module import Module
tensor_args = []
modules = []
fargs, _ = tree_flatten((args, kwargs))
for a in fargs:
if isinstance(a, RawTensor):
tensor_args.append(a)
elif isinstance(a, Module):
modules.append(a)
for m in modules:
tensor_args.extend(list(m.parameters()))
return tensor_args
def compile(self):
raise NotImplementedError
def execute(self, *args, **kwargs):
raise NotImplementedError
def setup_env(self):
pass
def unset_env(self):
pass
def compile_and_exec(self, *args, **kwargs):
if not self.third_party_backend_compiled:
self.compile()
self.third_party_backend_compiled = True
return self.execute(*args, **kwargs)
def convert_optimizer_state_to_tensor(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS
from ..optimizer import Optimizer
from ..tensor import Tensor
if Optimizer not in SUPPORTED_LEAF_CLS:
SUPPORTED_LEAF_CLS.append(Optimizer)
args, _ = tree_flatten((args, kwargs))
for arg in args:
if isinstance(arg, Optimizer):
arg._disable_type_convert = False
for param_group in arg.param_groups:
for k, v in param_group.items():
if not isinstance(v, (Tensor, Sequence)):
param_group[k] = Tensor(v)
elif isinstance(v, Sequence) and not isinstance(v[0], Tensor):
new_v = []
for i in range(len(v)):
new_v.append(Tensor(v[i]))
param_group[k] = new_v
def setup_io_without_trace(self, inputs, outputs):
self.traced = True
self.arg_list = [i for i in inputs if i != -1]
self.out_list = outputs
self.input_num = len(self.arg_list)
self.output_num = len([i for i in outputs if i != -1])
def setup_without_host(self):
self.inp_modules = set()
self.module_tensors = set()
self.tensor_to_attr = dict()
self.attr_to_key = dict()
self.update_param_dict = dict()
self.update_opt_param_dict = dict()
self.capture_optimizer_state = set()
self.opt_param_dict = dict()
def __call__(self, *args, **kwargs):
if not self.without_host:
return self.trace_normal(*args, **kwargs)
elif self.overall:
return self.trace_without_host_overall(*args, **kwargs)
else:
return self.trace_without_host(*args, **kwargs)
def trace_normal(self, *args, **kwargs):
global active_trace
symbolic_shape = None
outputs = None
......@@ -214,6 +342,270 @@ class trace:
raise
return outputs
def trace_without_host(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS
from ..module import Module
from ..utils.module_utils import get_expand_structure
from ..tensor import Tensor
from ..optimizer import Optimizer
assert self.without_host and not self.overall
global active_trace
symbolic_shape = None
outputs = None
if self.traced and self.third_party_backend:
return self.compile_and_exec(*args, **kwargs)
try:
active_trace = self
self._trace.enter()
if self._trace.compiled():
arglist = self.flatten_inputs(*args, **kwargs)
idx = 0
inp_dict = {}
for a in arglist:
if isinstance(a, RawTensor):
inp_dict[self.arg_list[idx]] = a
idx += 1
self._trace.put_datas(inp_dict)
outlist = []
for i in self.out_list:
if i == -1:
if not hasattr(self, "outdef"):
outlist.append(None)
else:
outlist.append(self._trace.get_data(i))
keep_vars = []
for i in self.keeped_activation:
keep_vars.append(self._trace.get_data(i))
outputs = (
self.outdef.unflatten(outlist)
if hasattr(self, "outdef")
else outlist
)
if keep_vars:
return outputs, keep_vars
else:
return outputs
arg_list = self.flatten_inputs(*args, **kwargs)
for i, arg in enumerate(arg_list):
arg_list[i]._reset(get_marked_input_tensor(self.input_num, arg))
self.arg_list.append(self.input_num)
self.input_num += 1
del arg_list
symbolic_shape = set_symbolic_shape(self._symbolic_shape)
if self.third_party_backend:
self.setup_env()
outputs = self.__wrapped__(*args, **kwargs)
finally:
handling_exc = sys.exc_info() != (None,) * 3
active_trace = None
if symbolic_shape is not None:
symbolic_shape = set_symbolic_shape(symbolic_shape)
assert symbolic_shape == self._symbolic_shape
if self.third_party_backend:
self.unset_env()
if (
self._capture_as_const
and (outputs is not None)
and not self.without_host
):
self._process_outputs(outputs)
if not self._trace.compiled():
outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist):
assert isinstance(out, RawTensor), type(out)
outlist[i] = get_marked_output_tensor(self.output_num, out)
del out
self.out_list.append(self.output_num)
self.output_num += 1
outputs = self.outdef.unflatten(outlist)
try:
# may raise TraceError
self._trace.exit()
except Exception as e:
if isinstance(e, TraceError):
if not handling_exc:
raise
else:
self._trace.set_execption(str(e))
raise
self.traced = True
return outputs
def trace_without_host_overall(self, *args, **kwargs):
# record overall train step include forward, backward, param update in a single trace object
from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS
from ..module import Module
from ..utils.module_utils import get_expand_structure
from ..tensor import Tensor
from ..optimizer import Optimizer
assert self.without_host
global active_trace
symbolic_shape = None
outputs = None
if self.traced and self.third_party_backend:
return self.compile_and_exec(*args, **kwargs)
try:
active_trace = self
if not self.traced:
self.convert_optimizer_state_to_tensor(*args, **kwargs)
self._trace.enter()
if self._trace.compiled():
arglist, _ = tree_flatten((args, kwargs))
idx = 0
inp_dict = {}
for a in arglist:
if isinstance(a, RawTensor):
inp_dict[self.arg_list[idx]] = a
idx += 1
for t, key in self.opt_param_dict.items():
inp_dict[key] = t
self._trace.put_datas(inp_dict)
for attr, key in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
self._trace.put_data(key, param)
outlist = []
for i in self.out_list:
if i == -1:
if not hasattr(self, "outdef"):
outlist.append(None)
else:
outlist.append(self._trace.get_data(i))
for attr, key in self.update_param_dict.items():
param = get_expand_structure(attr[0], attr[1])
param._reset(self._trace.get_data(key))
for state, key in self.update_opt_param_dict.items():
state._reset(self._trace.get_data(key))
keep_vars = []
for i in self.keeped_activation:
keep_vars.append(self._trace.get_data(i))
outputs = (
self.outdef.unflatten(outlist)
if hasattr(self, "outdef")
else outlist
)
if keep_vars:
return outputs, keep_vars
else:
return outputs
self.setup_without_host()
def get_attr_hook(obj, attr):
rst = object.__getattribute__(obj, attr)
if isinstance(rst, RawTensor):
assert rst in self.tensor_to_attr
attr = self.tensor_to_attr[rst]
if attr not in self.attr_to_key:
self.attr_to_key[attr] = self.input_num
self.input_num += 1
marked_input_tensor(self.attr_to_key[attr], rst)
return rst
origin_reset = Tensor._reset
self.update_param_num = 0
def tensor_wrapper_resethook(obj, other):
if obj in self.tensor_to_attr:
attr = self.tensor_to_attr[obj]
other = get_marked_output_tensor(self.output_num, other)
self.update_param_num += 1
self.update_param_dict[attr] = self.output_num
self.output_num += 1
elif obj in self.capture_optimizer_state:
other = get_marked_output_tensor(self.output_num, other)
self.update_opt_param_dict[obj] = self.output_num
self.output_num += 1
origin_reset(obj, other)
arg_list, self.argdef = tree_flatten((args, kwargs))
for i, arg in enumerate(arg_list):
if isinstance(arg, Module):
for k, v in arg.named_tensors():
if v not in self.tensor_to_attr:
self.tensor_to_attr[v] = (arg, k)
self.inp_modules.add(arg)
elif isinstance(arg, RawTensor):
arg_list[i] = get_marked_input_tensor(self.input_num, arg)
self.arg_list.append(self.input_num)
self.input_num += 1
elif isinstance(arg, Optimizer):
opt_params, _ = tree_flatten(arg.state_dict(keep_var=True))
for p in opt_params:
if isinstance(p, Tensor):
self.capture_optimizer_state.add(p)
self.opt_param_dict = {}
for t in self.capture_optimizer_state:
if t not in self.tensor_to_attr: # not module parameter
mark_param = get_marked_input_tensor(self.input_num, t)
self.opt_param_dict[t] = self.input_num
t[...] = mark_param
self.input_num += 1
args, kwargs = self.argdef.unflatten(arg_list)
Module.__getattribute__ = get_attr_hook
Tensor._reset = tensor_wrapper_resethook
symbolic_shape = set_symbolic_shape(self._symbolic_shape)
if self.third_party_backend:
self.setup_env()
outputs = self.__wrapped__(*args, **kwargs)
del arg_list
del args
del kwargs
Module.__getattribute__ = object.__getattribute__
Tensor._reset = origin_reset
for attr, key in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
finally:
handling_exc = sys.exc_info() != (None,) * 3
active_trace = None
if symbolic_shape is not None:
symbolic_shape = set_symbolic_shape(symbolic_shape)
assert symbolic_shape == self._symbolic_shape
if self.third_party_backend:
self.unset_env()
if (
self._capture_as_const
and (outputs is not None)
and not self.without_host
):
self._process_outputs(outputs)
if not self._trace.compiled():
outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist):
assert isinstance(out, RawTensor)
outlist[i] = get_marked_output_tensor(self.output_num, out)
del out
self.out_list.append(self.output_num)
self.output_num += 1
outputs = self.outdef.unflatten(outlist)
try:
# may raise TraceError
self._trace.exit()
except Exception as e:
if isinstance(e, TraceError):
if not handling_exc:
raise
else:
self._trace.set_execption(str(e))
raise
self.traced = True
return outputs
@property
def ops(self):
return self._trace.ops
@property
def vars(self):
return self._trace.vars
def _process_inputs(self, *args, **kwargs):
for i, arg in enumerate(args):
assert isinstance(
......
from collections import OrderedDict, defaultdict
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._trace_option import set_use_xla_backend
from ..device import get_default_device
from ..utils.dlpack import from_dlpack, to_dlpack
from .tracing import trace
try:
from ..xla.lib import xla_client as xc
except ImportError:
pass
class xla_trace(trace):
r"""Wraps a callable, and provides accelerated evaluation compiled by xla.
Currently it is an experimental feature.
Refer to :class:`~.jit.tracing.trace` for more information.
Examples:
.. code-block:: python
import numpy as np
from basecls.models.resnet import resnet18
from megengine.autodiff.grad_manager import GradManager
from megengine.jit import xla_trace
from megengine.optimizer import Adam
model = resnet18()
gm = GradManager()
opt = Adam(model.parameters(), lr=1e-4)
gm.attach(model.parameters())
# Only tensors in wrapped func args/kwargs will be treated as graph inputs,
# and other tensors will be captured as const value.
# Module, optimizer, and train data/label should be arguments of the wrapped function.
@xla_trace(capture_as_const=True)
def train_step(model, opt, data, label):
with gm:
pred = model(data)
loss = F.loss.cross_entropy(pred, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
"""
third_party_backend = True
def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs):
assert without_host, "xla trace only support without host mode"
assert not symbolic_shape, "xla doesn't support dynamic shape currently"
super().__init__(
function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs
)
def setup_env(self):
self.orig_use_xla = set_use_xla_backend(True)
def unset_env(self):
set_use_xla_backend(self.orig_use_xla)
def compile(self):
from ..xla import build_xla
from ..traced_module.pytree import SUPPORTED_LEAF_TYPE, register_supported_type
from ..utils.module_utils import get_expand_structure
from ..xla.device import get_xla_backend_and_device
from ..tensor import Tensor
assert self.traced
if self.overall:
for attr, _ in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
param._reset(param.to("cpux"))
for tensor, _ in self.opt_param_dict.items():
tensor._reset(tensor.to("cpux"))
self.xla_exec, self.inp_ids, self.out_ids = build_xla(
self, return_with_io=True, return_device_array=True
)
id2inpidx = defaultdict(list)
id2outidx = defaultdict(list)
for idx, id in enumerate(self.inp_ids):
id2inpidx[id].append(idx)
for idx, id in enumerate(self.out_ids):
id2outidx[id].append(idx)
self.inpkey2idx = {}
self.outkey2idx = {}
if self.input_num == len(set(self.inp_ids)) - 1:
self.has_randomstate = True
self.random_seed = Tensor([[1, 2], [3, 4]], dtype="int32")
else:
assert self.input_num == len(set(self.inp_ids))
self.has_randomstate = False
inpmark2id = dict()
outmark2id = dict()
for var in self.vars:
if var.kind == "external":
for mark in var.inp_mark:
inpmark2id[mark] = var.id
elif var.data_required and var.out_mark:
for mark in var.out_mark:
outmark2id[mark] = var.id
for k, v in inpmark2id.items():
for idx in id2inpidx[v]:
self.inpkey2idx[k] = idx
for k, v in outmark2id.items():
for idx in id2outidx[v]:
self.outkey2idx[k] = idx
def prepare_xla_inputs(self, tensors):
from ..utils.module_utils import get_expand_structure
inp_count = 0
inp_list = [0] * self.input_num
for idx, t in enumerate(tensors):
inp = self.inpkey2idx[self.arg_list[idx]]
inp_list[inp] = t
inp_count += 1
if self.overall:
for attr, key in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
inp = self.inpkey2idx[key]
inp_list[inp] = param
inp_count += 1
for tensor, k in self.opt_param_dict.items():
inp = self.inpkey2idx[k]
inp_list[inp] = tensor
inp_count += 1
assert inp_count == self.input_num
if self.has_randomstate:
inp_list.append(self.random_seed)
return inp_list
def to_dlpack(self, x, take_ownership: bool = True):
return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership)
def execute(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..tensor import Tensor
from ..utils.module_utils import get_expand_structure
inputs, _ = tree_flatten((args, kwargs))
arrays = []
cn = CompNode(get_default_device())
stream = dict(self.xla_exec.backend.get_compute_compnode())
device_kind, device_id, stream_id = cn.physical_locator
xla_stream = stream[device_id]
xla_comp_cn = "gpu{}:{}".format(device_id, xla_stream)
for t in inputs:
if isinstance(t, RawTensor):
assert cn == t.device
arrays.append(t.to(xla_comp_cn, _borrow=True))
arrays = self.prepare_xla_inputs(arrays)
outputs = self.xla_exec(*arrays)
return_vals = []
for i in self.out_list:
if i == -1:
if not hasattr(self, "outdef"):
return_vals.append(None)
else:
return_vals.append(outputs[self.outkey2idx[i]])
keeped_features = []
for i in self.keeped_activation:
capsule = self.to_dlpack(outputs[self.outkey2idx[i]])
t = from_dlpack(capsule, xla_stream).to(cn, _borrow=True)
keeped_features.append(t)
out_tensors = []
for array in return_vals:
if array is not None:
capsule = self.to_dlpack(array)
t = from_dlpack(capsule, xla_stream)
out_tensors.append(t.to(cn, _borrow=True))
else:
out_tensors.append(array)
if self.overall:
for attr, key in self.update_param_dict.items():
param = get_expand_structure(attr[0], attr[1])
xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array)
param._reset(from_dlpack(capsule).to(cn, _borrow=True))
for state, key in self.update_opt_param_dict.items():
xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array)
state._reset(from_dlpack(capsule).to(cn, _borrow=True))
rst = (
self.outdef.unflatten(out_tensors)
if hasattr(self, "outdef")
else out_tensors
)
if keeped_features:
return rst, keeped_features
else:
return rst
......@@ -49,6 +49,8 @@ def _access_structure(obj, key, callback=None):
cur = cur[k]
else:
cur = getattr(cur, k)
if callable is None:
return cur
return callback(parent, k, cur)
......
......@@ -115,11 +115,9 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
}
py::object Py_Varnode = py::none();
const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
std::make_unique<mgb::OprFootprint>()};
void init_graph_rt(py::module m) {
static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
std::make_unique<mgb::OprFootprint>()};
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
......
......@@ -10,7 +10,7 @@
namespace py = pybind11;
extern py::object Py_Varnode;
extern const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr;
template <typename T>
class GraphNodePtr {
std::shared_ptr<mgb::cg::ComputingGraph> m_graph;
......
此差异已折叠。
......@@ -74,6 +74,7 @@ public:
inline ValueRef data() const { return m_data.unwrap(); }
bool is_scalar() { return data().is_scalar(); }
inline std::string name() { return m_name; }
inline size_t value_id() { return m_data.id(); }
inline void set_name(std::string name) {
m_name = name;
if (!name.empty()) {
......@@ -128,6 +129,7 @@ public:
void reset(PyObject*);
PyObject* detach();
PyObject* isscalar();
PyObject* value_id();
PyObject* _dev_tensor();
void _drop();
PyObject* varnode();
......
......@@ -18,7 +18,13 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace
from megengine.jit import (
GraphOptimizationConfig,
TraceError,
exclude_from_trace,
partial_trace,
trace,
)
from megengine.module import Module
from megengine.random import normal, uniform
from megengine.utils.naming import AutoNaming
......@@ -803,3 +809,87 @@ def test_dump_without_output_error():
str(e)
== "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
)
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_without_host(trace_mode):
@trace(symbolic=trace_mode, without_host=True)
def fwd(a, b, c):
x = a + b
y = a + c
z = x * y
z1 = x / y
return [z, z1]
a = tensor([1.0])
b = tensor([2.0])
c = tensor([3.0])
rst = fwd(a, b, c)
for _ in range(2):
trace_rst = fwd(a, b, c)
np.testing.assert_equal(rst[0], trace_rst[0])
np.testing.assert_equal(rst[1], trace_rst[1])
def test_trace_without_error():
const = tensor([8.0])
@trace(symbolic=False, without_host=True)
def fwd(a, b, c):
x = a + b
y = a + c
z = x * y
z1 = x / y + const
return [z, z1]
try:
a = tensor([1.0])
b = tensor([2.0])
c = tensor([3.0])
fwd(a, b, c)
except Exception as e:
assert str(e) == "have some unknown input tensors in trace result"
else:
assert False
def test_partial_trace_fwd_bwd():
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.0], dtype=np.float32)
self.b = Parameter([2.0], dtype=np.float32)
@partial_trace
def forward(self, x):
x = x * self.a + x / self.b
x = F.exp(x)
return x
def clear_grad(self):
self.a.grad = None
self.b.grad = None
@partial_trace
def fwd_only(a, b):
return a * b + a / b
m = Simple()
gm = GradManager()
gm.attach(m.parameters())
def func(x):
with gm:
x = x * 3
x = m(x)
x = x * 2
gm.backward(x)
a = m.a.grad
b = m.b.grad
m.clear_grad()
return fwd_only(a, b) + a + b
gt = func(tensor(1.0))
for _ in range(3):
out = func(tensor(1.0))
np.testing.assert_equal(gt.numpy(), out.numpy())
......@@ -105,6 +105,10 @@ std::string IsScalar::to_string() const {
return "IsScalar";
}
std::string GetId::to_string() const {
return "GetId";
}
std::string GetFormat::to_string() const {
return "GetFormat{}";
}
......
......@@ -15,6 +15,10 @@ std::string BoolValue::to_string() const {
return (*this) ? "true" : "false";
}
std::string IntegerValue::to_string() const {
return std::to_string((int)*this);
}
std::string HostStorage::to_string() const {
return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str());
}
......
......@@ -142,6 +142,10 @@ const std::string OpDef::make_name() const {
return m_scope + "." + trait()->make_name(*this);
}
const std::string OpDef::type_name() const {
return trait()->name;
}
static thread_local OpDef::allocator_t local_allocator;
void OpDef::set_allocator(allocator_t allocator) {
......
......@@ -13,6 +13,10 @@ namespace imperative {
namespace {
class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
public:
bool strict = true;
private:
const OprAttr::Param& m_param;
size_t m_pos = 0;
ComputingGraph* m_graph;
......@@ -40,7 +44,8 @@ public:
m_graph(graph) {}
~OprParamsLoadContext() {
mgb_assert(m_pos == m_param.size(), "param not fully consumed");
if (strict)
mgb_assert(m_pos == m_param.size(), "param not fully consumed");
}
ComputingGraph& graph() override { return *m_graph; }
......@@ -126,7 +131,9 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) {
policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr);
}
return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config());
return OprAttr::make(
registry->name, std::move(ctx.m_param), policy, opr->config(),
opr->dyn_typeinfo());
}
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
......@@ -168,6 +175,12 @@ size_t OprAttr::hash() const {
config.hash());
}
std::shared_ptr<json::Value> OprAttr::mgb_param(OprFootprint* footprint) {
OprParamsLoadContext ctx{param, nullptr};
ctx.strict = false;
return footprint->get_serial_param_json(mgb_opr_type, ctx);
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);
} // namespace imperative
......
......@@ -130,6 +130,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
} else {
return {ValueRef()};
}
} else if (op.is<GetId>()) {
auto& val = inputs[0].cast(m_value_type);
int64_t id = val.id();
return {IntegerValue::make(id)};
} else if (op.is<DupTensor>()) {
auto& input = inputs[0].cast(m_value_type);
DeviceTensorND dev_tensor;
......
......@@ -7,6 +7,7 @@
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/resource_manager.h"
#include <range/v3/all.hpp>
namespace mgb {
namespace imperative {
......@@ -226,7 +227,7 @@ void GradKey::backward() {
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_throw(AssertionError, "invalid backward");
} else {
mgb_assert(grad_fn->m_slots.size() > 0);
// mgb_assert(grad_fn->m_slots.size() > 0);
SmallVector<ValueRef> grads (grad_fn->m_slots.size());
auto iter = grads.begin();
for (auto&& slot : grad_fn->m_slots) {
......@@ -419,6 +420,23 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert(!grad_fn->m_slots.empty());
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
return outputs;
} else if (auto* igc = op.as<InsertGradCallback>()) {
auto grad_fn = LocalPtr<GradFn>::make();
auto& backward =
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward());
auto id = inputs[0];
backward.m_backward = [id, callback = igc->callback()](
Span<ValueRef> inputs) -> SmallVector<ValueRef> {
callback({&id, (size_t)1});
return {};
};
m_key->m_side_effects.push_back(grad_fn);
m_key->m_tape.push_back({grad_fn, nullptr});
auto next_id = IntegerValue::make((int)id.cast<IntegerValue>() + 1);
auto prev_count =
imperative::apply(InsertGradCallback(igc->callback()), next_id)[0];
auto count = IntegerValue::make((int)prev_count.cast<IntegerValue>() + 1);
return {count};
} else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs);
} else if (auto* attach_grad = op.as<AttachGrad>()) {
......@@ -514,6 +532,13 @@ ValueRefList GradTransformation::apply_transformation(
}
}
return imperative::apply(op, inputs);
} else if (op.is<GetGradSlot>()) {
mgb_assert(inputs.size() == 1);
if (auto&& grad_value = as_grad_value(inputs[0])) {
return {GradSlotValue::make(grad_value->slot())};
} else {
return {};
}
} else if (op.kind() == Operator::IdentityLike) {
mgb_assert(inputs.size() == 1);
if (auto&& grad_value = as_grad_value(inputs[0])) {
......
......@@ -53,6 +53,9 @@ ValueRefList LazyEvalTransformation::apply_transformation(
outputs[i] = record_var(output_nodes[i]);
}
return outputs;
} else if (op.is<GetId>()) {
int64_t id = inputs[0].id();
return {IntegerValue::make(id)};
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto&& args = create_tensor->parse(inputs);
auto get_dev_val = [&] {
......
......@@ -83,7 +83,7 @@ VarNodeArray TraceResult::dump(
std::unordered_map<std::string, std::vector<cg::OperatorNodeBase*>> name2ops;
// iterate over opr_seq
for (auto&& item : seq) {
auto&& [op, inputs, outputs] = item;
auto&& [op, inputs, outputs, type] = item;
VarNodeArray input_nodes;
for (auto&& input : inputs) {
auto& node = nodes[input];
......@@ -207,7 +207,8 @@ ValueRefList TracingTransformation::apply_transformation(
auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal);
auto input_id = wrapped_input->id();
auto output_id = wrapped_output->id();
m_seq.push_back({{}, {input_id}, {output_id}});
m_seq.push_back({{}, {input_id}, {output_id}, OpKind::CreateTensor});
return {wrapped_output};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto unwrapped_input = unwrap_var(inputs[0]);
......@@ -246,7 +247,30 @@ ValueRefList TracingTransformation::apply_transformation(
}
auto output = record_var(input, false, VarKind::Internal);
m_vars[output->id()].mark = trace_mark_var->mark();
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}});
m_seq.push_back(
{{}, {tracing_var->id()}, {output->id()}, OpKind::TraceMarkVar});
return {output};
} else if (auto* iomarker = op.as<IOMarkVar>()) {
mgb_assert(inputs.size() == 1, "IOMarkVar expects exactly one input");
auto input = inputs[0];
auto tracing_var = input.as_ref(m_value_type);
if (!tracing_var) {
bool is_input = iomarker->kind() == IOMarkVar::Kind::Input;
if (is_input) {
tracing_var = record_var(input, false, VarKind::External);
} else {
tracing_var = record_var(input, m_capture_as_const, VarKind::External);
}
} else {
input = tracing_var->value();
}
auto output = record_var(input, false, VarKind::Internal);
if (iomarker->kind() == IOMarkVar::Kind::Input)
m_vars[tracing_var->id()].inp_marker.insert(iomarker->mark());
else
m_vars[output->id()].out_marker.insert(iomarker->mark());
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}, OpKind::IOMarkVar});
return {output};
} else if (auto* trace_name_var = op.as<RenameValue>()) {
mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input");
......@@ -259,7 +283,7 @@ ValueRefList TracingTransformation::apply_transformation(
}
auto output = record_var(input, false, VarKind::Internal);
m_vars[output->id()].name = trace_name_var->name();
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}});
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}, OpKind::Rename});
return {output};
} else if (op.is<GetName>()) {
mgb_assert(inputs.size() == 1, "GetName expects exactly one input");
......@@ -279,6 +303,78 @@ ValueRefList TracingTransformation::apply_transformation(
}
}
void TracingTransformation::postprocess_trace_result() {
std::unordered_map<size_t, size_t> identity_oi_map, identity_io_map;
for (auto&& op : m_seq) {
if (op.op == nullptr && op.inputs.size() == 1 && op.outputs.size() == 1) {
identity_oi_map[op.outputs[0]] = op.inputs[0];
identity_io_map[op.inputs[0]] = op.outputs[0];
}
}
for (auto&& op : m_seq) {
if (op.kind == OpKind::IOMarkVar) {
auto&& inpvar = m_vars[op.inputs[0]];
auto&& outvar = m_vars[op.outputs[0]];
if (inpvar.inp_marker.size() > 0) {
auto id = inpvar.id;
if (inpvar.kind != VarKind::External) {
while (identity_oi_map.find(id) != identity_oi_map.end()) {
id = identity_oi_map[id];
}
if (m_vars[id].kind == VarKind::External) {
for (auto mark : inpvar.inp_marker) {
mgb_assert(
inpmark_to_id.find(mark) == inpmark_to_id.end() ||
inpmark_to_id[mark] == id,
"two nodes have same mark");
inpmark_to_id[mark] = id;
m_vars[id].inp_marker.insert(mark);
}
inpvar.inp_marker.clear();
}
} else {
for (auto mark : inpvar.inp_marker) {
mgb_assert(
inpmark_to_id.find(mark) == inpmark_to_id.end() ||
inpmark_to_id[mark] == id,
"two nodes have same mark");
inpmark_to_id[mark] = id;
}
}
} else {
mgb_assert(outvar.out_marker.size() > 0);
auto id = outvar.id;
if (!outvar.data_required) {
while (identity_io_map.find(id) != identity_io_map.end()) {
id = identity_io_map[id];
}
if (m_vars[id].data_required) {
for (auto mark : outvar.out_marker) {
mgb_assert(
outmark_to_id.find(mark) == outmark_to_id.end() ||
outmark_to_id[mark] == id,
"two nodes have same mark");
outmark_to_id[mark] = id;
m_vars[id].out_marker.insert(mark);
}
outvar.out_marker.clear();
}
} else {
for (auto mark : outvar.out_marker) {
mgb_assert(
outmark_to_id.find(mark) == outmark_to_id.end() ||
outmark_to_id[mark] == id,
"two nodes have same mark");
outmark_to_id[mark] = id;
}
}
}
}
}
}
void TracingTransformation::on_unregister() noexcept {
for (auto&& weak_var : m_weak_vars) {
if (auto tracing_value = weak_var.lock()) {
......@@ -526,7 +622,10 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
var.device->to_string().c_str(),
device.to_string().c_str());
}
var_accessor.data_setter(value.dev_tensor()->as_nd());
if (m_setted_extern.find(id) == m_setted_extern.end()) {
var_accessor.data_setter(value.dev_tensor()->as_nd());
m_setted_extern.insert(id);
}
break;
}
case VarKind::Constant: {
......@@ -732,6 +831,7 @@ void CompiledTransformation::wait() {
m_pc = 0;
std::exception_ptr graph_exc;
std::swap(m_graph_exc, graph_exc);
m_setted_extern.clear();
if (graph_exc) {
// graph with exception cannot be reused
recompile();
......
......@@ -127,6 +127,10 @@ bool ValueRef::watching() const {
return this->storage()->m_watching;
}
int ValueRef::handle_id() const {
return imperative::apply(GetId(), *this)[0].cast<IntegerValue>();
}
ValueRef ValueRef::make(ValueRef::storage_t storage) {
if (recording_values) {
recorded_values.push_back({storage});
......
......@@ -141,6 +141,14 @@ public:
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
};
class GetId final : public OperatorImpl<GetId, Operator::GetAttrLike> {
public:
std::string to_string() const override;
std::string raw_type() const { return "GetId"; }
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
};
/**
* \brief return a value with new name
*
......
......@@ -48,6 +48,25 @@ public:
std::string to_string() const override;
};
class Integer {
private:
int64_t m_value;
public:
Integer() = default;
Integer(int64_t value) : m_value(value) {}
operator int64_t() const { return m_value; }
};
// TODO: override factory method
class IntegerValue final : public PrimitiveValue<IntegerValue, Integer> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class HostStorage final : public PrimitiveValue<HostStorage, HostTensorStorage> {
public:
using PrimitiveValue::PrimitiveValue;
......
......@@ -80,6 +80,8 @@ public:
const std::string make_name() const;
virtual const std::string type_name() const;
void set_scope(const std::string& scope);
virtual size_t hash() const;
......
......@@ -2,6 +2,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/plugin/opr_footprint.h"
namespace mgb {
namespace imperative {
......@@ -28,6 +29,7 @@ public:
Type type;
Param param;
Typeinfo* mgb_opr_type;
megdnn::param::ExecutionPolicy policy;
cg::OperatorNodeConfig config;
......@@ -36,13 +38,14 @@ public:
OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c)
: type(t), param(p), config(c) {}
OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps,
const cg::OperatorNodeConfig& c)
: type(t), param(p), policy(ps), config(c) {}
const cg::OperatorNodeConfig& c, Typeinfo* optype)
: type(t), param(p), policy(ps), config(c), mgb_opr_type(optype) {}
std::string repr() const;
std::shared_ptr<json::Value> mgb_param(OprFootprint*);
bool is_same_st(const Hashable& rhs) const override;
size_t hash() const override;
const std::string type_name() const override { return type; }
};
} // namespace imperative
......
......@@ -107,7 +107,7 @@ private:
public:
std::string to_string() const;
ValueRef grad() const { return m_grad; }
friend class GradKey;
friend class GradSlotProducerPtr;
friend class GradTransformation;
......@@ -224,6 +224,7 @@ public:
class GradKey : public std::enable_shared_from_this<GradKey> {
private:
std::string m_name;
std::vector<LocalPtr<GradFn>> m_side_effects;
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape;
bool m_frozen = false;
......@@ -253,6 +254,13 @@ public:
}
};
class GradSlotValue final : public PrimitiveValue<GradSlotValue, GradSlotPtr> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override { return ssprintf("GradSlot{}"); }
};
class GradTransformation final : public Transformation {
private:
ObjectType<GradValue> m_value_type{"GradValue"};
......@@ -404,6 +412,28 @@ public:
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
};
class GetGradSlot : public OperatorImpl<GetGradSlot, Operator::GetAttrLike> {
public:
GetGradSlot() = default;
std::string to_string() const override { return ssprintf("GetGradSlot{}"); }
std::string raw_type() const { return "GetGradSlot"; };
ValueRefList fallback(Span<ValueRef> inputs) const override { return {}; }
};
class InsertGradCallback : public OperatorImpl<InsertGradCallback, Operator::Other> {
public:
GenericFunction m_callback;
public:
InsertGradCallback(GenericFunction callback) : m_callback(callback) {}
GenericFunction callback() const { return m_callback; }
std::string to_string() const override { return ssprintf("InsertGradCallback{}"); }
std::string raw_type() const { return "InsertGradCallback"; }
};
class GetBackwardColsure
: public OperatorImpl<GetBackwardColsure, Operator::GetAttrLike> {
private:
......@@ -420,4 +450,19 @@ public:
std::string raw_type() const { return "GetBackwardClosure"; }
};
class GradTransformationGuard final : public Transformation {
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (auto* igc = op.as<InsertGradCallback>()) {
auto count = IntegerValue::make(0);
return {count};
}
return imperative::apply(op, inputs);
}
ValueRef unwrap(ValueRef value) override { return value; };
std::string name() const override { return "GradTransformationGuard"; };
};
} // namespace mgb::imperative
......@@ -2,8 +2,8 @@
#include <chrono>
#include <future>
#include <set>
#include <variant>
#include "megbrain/gopt/inference.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
......@@ -17,11 +17,21 @@ namespace mgb::imperative {
struct TraceResult {
struct SeqItem {
enum OpKind {
Unknown,
TraceMarkVar,
Rename,
IOMarkVar,
CreateTensor,
};
std::shared_ptr<OpDef> op;
SmallVector<size_t> inputs;
SmallVector<size_t> outputs;
OpKind kind = OpKind::Unknown;
};
using OpKind = SeqItem::OpKind;
struct VarInfo {
enum Kind {
External, // End point of traced graph, its value is received from
......@@ -41,12 +51,14 @@ struct TraceResult {
ValueRef bound_data;
std::string mark;
std::string name;
int handle_id;
Kind kind;
bool value_required = false;
bool data_required = false;
bool shape_required = false;
std::set<size_t> inp_marker;
std::set<size_t> out_marker;
TensorShape shape;
};
......@@ -91,6 +103,27 @@ public:
std::string raw_type() const { return "TraceMarkVar"; }
};
class IOMarkVar : public OperatorImpl<IOMarkVar, Operator::IdentityLike> {
public:
enum Kind {
Input,
Output,
};
private:
size_t m_mark;
Kind m_kind;
public:
IOMarkVar(size_t mark, Kind kind) : m_mark(mark), m_kind(kind) {}
size_t mark() const { return m_mark; }
Kind kind() const { return m_kind; }
std::string to_string() const override { return ssprintf("IOMarkVar"); }
std::string raw_type() const override { return "IOMarkVar"; }
};
class TracingValue final : public ObjectValue<TracingValue> {
private:
ValueRef m_value = {};
......@@ -125,15 +158,22 @@ class TracingTransformation final : public Transformation {
public:
using VarInfo = TraceResult::VarInfo;
using VarKind = VarInfo::Kind;
using OpKind = TraceResult::SeqItem::OpKind;
private:
std::vector<TraceResult::SeqItem> m_seq;
std::vector<TraceResult::VarInfo> m_vars;
std::vector<TracingValue::weak_ref_t> m_weak_vars;
std::unordered_map<size_t, size_t> extern_var_to_id;
bool m_capture_as_const = false;
bool m_record_input_shapes = false;
bool m_record_all_shapes = false;
ObjectType<TracingValue> m_value_type{"TracingValue"};
public:
std::unordered_map<size_t, size_t> inpmark_to_id;
std::unordered_map<size_t, size_t> outmark_to_id;
public:
TracingTransformation(bool capture_as_const, bool record_input_shapes)
: m_capture_as_const(capture_as_const),
......@@ -148,7 +188,14 @@ public:
* \return TypedValueRef<TracingValue> traced value
*/
TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
if (kind == VarKind::External &&
extern_var_to_id.find(value.id()) != extern_var_to_id.end()) {
return m_value_type.make(value, extern_var_to_id[value.id()]);
}
size_t id = m_vars.size();
if (kind == VarKind::External) {
extern_var_to_id[value.id()] = id;
}
auto wrapped_value = m_value_type.make(value, id);
m_vars.push_back({id, value.dtype(), value.device()});
auto& var = m_vars.back();
......@@ -156,9 +203,12 @@ public:
var.bound_data = value;
}
var.kind = kind;
if (m_record_input_shapes && kind != VarKind::Internal) {
if ((m_record_input_shapes && kind != VarKind::Internal) ||
m_record_all_shapes) {
var.shape = value.shape()->as_tensor_shape();
}
if (m_record_all_shapes)
var.handle_id = value.handle_id();
if (auto name = value.name()) {
var.name = *name;
}
......@@ -185,8 +235,9 @@ public:
std::string name() const override { return "TracingTransformation"; }
void on_unregister() noexcept override;
void postprocess_trace_result();
TraceResult get_result() { return {m_seq, m_vars}; }
void enable_record_all_shapes() { m_record_all_shapes = true; }
};
class TraceError : public std::exception {
......@@ -211,6 +262,7 @@ class CompiledTransformation final : public Transformation {
public:
using VarInfo = TraceResult::VarInfo;
using VarKind = VarInfo::Kind;
using OpKind = TraceResult::SeqItem::OpKind;
struct VarAccessor {
VarNode* node;
......@@ -254,6 +306,7 @@ private:
std::vector<TraceResult::SeqItem> m_seq;
std::vector<TraceResult::VarInfo> m_vars;
std::vector<VarAccessor> m_var_accessors;
std::unordered_map<std::string, size_t> mark2id;
size_t m_pc = 0;
std::shared_ptr<ComputingGraph> m_graph;
std::unique_ptr<cg::AsyncExecutable> m_executable;
......@@ -268,6 +321,7 @@ private:
std::vector<std::shared_ptr<BoxBase>> m_boxes;
ComputingGraph::OutputSpec m_output_spec;
ObjectType<TracedValue> m_value_type{"TracedValue"};
std::set<size_t> m_setted_extern;
public:
CompiledTransformation(TraceResult result, bool input_shape_static)
......@@ -360,8 +414,10 @@ public:
return value;
}
std::string name() const override { return "CompiledTransformation"; }
VarAccessor& get_accessor_by_id(size_t id) { return m_var_accessors[id]; }
std::string name() const override { return "CompiledTransformation"; }
void set_pc_to_end() { m_pc = m_seq.size(); }
void execute();
void wait();
......
......@@ -222,6 +222,7 @@ public:
TypedValueRef<DTypeValue> dtype() const;
TypedValueRef<FormatValue> format() const;
TypedValueRef<StringValue> name() const;
int handle_id() const;
bool is_scalar() const;
void watch() const;
......@@ -298,7 +299,7 @@ protected:
public:
const IType& type() const { return *m_type; }
uint64_t id() const { return m_id; }
static void register_value(ValueRef value);
static ValueRef get_value_by_id(uint64_t id);
static void begin_record_values();
......@@ -538,11 +539,11 @@ public:
const ValueRef* data() const { return m_data; }
bool empty() const { return m_size == 0; }
ValueRef& front() {
mgb_assert(m_size > 1);
mgb_assert(m_size >= 1);
return m_data[0];
}
ValueRef& back() {
mgb_assert(m_size > 1);
mgb_assert(m_size >= 1);
return m_data[m_size - 1];
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册