提交 b11d4430 编写于 作者: M Megvii Engine Team

feat(mge/jit): support trace withouthost mode

GitOrigin-RevId: 09b29e3dac44a4e4330f2ceb10da7d55df772466
上级 3116e9f7
...@@ -8,6 +8,8 @@ _use_symbolic_shape = False ...@@ -8,6 +8,8 @@ _use_symbolic_shape = False
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"): if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
_use_symbolic_shape = True _use_symbolic_shape = True
_use_xla_backend = False
def use_symbolic_shape() -> bool: def use_symbolic_shape() -> bool:
r"""Returns whether tensor.shape returns a tensor instead of a tuple""" r"""Returns whether tensor.shape returns a tensor instead of a tuple"""
...@@ -22,4 +24,15 @@ def set_symbolic_shape(option: bool): ...@@ -22,4 +24,15 @@ def set_symbolic_shape(option: bool):
return _org return _org
def use_xla_backend() -> bool:
return _use_xla_backend
def set_use_xla_backend(option: bool) -> bool:
global _use_xla_backend
_org = _use_xla_backend
_use_xla_backend = option
return _org
set_cpp_use_symbolic_shape(use_symbolic_shape) set_cpp_use_symbolic_shape(use_symbolic_shape)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from .dtr_config import DTRConfig from .dtr_config import DTRConfig
from .graph_opt_config import GraphOptimizationConfig from .graph_opt_config import GraphOptimizationConfig
from .partial_tracing import partial_trace
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import TraceError, exclude_from_trace, trace from .tracing import TraceError, exclude_from_trace, trace
from .xla_backend import xla_trace
from collections import OrderedDict
from typing import Sequence
from ..core._imperative_rt.core2 import add_backward_callback as _add_backward_callback
from ..core._imperative_rt.core2 import get_grad_slot, get_handle_id
from ..tensor import Tensor
from .tracing import trace
from .xla_backend import xla_trace
def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map):
# partial_trace will record op sequences for forward/backward respectively, and get two TraceResult objects after tracing.
# But the inputs/outputs of backward graph are unknown. This function will determine the inputs and outputs of the backward graph
# var.handle_id is id of value ref. It's used to find the tensors used in both forward and backward calculation.
# inp_grad_map, key: handle id of forward inputs, value: handle id of grads of forward inputs.
# out_grad_map, key: handle id of foward outputs, value: handle id of grads of forward outputs.
fwd_features = set([t.handle_id for t in fwd._trace.vars])
bwd_features = set([t.handle_id for t in bwd._trace.vars])
keep_vars = fwd_features.intersection(
bwd_features
) # some intermediate vars produced by forward, and will be used in backward.
current = max(fwd.out_list) + 1
saved_feature_map = OrderedDict()
saved_featrues = []
# mark keep_vars as forward outputs
for var in fwd._trace.vars:
if (
var.handle_id in keep_vars
and var.data_required
and len(var.out_mark) == 0
and var.kind not in ["const", "external"]
):
keep_vars.remove(var.handle_id)
fwd._trace.mark_output(current, var.id)
saved_feature_map[var.handle_id] = current
saved_featrues.append(current)
current += 1
fwd.keeped_activation = saved_featrues
bwd_inp_idx = 0
bwd_out_idx = 0
bwd_dys = []
bwd_inps = [-1] * len(saved_feature_map)
saved_feature_handle_id = list(saved_feature_map.keys())
dy_ids = list(out_grad_map.values()) # handle_id of grad of forward output
inp_grad_ids = list(inp_grad_map.values()) # handle_id of grad of forward input
bwd_dys = [-1] * len(dy_ids)
bwd_outputs = [-1] * len(inp_grad_ids)
# dy_ids + saved_feature_map are backward inputs
# inp_grad_ids are backward outputs
# mark inputs/outputs for backward
for var in bwd._trace.vars:
if var.handle_id in dy_ids and var.kind == "external":
bwd._trace.mark_input(bwd_inp_idx, var.id)
idx = dy_ids.index(var.handle_id)
bwd_dys[idx] = bwd_inp_idx
bwd_inp_idx += 1
elif var.handle_id in saved_feature_map and var.kind == "external":
bwd._trace.mark_input(bwd_inp_idx, var.id)
bwd_inps[saved_feature_handle_id.index(var.handle_id)] = bwd_inp_idx
bwd_inp_idx += 1
if var.handle_id in inp_grad_ids and var.data_required:
bwd_outputs[inp_grad_ids.index(var.handle_id)] = bwd_out_idx
bwd._trace.mark_output(bwd_out_idx, var.id)
bwd_out_idx += 1
# assert -1 not in bwd_dys
assert -1 not in bwd_inps
for var in fwd._trace.vars:
if not var.out_mark:
var.data_required = False
# assert -1 not in bwd_outputs
bwd.setup_io_without_trace(bwd_dys + bwd_inps, bwd_outputs)
bwd.setup_without_host()
def check_external(trace_obj):
for var in trace_obj.vars:
if var.kind == "external" and not var.inp_mark:
raise RuntimeError("have unknown input in trace result")
check_external(fwd)
check_external(bwd)
JIT_BACKEND = {"default": trace, "xla": xla_trace}
def partial_trace(func=None, *, backend="default", without_host=True, **trace_options):
assert backend in JIT_BACKEND
assert without_host, "partial_trace only support without_host mode currently!"
def wrapper(func):
trace_obj = JIT_BACKEND[backend](
func, without_host=without_host, **trace_options
)
trace_options["capture_as_const"] = False
backward_trace_obj = JIT_BACKEND[backend](
None, without_host=without_host, **trace_options
)
backward_trace_obj.check_external = (
False # check if there are unknown external vars after tracing.
)
trace_obj.overall = False # if trace overall train step
backward_trace_obj.overall = False
trace_obj._trace.remove_unused_data_required = False
backward_trace_obj._trace.remove_unused_data_required = False
inp_grad_maps = OrderedDict() # x, dx map
out_grad_maps = OrderedDict() # y, dy map
traced = False # if wrapped function has been traced
compiled = False # if wrapped function has been compiled
custom_autodiff = None
outdef = None # treedef of forward return value
from ..core.autodiff.grad import Function
class CustomAutodiff(Function):
def __init__(self, fwd, bwd):
self.fwd = fwd
self.bwd = bwd
del fwd.outdef
self.keeped_features = []
def forward(self, *args):
rst = self.fwd(*args)
keeped_features = rst[-1]
if not isinstance(keeped_features, Sequence):
keeped_features = tuple([keeped_features])
else:
keeped_features = tuple(keeped_features)
self.keeped_features = keeped_features
return rst[0]
def get_keeped_features(self):
rst = self.keeped_features
del self.keeped_features
return rst
def backward(self, *output_grads):
output_grads = tuple([i for i in output_grads if i is not None])
return self.bwd(*(output_grads + self.get_keeped_features()))
class CustomFwd:
def __init__(self, fwd, bwd):
self.fwd = fwd
self.bwd = bwd
def __call__(self, *args):
rst = self.fwd(*args)
if self.fwd.keeped_activation:
keeped_features = rst[-1]
if not isinstance(keeped_features, Sequence):
keeped_features = tuple([keeped_features])
else:
keeped_features = tuple(keeped_features)
self.keeped_features = keeped_features
return rst[0]
else:
return rst
def wrapped_func(*args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..module import Module
nonlocal traced
nonlocal compiled
nonlocal custom_autodiff
nonlocal outdef
if not traced:
traced = True
fargs = trace_obj.flatten_inputs(*args, **kwargs)
for t in fargs:
inp_grad_maps[t] = get_grad_slot(t)
del fargs
def exit_trace():
backward_trace_obj._trace.exit()
new_dict = {}
for k, v in inp_grad_maps.items():
if v is not None:
new_dict[get_handle_id(k)] = get_handle_id(v.grad)
else:
new_dict[get_handle_id(k)] = -1
inp_grad_maps.clear()
inp_grad_maps.update(new_dict)
_add_backward_callback(exit_trace)
ret = trace_obj(*args)
rlist, outdef = tree_flatten(ret)
for t in rlist:
out_grad_maps[t] = get_grad_slot(t)
def enter_trace():
new_dict = {}
for k, v in out_grad_maps.items():
if v is not None:
new_dict[get_handle_id(k)] = get_handle_id(v.grad)
out_grad_maps.clear()
out_grad_maps.update(new_dict)
backward_trace_obj._trace.enter()
_add_backward_callback(enter_trace)
return ret
elif not compiled:
if custom_autodiff is None:
_process_fwd_bwd_trace_result(
trace_obj, backward_trace_obj, inp_grad_maps, out_grad_maps
)
if len(backward_trace_obj._trace.ops) > 0:
custom_autodiff = CustomAutodiff(trace_obj, backward_trace_obj)
else:
custom_autodiff = CustomFwd(trace_obj, backward_trace_obj)
fargs = trace_obj.flatten_inputs(*args, **kwargs)
del args
del kwargs
if outdef is None:
return custom_autodiff(*fargs)
else:
return outdef.unflatten(custom_autodiff(*fargs))
return wrapped_func
if func is None:
return wrapper
else:
return wrapper(func)
...@@ -9,6 +9,7 @@ import pickle ...@@ -9,6 +9,7 @@ import pickle
import re import re
import struct import struct
import sys import sys
from collections import OrderedDict, defaultdict
from typing import Any, Sequence from typing import Any, Sequence
import cv2 import cv2
...@@ -16,9 +17,22 @@ import numpy as np ...@@ -16,9 +17,22 @@ import numpy as np
from .. import tensor from .. import tensor
from ..core import _imperative_rt as rt from ..core import _imperative_rt as rt
from ..core._imperative_rt import GraphProfiler, GraphProfiler2, SerializationMetadata from ..core._imperative_rt import (
CompNode,
GraphProfiler,
GraphProfiler2,
SerializationMetadata,
)
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import Trace, TraceError, name_tensor # skip_tracing, from ..core._imperative_rt.core2 import Trace, TraceError # skip_tracing,
from ..core._imperative_rt.core2 import add_backward_callback as _add_backward_callback
from ..core._imperative_rt.core2 import (
get_marked_input_tensor,
get_marked_output_tensor,
get_marked_tensor,
marked_input_tensor,
name_tensor,
)
from ..core._imperative_rt.graph import _set_priority_to_id from ..core._imperative_rt.graph import _set_priority_to_id
from ..core._imperative_rt.ops import ( from ..core._imperative_rt.ops import (
AssertEqual, AssertEqual,
...@@ -31,6 +45,7 @@ from ..core._imperative_rt.ops import ( ...@@ -31,6 +45,7 @@ from ..core._imperative_rt.ops import (
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Tensor
from ..utils import comp_graph_tools as cgtools from ..utils import comp_graph_tools as cgtools
from ..utils.naming import AutoNaming from ..utils.naming import AutoNaming
from ..utils.profiler import is_profiling from ..utils.profiler import is_profiling
...@@ -94,8 +109,13 @@ class trace: ...@@ -94,8 +109,13 @@ class trace:
opt_level: optimization level for compiling trace. Default: 2 opt_level: optimization level for compiling trace. Default: 2
graph_opt_config: configuration for graph optimization. Default: None graph_opt_config: configuration for graph optimization. Default: None
symbolic_shape: whether to use symbolic shape for tracing. Default: True symbolic_shape: whether to use symbolic shape for tracing. Default: True
without_host: if True, will run python code of wrapped function on the first call,
and run the compiled graph/function on subsequent calls. if False, will run python code every time.
Default: False
""" """
third_party_backend = False
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if not args: if not args:
return functools.partial(cls, **kwargs) return functools.partial(cls, **kwargs)
...@@ -113,6 +133,7 @@ class trace: ...@@ -113,6 +133,7 @@ class trace:
opt_level: int = 2, opt_level: int = 2,
graph_opt_config: GraphOptimizationConfig = None, graph_opt_config: GraphOptimizationConfig = None,
symbolic_shape: bool = True, symbolic_shape: bool = True,
without_host: bool = False,
): ):
self.__wrapped__ = function self.__wrapped__ = function
self._capture_as_const = capture_as_const or record_only self._capture_as_const = capture_as_const or record_only
...@@ -150,6 +171,7 @@ class trace: ...@@ -150,6 +171,7 @@ class trace:
graph_options["graph_opt.jit_config.fuse_reduce"] = mapping[ graph_options["graph_opt.jit_config.fuse_reduce"] = mapping[
graph_opt_config.jit_fuse_reduce graph_opt_config.jit_fuse_reduce
] ]
if sublinear_memory_config is not None: if sublinear_memory_config is not None:
graph_options["enable_sublinear_memory_opt"] = True graph_options["enable_sublinear_memory_opt"] = True
graph_options[ graph_options[
...@@ -186,8 +208,114 @@ class trace: ...@@ -186,8 +208,114 @@ class trace:
self._trace.profile = profiling self._trace.profile = profiling
self._trace.array_comparator = array_comparator self._trace.array_comparator = array_comparator
self._trace.record_input_shapes = _input_node_use_static_shape() self._trace.record_input_shapes = _input_node_use_static_shape()
self._trace.without_host = without_host
self.check_external = True
self.traced = False
self.input_num = 0
self.output_num = 0
self.arg_list = []
self.out_list = []
self.overall = True
# forward keeped activation
self.keeped_activation = []
self.third_party_backend_compiled = False
@property
def check_external(self):
return self._trace.check_external
@check_external.setter
def check_external(self, flag):
self._trace.check_external = flag
@property
def without_host(self):
return self._trace.without_host
def flatten_inputs(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..module import Module
tensor_args = []
modules = []
fargs, _ = tree_flatten((args, kwargs))
for a in fargs:
if isinstance(a, RawTensor):
tensor_args.append(a)
elif isinstance(a, Module):
modules.append(a)
for m in modules:
tensor_args.extend(list(m.parameters()))
return tensor_args
def compile(self):
raise NotImplementedError
def execute(self, *args, **kwargs):
raise NotImplementedError
def setup_env(self):
pass
def unset_env(self):
pass
def compile_and_exec(self, *args, **kwargs):
if not self.third_party_backend_compiled:
self.compile()
self.third_party_backend_compiled = True
return self.execute(*args, **kwargs)
def convert_optimizer_state_to_tensor(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS
from ..optimizer import Optimizer
from ..tensor import Tensor
if Optimizer not in SUPPORTED_LEAF_CLS:
SUPPORTED_LEAF_CLS.append(Optimizer)
args, _ = tree_flatten((args, kwargs))
for arg in args:
if isinstance(arg, Optimizer):
arg._disable_type_convert = False
for param_group in arg.param_groups:
for k, v in param_group.items():
if not isinstance(v, (Tensor, Sequence)):
param_group[k] = Tensor(v)
elif isinstance(v, Sequence) and not isinstance(v[0], Tensor):
new_v = []
for i in range(len(v)):
new_v.append(Tensor(v[i]))
param_group[k] = new_v
def setup_io_without_trace(self, inputs, outputs):
self.traced = True
self.arg_list = [i for i in inputs if i != -1]
self.out_list = outputs
self.input_num = len(self.arg_list)
self.output_num = len([i for i in outputs if i != -1])
def setup_without_host(self):
self.inp_modules = set()
self.module_tensors = set()
self.tensor_to_attr = dict()
self.attr_to_key = dict()
self.update_param_dict = dict()
self.update_opt_param_dict = dict()
self.capture_optimizer_state = set()
self.opt_param_dict = dict()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if not self.without_host:
return self.trace_normal(*args, **kwargs)
elif self.overall:
return self.trace_without_host_overall(*args, **kwargs)
else:
return self.trace_without_host(*args, **kwargs)
def trace_normal(self, *args, **kwargs):
global active_trace global active_trace
symbolic_shape = None symbolic_shape = None
outputs = None outputs = None
...@@ -214,6 +342,270 @@ class trace: ...@@ -214,6 +342,270 @@ class trace:
raise raise
return outputs return outputs
def trace_without_host(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS
from ..module import Module
from ..utils.module_utils import get_expand_structure
from ..tensor import Tensor
from ..optimizer import Optimizer
assert self.without_host and not self.overall
global active_trace
symbolic_shape = None
outputs = None
if self.traced and self.third_party_backend:
return self.compile_and_exec(*args, **kwargs)
try:
active_trace = self
self._trace.enter()
if self._trace.compiled():
arglist = self.flatten_inputs(*args, **kwargs)
idx = 0
inp_dict = {}
for a in arglist:
if isinstance(a, RawTensor):
inp_dict[self.arg_list[idx]] = a
idx += 1
self._trace.put_datas(inp_dict)
outlist = []
for i in self.out_list:
if i == -1:
if not hasattr(self, "outdef"):
outlist.append(None)
else:
outlist.append(self._trace.get_data(i))
keep_vars = []
for i in self.keeped_activation:
keep_vars.append(self._trace.get_data(i))
outputs = (
self.outdef.unflatten(outlist)
if hasattr(self, "outdef")
else outlist
)
if keep_vars:
return outputs, keep_vars
else:
return outputs
arg_list = self.flatten_inputs(*args, **kwargs)
for i, arg in enumerate(arg_list):
arg_list[i]._reset(get_marked_input_tensor(self.input_num, arg))
self.arg_list.append(self.input_num)
self.input_num += 1
del arg_list
symbolic_shape = set_symbolic_shape(self._symbolic_shape)
if self.third_party_backend:
self.setup_env()
outputs = self.__wrapped__(*args, **kwargs)
finally:
handling_exc = sys.exc_info() != (None,) * 3
active_trace = None
if symbolic_shape is not None:
symbolic_shape = set_symbolic_shape(symbolic_shape)
assert symbolic_shape == self._symbolic_shape
if self.third_party_backend:
self.unset_env()
if (
self._capture_as_const
and (outputs is not None)
and not self.without_host
):
self._process_outputs(outputs)
if not self._trace.compiled():
outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist):
assert isinstance(out, RawTensor), type(out)
outlist[i] = get_marked_output_tensor(self.output_num, out)
del out
self.out_list.append(self.output_num)
self.output_num += 1
outputs = self.outdef.unflatten(outlist)
try:
# may raise TraceError
self._trace.exit()
except Exception as e:
if isinstance(e, TraceError):
if not handling_exc:
raise
else:
self._trace.set_execption(str(e))
raise
self.traced = True
return outputs
def trace_without_host_overall(self, *args, **kwargs):
# record overall train step include forward, backward, param update in a single trace object
from ..traced_module.pytree import tree_flatten, SUPPORTED_LEAF_CLS
from ..module import Module
from ..utils.module_utils import get_expand_structure
from ..tensor import Tensor
from ..optimizer import Optimizer
assert self.without_host
global active_trace
symbolic_shape = None
outputs = None
if self.traced and self.third_party_backend:
return self.compile_and_exec(*args, **kwargs)
try:
active_trace = self
if not self.traced:
self.convert_optimizer_state_to_tensor(*args, **kwargs)
self._trace.enter()
if self._trace.compiled():
arglist, _ = tree_flatten((args, kwargs))
idx = 0
inp_dict = {}
for a in arglist:
if isinstance(a, RawTensor):
inp_dict[self.arg_list[idx]] = a
idx += 1
for t, key in self.opt_param_dict.items():
inp_dict[key] = t
self._trace.put_datas(inp_dict)
for attr, key in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
self._trace.put_data(key, param)
outlist = []
for i in self.out_list:
if i == -1:
if not hasattr(self, "outdef"):
outlist.append(None)
else:
outlist.append(self._trace.get_data(i))
for attr, key in self.update_param_dict.items():
param = get_expand_structure(attr[0], attr[1])
param._reset(self._trace.get_data(key))
for state, key in self.update_opt_param_dict.items():
state._reset(self._trace.get_data(key))
keep_vars = []
for i in self.keeped_activation:
keep_vars.append(self._trace.get_data(i))
outputs = (
self.outdef.unflatten(outlist)
if hasattr(self, "outdef")
else outlist
)
if keep_vars:
return outputs, keep_vars
else:
return outputs
self.setup_without_host()
def get_attr_hook(obj, attr):
rst = object.__getattribute__(obj, attr)
if isinstance(rst, RawTensor):
assert rst in self.tensor_to_attr
attr = self.tensor_to_attr[rst]
if attr not in self.attr_to_key:
self.attr_to_key[attr] = self.input_num
self.input_num += 1
marked_input_tensor(self.attr_to_key[attr], rst)
return rst
origin_reset = Tensor._reset
self.update_param_num = 0
def tensor_wrapper_resethook(obj, other):
if obj in self.tensor_to_attr:
attr = self.tensor_to_attr[obj]
other = get_marked_output_tensor(self.output_num, other)
self.update_param_num += 1
self.update_param_dict[attr] = self.output_num
self.output_num += 1
elif obj in self.capture_optimizer_state:
other = get_marked_output_tensor(self.output_num, other)
self.update_opt_param_dict[obj] = self.output_num
self.output_num += 1
origin_reset(obj, other)
arg_list, self.argdef = tree_flatten((args, kwargs))
for i, arg in enumerate(arg_list):
if isinstance(arg, Module):
for k, v in arg.named_tensors():
if v not in self.tensor_to_attr:
self.tensor_to_attr[v] = (arg, k)
self.inp_modules.add(arg)
elif isinstance(arg, RawTensor):
arg_list[i] = get_marked_input_tensor(self.input_num, arg)
self.arg_list.append(self.input_num)
self.input_num += 1
elif isinstance(arg, Optimizer):
opt_params, _ = tree_flatten(arg.state_dict(keep_var=True))
for p in opt_params:
if isinstance(p, Tensor):
self.capture_optimizer_state.add(p)
self.opt_param_dict = {}
for t in self.capture_optimizer_state:
if t not in self.tensor_to_attr: # not module parameter
mark_param = get_marked_input_tensor(self.input_num, t)
self.opt_param_dict[t] = self.input_num
t[...] = mark_param
self.input_num += 1
args, kwargs = self.argdef.unflatten(arg_list)
Module.__getattribute__ = get_attr_hook
Tensor._reset = tensor_wrapper_resethook
symbolic_shape = set_symbolic_shape(self._symbolic_shape)
if self.third_party_backend:
self.setup_env()
outputs = self.__wrapped__(*args, **kwargs)
del arg_list
del args
del kwargs
Module.__getattribute__ = object.__getattribute__
Tensor._reset = origin_reset
for attr, key in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
finally:
handling_exc = sys.exc_info() != (None,) * 3
active_trace = None
if symbolic_shape is not None:
symbolic_shape = set_symbolic_shape(symbolic_shape)
assert symbolic_shape == self._symbolic_shape
if self.third_party_backend:
self.unset_env()
if (
self._capture_as_const
and (outputs is not None)
and not self.without_host
):
self._process_outputs(outputs)
if not self._trace.compiled():
outlist, self.outdef = tree_flatten(outputs)
for i, out in enumerate(outlist):
assert isinstance(out, RawTensor)
outlist[i] = get_marked_output_tensor(self.output_num, out)
del out
self.out_list.append(self.output_num)
self.output_num += 1
outputs = self.outdef.unflatten(outlist)
try:
# may raise TraceError
self._trace.exit()
except Exception as e:
if isinstance(e, TraceError):
if not handling_exc:
raise
else:
self._trace.set_execption(str(e))
raise
self.traced = True
return outputs
@property
def ops(self):
return self._trace.ops
@property
def vars(self):
return self._trace.vars
def _process_inputs(self, *args, **kwargs): def _process_inputs(self, *args, **kwargs):
for i, arg in enumerate(args): for i, arg in enumerate(args):
assert isinstance( assert isinstance(
......
from collections import OrderedDict, defaultdict
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._trace_option import set_use_xla_backend
from ..device import get_default_device
from ..utils.dlpack import from_dlpack, to_dlpack
from .tracing import trace
try:
from ..xla.lib import xla_client as xc
except ImportError:
pass
class xla_trace(trace):
r"""Wraps a callable, and provides accelerated evaluation compiled by xla.
Currently it is an experimental feature.
Refer to :class:`~.jit.tracing.trace` for more information.
Examples:
.. code-block:: python
import numpy as np
from basecls.models.resnet import resnet18
from megengine.autodiff.grad_manager import GradManager
from megengine.jit import xla_trace
from megengine.optimizer import Adam
model = resnet18()
gm = GradManager()
opt = Adam(model.parameters(), lr=1e-4)
gm.attach(model.parameters())
# Only tensors in wrapped func args/kwargs will be treated as graph inputs,
# and other tensors will be captured as const value.
# Module, optimizer, and train data/label should be arguments of the wrapped function.
@xla_trace(capture_as_const=True)
def train_step(model, opt, data, label):
with gm:
pred = model(data)
loss = F.loss.cross_entropy(pred, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
"""
third_party_backend = True
def __init__(self, function, *, without_host=True, symbolic_shape=False, **kwargs):
assert without_host, "xla trace only support without host mode"
assert not symbolic_shape, "xla doesn't support dynamic shape currently"
super().__init__(
function, without_host=without_host, symbolic_shape=symbolic_shape, **kwargs
)
def setup_env(self):
self.orig_use_xla = set_use_xla_backend(True)
def unset_env(self):
set_use_xla_backend(self.orig_use_xla)
def compile(self):
from ..xla import build_xla
from ..traced_module.pytree import SUPPORTED_LEAF_TYPE, register_supported_type
from ..utils.module_utils import get_expand_structure
from ..xla.device import get_xla_backend_and_device
from ..tensor import Tensor
assert self.traced
if self.overall:
for attr, _ in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
param._reset(param.to("cpux"))
for tensor, _ in self.opt_param_dict.items():
tensor._reset(tensor.to("cpux"))
self.xla_exec, self.inp_ids, self.out_ids = build_xla(
self, return_with_io=True, return_device_array=True
)
id2inpidx = defaultdict(list)
id2outidx = defaultdict(list)
for idx, id in enumerate(self.inp_ids):
id2inpidx[id].append(idx)
for idx, id in enumerate(self.out_ids):
id2outidx[id].append(idx)
self.inpkey2idx = {}
self.outkey2idx = {}
if self.input_num == len(set(self.inp_ids)) - 1:
self.has_randomstate = True
self.random_seed = Tensor([[1, 2], [3, 4]], dtype="int32")
else:
assert self.input_num == len(set(self.inp_ids))
self.has_randomstate = False
inpmark2id = dict()
outmark2id = dict()
for var in self.vars:
if var.kind == "external":
for mark in var.inp_mark:
inpmark2id[mark] = var.id
elif var.data_required and var.out_mark:
for mark in var.out_mark:
outmark2id[mark] = var.id
for k, v in inpmark2id.items():
for idx in id2inpidx[v]:
self.inpkey2idx[k] = idx
for k, v in outmark2id.items():
for idx in id2outidx[v]:
self.outkey2idx[k] = idx
def prepare_xla_inputs(self, tensors):
from ..utils.module_utils import get_expand_structure
inp_count = 0
inp_list = [0] * self.input_num
for idx, t in enumerate(tensors):
inp = self.inpkey2idx[self.arg_list[idx]]
inp_list[inp] = t
inp_count += 1
if self.overall:
for attr, key in self.attr_to_key.items():
param = get_expand_structure(attr[0], attr[1])
inp = self.inpkey2idx[key]
inp_list[inp] = param
inp_count += 1
for tensor, k in self.opt_param_dict.items():
inp = self.inpkey2idx[k]
inp_list[inp] = tensor
inp_count += 1
assert inp_count == self.input_num
if self.has_randomstate:
inp_list.append(self.random_seed)
return inp_list
def to_dlpack(self, x, take_ownership: bool = True):
return xc._xla.buffer_to_dlpack_managed_tensor(x, take_ownership=take_ownership)
def execute(self, *args, **kwargs):
from ..traced_module.pytree import tree_flatten
from ..tensor import Tensor
from ..utils.module_utils import get_expand_structure
inputs, _ = tree_flatten((args, kwargs))
arrays = []
cn = CompNode(get_default_device())
stream = dict(self.xla_exec.backend.get_compute_compnode())
device_kind, device_id, stream_id = cn.physical_locator
xla_stream = stream[device_id]
xla_comp_cn = "gpu{}:{}".format(device_id, xla_stream)
for t in inputs:
if isinstance(t, RawTensor):
assert cn == t.device
arrays.append(t.to(xla_comp_cn, _borrow=True))
arrays = self.prepare_xla_inputs(arrays)
outputs = self.xla_exec(*arrays)
return_vals = []
for i in self.out_list:
if i == -1:
if not hasattr(self, "outdef"):
return_vals.append(None)
else:
return_vals.append(outputs[self.outkey2idx[i]])
keeped_features = []
for i in self.keeped_activation:
capsule = self.to_dlpack(outputs[self.outkey2idx[i]])
t = from_dlpack(capsule, xla_stream).to(cn, _borrow=True)
keeped_features.append(t)
out_tensors = []
for array in return_vals:
if array is not None:
capsule = self.to_dlpack(array)
t = from_dlpack(capsule, xla_stream)
out_tensors.append(t.to(cn, _borrow=True))
else:
out_tensors.append(array)
if self.overall:
for attr, key in self.update_param_dict.items():
param = get_expand_structure(attr[0], attr[1])
xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array)
param._reset(from_dlpack(capsule).to(cn, _borrow=True))
for state, key in self.update_opt_param_dict.items():
xla_array = outputs[self.outkey2idx[key]]
capsule = self.to_dlpack(xla_array)
state._reset(from_dlpack(capsule).to(cn, _borrow=True))
rst = (
self.outdef.unflatten(out_tensors)
if hasattr(self, "outdef")
else out_tensors
)
if keeped_features:
return rst, keeped_features
else:
return rst
...@@ -49,6 +49,8 @@ def _access_structure(obj, key, callback=None): ...@@ -49,6 +49,8 @@ def _access_structure(obj, key, callback=None):
cur = cur[k] cur = cur[k]
else: else:
cur = getattr(cur, k) cur = getattr(cur, k)
if callable is None:
return cur
return callback(parent, k, cur) return callback(parent, k, cur)
......
...@@ -115,11 +115,9 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) { ...@@ -115,11 +115,9 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
} }
py::object Py_Varnode = py::none(); py::object Py_Varnode = py::none();
const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
std::make_unique<mgb::OprFootprint>()};
void init_graph_rt(py::module m) { void init_graph_rt(py::module m) {
static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{
std::make_unique<mgb::OprFootprint>()};
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous"); def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
namespace py = pybind11; namespace py = pybind11;
extern py::object Py_Varnode; extern py::object Py_Varnode;
extern const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr;
template <typename T> template <typename T>
class GraphNodePtr { class GraphNodePtr {
std::shared_ptr<mgb::cg::ComputingGraph> m_graph; std::shared_ptr<mgb::cg::ComputingGraph> m_graph;
......
此差异已折叠。
...@@ -74,6 +74,7 @@ public: ...@@ -74,6 +74,7 @@ public:
inline ValueRef data() const { return m_data.unwrap(); } inline ValueRef data() const { return m_data.unwrap(); }
bool is_scalar() { return data().is_scalar(); } bool is_scalar() { return data().is_scalar(); }
inline std::string name() { return m_name; } inline std::string name() { return m_name; }
inline size_t value_id() { return m_data.id(); }
inline void set_name(std::string name) { inline void set_name(std::string name) {
m_name = name; m_name = name;
if (!name.empty()) { if (!name.empty()) {
...@@ -128,6 +129,7 @@ public: ...@@ -128,6 +129,7 @@ public:
void reset(PyObject*); void reset(PyObject*);
PyObject* detach(); PyObject* detach();
PyObject* isscalar(); PyObject* isscalar();
PyObject* value_id();
PyObject* _dev_tensor(); PyObject* _dev_tensor();
void _drop(); void _drop();
PyObject* varnode(); PyObject* varnode();
......
...@@ -18,7 +18,13 @@ from megengine.core.ops import builtin as ops ...@@ -18,7 +18,13 @@ from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.utils import isscalar from megengine.core.tensor.utils import isscalar
from megengine.functional import exp, log from megengine.functional import exp, log
from megengine.jit import GraphOptimizationConfig, TraceError, exclude_from_trace, trace from megengine.jit import (
GraphOptimizationConfig,
TraceError,
exclude_from_trace,
partial_trace,
trace,
)
from megengine.module import Module from megengine.module import Module
from megengine.random import normal, uniform from megengine.random import normal, uniform
from megengine.utils.naming import AutoNaming from megengine.utils.naming import AutoNaming
...@@ -803,3 +809,87 @@ def test_dump_without_output_error(): ...@@ -803,3 +809,87 @@ def test_dump_without_output_error():
str(e) str(e)
== "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]" == "the traced function without return values cannot be dumped, the traced function should return List[Tensor] or Dict[str, Tensor]"
) )
@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_without_host(trace_mode):
@trace(symbolic=trace_mode, without_host=True)
def fwd(a, b, c):
x = a + b
y = a + c
z = x * y
z1 = x / y
return [z, z1]
a = tensor([1.0])
b = tensor([2.0])
c = tensor([3.0])
rst = fwd(a, b, c)
for _ in range(2):
trace_rst = fwd(a, b, c)
np.testing.assert_equal(rst[0], trace_rst[0])
np.testing.assert_equal(rst[1], trace_rst[1])
def test_trace_without_error():
const = tensor([8.0])
@trace(symbolic=False, without_host=True)
def fwd(a, b, c):
x = a + b
y = a + c
z = x * y
z1 = x / y + const
return [z, z1]
try:
a = tensor([1.0])
b = tensor([2.0])
c = tensor([3.0])
fwd(a, b, c)
except Exception as e:
assert str(e) == "have some unknown input tensors in trace result"
else:
assert False
def test_partial_trace_fwd_bwd():
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.0], dtype=np.float32)
self.b = Parameter([2.0], dtype=np.float32)
@partial_trace
def forward(self, x):
x = x * self.a + x / self.b
x = F.exp(x)
return x
def clear_grad(self):
self.a.grad = None
self.b.grad = None
@partial_trace
def fwd_only(a, b):
return a * b + a / b
m = Simple()
gm = GradManager()
gm.attach(m.parameters())
def func(x):
with gm:
x = x * 3
x = m(x)
x = x * 2
gm.backward(x)
a = m.a.grad
b = m.b.grad
m.clear_grad()
return fwd_only(a, b) + a + b
gt = func(tensor(1.0))
for _ in range(3):
out = func(tensor(1.0))
np.testing.assert_equal(gt.numpy(), out.numpy())
...@@ -105,6 +105,10 @@ std::string IsScalar::to_string() const { ...@@ -105,6 +105,10 @@ std::string IsScalar::to_string() const {
return "IsScalar"; return "IsScalar";
} }
std::string GetId::to_string() const {
return "GetId";
}
std::string GetFormat::to_string() const { std::string GetFormat::to_string() const {
return "GetFormat{}"; return "GetFormat{}";
} }
......
...@@ -15,6 +15,10 @@ std::string BoolValue::to_string() const { ...@@ -15,6 +15,10 @@ std::string BoolValue::to_string() const {
return (*this) ? "true" : "false"; return (*this) ? "true" : "false";
} }
std::string IntegerValue::to_string() const {
return std::to_string((int)*this);
}
std::string HostStorage::to_string() const { std::string HostStorage::to_string() const {
return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str()); return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str());
} }
......
...@@ -142,6 +142,10 @@ const std::string OpDef::make_name() const { ...@@ -142,6 +142,10 @@ const std::string OpDef::make_name() const {
return m_scope + "." + trait()->make_name(*this); return m_scope + "." + trait()->make_name(*this);
} }
const std::string OpDef::type_name() const {
return trait()->name;
}
static thread_local OpDef::allocator_t local_allocator; static thread_local OpDef::allocator_t local_allocator;
void OpDef::set_allocator(allocator_t allocator) { void OpDef::set_allocator(allocator_t allocator) {
......
...@@ -13,6 +13,10 @@ namespace imperative { ...@@ -13,6 +13,10 @@ namespace imperative {
namespace { namespace {
class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD { class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
public:
bool strict = true;
private:
const OprAttr::Param& m_param; const OprAttr::Param& m_param;
size_t m_pos = 0; size_t m_pos = 0;
ComputingGraph* m_graph; ComputingGraph* m_graph;
...@@ -40,7 +44,8 @@ public: ...@@ -40,7 +44,8 @@ public:
m_graph(graph) {} m_graph(graph) {}
~OprParamsLoadContext() { ~OprParamsLoadContext() {
mgb_assert(m_pos == m_param.size(), "param not fully consumed"); if (strict)
mgb_assert(m_pos == m_param.size(), "param not fully consumed");
} }
ComputingGraph& graph() override { return *m_graph; } ComputingGraph& graph() override { return *m_graph; }
...@@ -126,7 +131,9 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { ...@@ -126,7 +131,9 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) {
policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr); policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr);
} }
return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config()); return OprAttr::make(
registry->name, std::move(ctx.m_param), policy, opr->config(),
opr->dyn_typeinfo());
} }
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
...@@ -168,6 +175,12 @@ size_t OprAttr::hash() const { ...@@ -168,6 +175,12 @@ size_t OprAttr::hash() const {
config.hash()); config.hash());
} }
std::shared_ptr<json::Value> OprAttr::mgb_param(OprFootprint* footprint) {
OprParamsLoadContext ctx{param, nullptr};
ctx.strict = false;
return footprint->get_serial_param_json(mgb_opr_type, ctx);
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr); MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);
} // namespace imperative } // namespace imperative
......
...@@ -130,6 +130,11 @@ ValueRefList InterpreterTransformation::apply_transformation( ...@@ -130,6 +130,11 @@ ValueRefList InterpreterTransformation::apply_transformation(
} else { } else {
return {ValueRef()}; return {ValueRef()};
} }
} else if (op.is<GetId>()) {
auto& val = inputs[0].cast(m_value_type);
int64_t id = val.id();
return {IntegerValue::make(id)};
} else if (op.is<DupTensor>()) { } else if (op.is<DupTensor>()) {
auto& input = inputs[0].cast(m_value_type); auto& input = inputs[0].cast(m_value_type);
DeviceTensorND dev_tensor; DeviceTensorND dev_tensor;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/resource_manager.h" #include "megbrain/imperative/resource_manager.h"
#include <range/v3/all.hpp>
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -226,7 +227,7 @@ void GradKey::backward() { ...@@ -226,7 +227,7 @@ void GradKey::backward() {
if constexpr (std::is_same_v<T, std::monostate>) { if constexpr (std::is_same_v<T, std::monostate>) {
mgb_throw(AssertionError, "invalid backward"); mgb_throw(AssertionError, "invalid backward");
} else { } else {
mgb_assert(grad_fn->m_slots.size() > 0); // mgb_assert(grad_fn->m_slots.size() > 0);
SmallVector<ValueRef> grads (grad_fn->m_slots.size()); SmallVector<ValueRef> grads (grad_fn->m_slots.size());
auto iter = grads.begin(); auto iter = grads.begin();
for (auto&& slot : grad_fn->m_slots) { for (auto&& slot : grad_fn->m_slots) {
...@@ -419,6 +420,23 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -419,6 +420,23 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert(!grad_fn->m_slots.empty()); mgb_assert(!grad_fn->m_slots.empty());
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
return outputs; return outputs;
} else if (auto* igc = op.as<InsertGradCallback>()) {
auto grad_fn = LocalPtr<GradFn>::make();
auto& backward =
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward());
auto id = inputs[0];
backward.m_backward = [id, callback = igc->callback()](
Span<ValueRef> inputs) -> SmallVector<ValueRef> {
callback({&id, (size_t)1});
return {};
};
m_key->m_side_effects.push_back(grad_fn);
m_key->m_tape.push_back({grad_fn, nullptr});
auto next_id = IntegerValue::make((int)id.cast<IntegerValue>() + 1);
auto prev_count =
imperative::apply(InsertGradCallback(igc->callback()), next_id)[0];
auto count = IntegerValue::make((int)prev_count.cast<IntegerValue>() + 1);
return {count};
} else if (op.is<CreateTensor>()) { } else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} else if (auto* attach_grad = op.as<AttachGrad>()) { } else if (auto* attach_grad = op.as<AttachGrad>()) {
...@@ -514,6 +532,13 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -514,6 +532,13 @@ ValueRefList GradTransformation::apply_transformation(
} }
} }
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} else if (op.is<GetGradSlot>()) {
mgb_assert(inputs.size() == 1);
if (auto&& grad_value = as_grad_value(inputs[0])) {
return {GradSlotValue::make(grad_value->slot())};
} else {
return {};
}
} else if (op.kind() == Operator::IdentityLike) { } else if (op.kind() == Operator::IdentityLike) {
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
if (auto&& grad_value = as_grad_value(inputs[0])) { if (auto&& grad_value = as_grad_value(inputs[0])) {
......
...@@ -53,6 +53,9 @@ ValueRefList LazyEvalTransformation::apply_transformation( ...@@ -53,6 +53,9 @@ ValueRefList LazyEvalTransformation::apply_transformation(
outputs[i] = record_var(output_nodes[i]); outputs[i] = record_var(output_nodes[i]);
} }
return outputs; return outputs;
} else if (op.is<GetId>()) {
int64_t id = inputs[0].id();
return {IntegerValue::make(id)};
} else if (auto* create_tensor = op.as<CreateTensor>()) { } else if (auto* create_tensor = op.as<CreateTensor>()) {
auto&& args = create_tensor->parse(inputs); auto&& args = create_tensor->parse(inputs);
auto get_dev_val = [&] { auto get_dev_val = [&] {
......
...@@ -83,7 +83,7 @@ VarNodeArray TraceResult::dump( ...@@ -83,7 +83,7 @@ VarNodeArray TraceResult::dump(
std::unordered_map<std::string, std::vector<cg::OperatorNodeBase*>> name2ops; std::unordered_map<std::string, std::vector<cg::OperatorNodeBase*>> name2ops;
// iterate over opr_seq // iterate over opr_seq
for (auto&& item : seq) { for (auto&& item : seq) {
auto&& [op, inputs, outputs] = item; auto&& [op, inputs, outputs, type] = item;
VarNodeArray input_nodes; VarNodeArray input_nodes;
for (auto&& input : inputs) { for (auto&& input : inputs) {
auto& node = nodes[input]; auto& node = nodes[input];
...@@ -207,7 +207,8 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -207,7 +207,8 @@ ValueRefList TracingTransformation::apply_transformation(
auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal); auto wrapped_output = record_var(outputs[0], as_const, VarKind::Internal);
auto input_id = wrapped_input->id(); auto input_id = wrapped_input->id();
auto output_id = wrapped_output->id(); auto output_id = wrapped_output->id();
m_seq.push_back({{}, {input_id}, {output_id}});
m_seq.push_back({{}, {input_id}, {output_id}, OpKind::CreateTensor});
return {wrapped_output}; return {wrapped_output};
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
auto unwrapped_input = unwrap_var(inputs[0]); auto unwrapped_input = unwrap_var(inputs[0]);
...@@ -246,7 +247,30 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -246,7 +247,30 @@ ValueRefList TracingTransformation::apply_transformation(
} }
auto output = record_var(input, false, VarKind::Internal); auto output = record_var(input, false, VarKind::Internal);
m_vars[output->id()].mark = trace_mark_var->mark(); m_vars[output->id()].mark = trace_mark_var->mark();
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}}); m_seq.push_back(
{{}, {tracing_var->id()}, {output->id()}, OpKind::TraceMarkVar});
return {output};
} else if (auto* iomarker = op.as<IOMarkVar>()) {
mgb_assert(inputs.size() == 1, "IOMarkVar expects exactly one input");
auto input = inputs[0];
auto tracing_var = input.as_ref(m_value_type);
if (!tracing_var) {
bool is_input = iomarker->kind() == IOMarkVar::Kind::Input;
if (is_input) {
tracing_var = record_var(input, false, VarKind::External);
} else {
tracing_var = record_var(input, m_capture_as_const, VarKind::External);
}
} else {
input = tracing_var->value();
}
auto output = record_var(input, false, VarKind::Internal);
if (iomarker->kind() == IOMarkVar::Kind::Input)
m_vars[tracing_var->id()].inp_marker.insert(iomarker->mark());
else
m_vars[output->id()].out_marker.insert(iomarker->mark());
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}, OpKind::IOMarkVar});
return {output}; return {output};
} else if (auto* trace_name_var = op.as<RenameValue>()) { } else if (auto* trace_name_var = op.as<RenameValue>()) {
mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input"); mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input");
...@@ -259,7 +283,7 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -259,7 +283,7 @@ ValueRefList TracingTransformation::apply_transformation(
} }
auto output = record_var(input, false, VarKind::Internal); auto output = record_var(input, false, VarKind::Internal);
m_vars[output->id()].name = trace_name_var->name(); m_vars[output->id()].name = trace_name_var->name();
m_seq.push_back({{}, {tracing_var->id()}, {output->id()}}); m_seq.push_back({{}, {tracing_var->id()}, {output->id()}, OpKind::Rename});
return {output}; return {output};
} else if (op.is<GetName>()) { } else if (op.is<GetName>()) {
mgb_assert(inputs.size() == 1, "GetName expects exactly one input"); mgb_assert(inputs.size() == 1, "GetName expects exactly one input");
...@@ -279,6 +303,78 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -279,6 +303,78 @@ ValueRefList TracingTransformation::apply_transformation(
} }
} }
void TracingTransformation::postprocess_trace_result() {
std::unordered_map<size_t, size_t> identity_oi_map, identity_io_map;
for (auto&& op : m_seq) {
if (op.op == nullptr && op.inputs.size() == 1 && op.outputs.size() == 1) {
identity_oi_map[op.outputs[0]] = op.inputs[0];
identity_io_map[op.inputs[0]] = op.outputs[0];
}
}
for (auto&& op : m_seq) {
if (op.kind == OpKind::IOMarkVar) {
auto&& inpvar = m_vars[op.inputs[0]];
auto&& outvar = m_vars[op.outputs[0]];
if (inpvar.inp_marker.size() > 0) {
auto id = inpvar.id;
if (inpvar.kind != VarKind::External) {
while (identity_oi_map.find(id) != identity_oi_map.end()) {
id = identity_oi_map[id];
}
if (m_vars[id].kind == VarKind::External) {
for (auto mark : inpvar.inp_marker) {
mgb_assert(
inpmark_to_id.find(mark) == inpmark_to_id.end() ||
inpmark_to_id[mark] == id,
"two nodes have same mark");
inpmark_to_id[mark] = id;
m_vars[id].inp_marker.insert(mark);
}
inpvar.inp_marker.clear();
}
} else {
for (auto mark : inpvar.inp_marker) {
mgb_assert(
inpmark_to_id.find(mark) == inpmark_to_id.end() ||
inpmark_to_id[mark] == id,
"two nodes have same mark");
inpmark_to_id[mark] = id;
}
}
} else {
mgb_assert(outvar.out_marker.size() > 0);
auto id = outvar.id;
if (!outvar.data_required) {
while (identity_io_map.find(id) != identity_io_map.end()) {
id = identity_io_map[id];
}
if (m_vars[id].data_required) {
for (auto mark : outvar.out_marker) {
mgb_assert(
outmark_to_id.find(mark) == outmark_to_id.end() ||
outmark_to_id[mark] == id,
"two nodes have same mark");
outmark_to_id[mark] = id;
m_vars[id].out_marker.insert(mark);
}
outvar.out_marker.clear();
}
} else {
for (auto mark : outvar.out_marker) {
mgb_assert(
outmark_to_id.find(mark) == outmark_to_id.end() ||
outmark_to_id[mark] == id,
"two nodes have same mark");
outmark_to_id[mark] = id;
}
}
}
}
}
}
void TracingTransformation::on_unregister() noexcept { void TracingTransformation::on_unregister() noexcept {
for (auto&& weak_var : m_weak_vars) { for (auto&& weak_var : m_weak_vars) {
if (auto tracing_value = weak_var.lock()) { if (auto tracing_value = weak_var.lock()) {
...@@ -526,7 +622,10 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { ...@@ -526,7 +622,10 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
var.device->to_string().c_str(), var.device->to_string().c_str(),
device.to_string().c_str()); device.to_string().c_str());
} }
var_accessor.data_setter(value.dev_tensor()->as_nd()); if (m_setted_extern.find(id) == m_setted_extern.end()) {
var_accessor.data_setter(value.dev_tensor()->as_nd());
m_setted_extern.insert(id);
}
break; break;
} }
case VarKind::Constant: { case VarKind::Constant: {
...@@ -732,6 +831,7 @@ void CompiledTransformation::wait() { ...@@ -732,6 +831,7 @@ void CompiledTransformation::wait() {
m_pc = 0; m_pc = 0;
std::exception_ptr graph_exc; std::exception_ptr graph_exc;
std::swap(m_graph_exc, graph_exc); std::swap(m_graph_exc, graph_exc);
m_setted_extern.clear();
if (graph_exc) { if (graph_exc) {
// graph with exception cannot be reused // graph with exception cannot be reused
recompile(); recompile();
......
...@@ -127,6 +127,10 @@ bool ValueRef::watching() const { ...@@ -127,6 +127,10 @@ bool ValueRef::watching() const {
return this->storage()->m_watching; return this->storage()->m_watching;
} }
int ValueRef::handle_id() const {
return imperative::apply(GetId(), *this)[0].cast<IntegerValue>();
}
ValueRef ValueRef::make(ValueRef::storage_t storage) { ValueRef ValueRef::make(ValueRef::storage_t storage) {
if (recording_values) { if (recording_values) {
recorded_values.push_back({storage}); recorded_values.push_back({storage});
......
...@@ -141,6 +141,14 @@ public: ...@@ -141,6 +141,14 @@ public:
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
}; };
class GetId final : public OperatorImpl<GetId, Operator::GetAttrLike> {
public:
std::string to_string() const override;
std::string raw_type() const { return "GetId"; }
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
};
/** /**
* \brief return a value with new name * \brief return a value with new name
* *
......
...@@ -48,6 +48,25 @@ public: ...@@ -48,6 +48,25 @@ public:
std::string to_string() const override; std::string to_string() const override;
}; };
class Integer {
private:
int64_t m_value;
public:
Integer() = default;
Integer(int64_t value) : m_value(value) {}
operator int64_t() const { return m_value; }
};
// TODO: override factory method
class IntegerValue final : public PrimitiveValue<IntegerValue, Integer> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class HostStorage final : public PrimitiveValue<HostStorage, HostTensorStorage> { class HostStorage final : public PrimitiveValue<HostStorage, HostTensorStorage> {
public: public:
using PrimitiveValue::PrimitiveValue; using PrimitiveValue::PrimitiveValue;
......
...@@ -80,6 +80,8 @@ public: ...@@ -80,6 +80,8 @@ public:
const std::string make_name() const; const std::string make_name() const;
virtual const std::string type_name() const;
void set_scope(const std::string& scope); void set_scope(const std::string& scope);
virtual size_t hash() const; virtual size_t hash() const;
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/opr/param_defs.h" #include "megbrain/opr/param_defs.h"
#include "megbrain/plugin/opr_footprint.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -28,6 +29,7 @@ public: ...@@ -28,6 +29,7 @@ public:
Type type; Type type;
Param param; Param param;
Typeinfo* mgb_opr_type;
megdnn::param::ExecutionPolicy policy; megdnn::param::ExecutionPolicy policy;
cg::OperatorNodeConfig config; cg::OperatorNodeConfig config;
...@@ -36,13 +38,14 @@ public: ...@@ -36,13 +38,14 @@ public:
OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c)
: type(t), param(p), config(c) {} : type(t), param(p), config(c) {}
OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps, OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps,
const cg::OperatorNodeConfig& c) const cg::OperatorNodeConfig& c, Typeinfo* optype)
: type(t), param(p), policy(ps), config(c) {} : type(t), param(p), policy(ps), config(c), mgb_opr_type(optype) {}
std::string repr() const; std::string repr() const;
std::shared_ptr<json::Value> mgb_param(OprFootprint*);
bool is_same_st(const Hashable& rhs) const override; bool is_same_st(const Hashable& rhs) const override;
size_t hash() const override; size_t hash() const override;
const std::string type_name() const override { return type; }
}; };
} // namespace imperative } // namespace imperative
......
...@@ -107,7 +107,7 @@ private: ...@@ -107,7 +107,7 @@ private:
public: public:
std::string to_string() const; std::string to_string() const;
ValueRef grad() const { return m_grad; }
friend class GradKey; friend class GradKey;
friend class GradSlotProducerPtr; friend class GradSlotProducerPtr;
friend class GradTransformation; friend class GradTransformation;
...@@ -224,6 +224,7 @@ public: ...@@ -224,6 +224,7 @@ public:
class GradKey : public std::enable_shared_from_this<GradKey> { class GradKey : public std::enable_shared_from_this<GradKey> {
private: private:
std::string m_name; std::string m_name;
std::vector<LocalPtr<GradFn>> m_side_effects;
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape; std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape; std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape;
bool m_frozen = false; bool m_frozen = false;
...@@ -253,6 +254,13 @@ public: ...@@ -253,6 +254,13 @@ public:
} }
}; };
class GradSlotValue final : public PrimitiveValue<GradSlotValue, GradSlotPtr> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override { return ssprintf("GradSlot{}"); }
};
class GradTransformation final : public Transformation { class GradTransformation final : public Transformation {
private: private:
ObjectType<GradValue> m_value_type{"GradValue"}; ObjectType<GradValue> m_value_type{"GradValue"};
...@@ -404,6 +412,28 @@ public: ...@@ -404,6 +412,28 @@ public:
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
}; };
class GetGradSlot : public OperatorImpl<GetGradSlot, Operator::GetAttrLike> {
public:
GetGradSlot() = default;
std::string to_string() const override { return ssprintf("GetGradSlot{}"); }
std::string raw_type() const { return "GetGradSlot"; };
ValueRefList fallback(Span<ValueRef> inputs) const override { return {}; }
};
class InsertGradCallback : public OperatorImpl<InsertGradCallback, Operator::Other> {
public:
GenericFunction m_callback;
public:
InsertGradCallback(GenericFunction callback) : m_callback(callback) {}
GenericFunction callback() const { return m_callback; }
std::string to_string() const override { return ssprintf("InsertGradCallback{}"); }
std::string raw_type() const { return "InsertGradCallback"; }
};
class GetBackwardColsure class GetBackwardColsure
: public OperatorImpl<GetBackwardColsure, Operator::GetAttrLike> { : public OperatorImpl<GetBackwardColsure, Operator::GetAttrLike> {
private: private:
...@@ -420,4 +450,19 @@ public: ...@@ -420,4 +450,19 @@ public:
std::string raw_type() const { return "GetBackwardClosure"; } std::string raw_type() const { return "GetBackwardClosure"; }
}; };
class GradTransformationGuard final : public Transformation {
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (auto* igc = op.as<InsertGradCallback>()) {
auto count = IntegerValue::make(0);
return {count};
}
return imperative::apply(op, inputs);
}
ValueRef unwrap(ValueRef value) override { return value; };
std::string name() const override { return "GradTransformationGuard"; };
};
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#include <chrono> #include <chrono>
#include <future> #include <future>
#include <set>
#include <variant> #include <variant>
#include "megbrain/gopt/inference.h" #include "megbrain/gopt/inference.h"
#include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/interpreter.h"
...@@ -17,11 +17,21 @@ namespace mgb::imperative { ...@@ -17,11 +17,21 @@ namespace mgb::imperative {
struct TraceResult { struct TraceResult {
struct SeqItem { struct SeqItem {
enum OpKind {
Unknown,
TraceMarkVar,
Rename,
IOMarkVar,
CreateTensor,
};
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<size_t> inputs; SmallVector<size_t> inputs;
SmallVector<size_t> outputs; SmallVector<size_t> outputs;
OpKind kind = OpKind::Unknown;
}; };
using OpKind = SeqItem::OpKind;
struct VarInfo { struct VarInfo {
enum Kind { enum Kind {
External, // End point of traced graph, its value is received from External, // End point of traced graph, its value is received from
...@@ -41,12 +51,14 @@ struct TraceResult { ...@@ -41,12 +51,14 @@ struct TraceResult {
ValueRef bound_data; ValueRef bound_data;
std::string mark; std::string mark;
std::string name; std::string name;
int handle_id;
Kind kind; Kind kind;
bool value_required = false; bool value_required = false;
bool data_required = false; bool data_required = false;
bool shape_required = false; bool shape_required = false;
std::set<size_t> inp_marker;
std::set<size_t> out_marker;
TensorShape shape; TensorShape shape;
}; };
...@@ -91,6 +103,27 @@ public: ...@@ -91,6 +103,27 @@ public:
std::string raw_type() const { return "TraceMarkVar"; } std::string raw_type() const { return "TraceMarkVar"; }
}; };
class IOMarkVar : public OperatorImpl<IOMarkVar, Operator::IdentityLike> {
public:
enum Kind {
Input,
Output,
};
private:
size_t m_mark;
Kind m_kind;
public:
IOMarkVar(size_t mark, Kind kind) : m_mark(mark), m_kind(kind) {}
size_t mark() const { return m_mark; }
Kind kind() const { return m_kind; }
std::string to_string() const override { return ssprintf("IOMarkVar"); }
std::string raw_type() const override { return "IOMarkVar"; }
};
class TracingValue final : public ObjectValue<TracingValue> { class TracingValue final : public ObjectValue<TracingValue> {
private: private:
ValueRef m_value = {}; ValueRef m_value = {};
...@@ -125,15 +158,22 @@ class TracingTransformation final : public Transformation { ...@@ -125,15 +158,22 @@ class TracingTransformation final : public Transformation {
public: public:
using VarInfo = TraceResult::VarInfo; using VarInfo = TraceResult::VarInfo;
using VarKind = VarInfo::Kind; using VarKind = VarInfo::Kind;
using OpKind = TraceResult::SeqItem::OpKind;
private: private:
std::vector<TraceResult::SeqItem> m_seq; std::vector<TraceResult::SeqItem> m_seq;
std::vector<TraceResult::VarInfo> m_vars; std::vector<TraceResult::VarInfo> m_vars;
std::vector<TracingValue::weak_ref_t> m_weak_vars; std::vector<TracingValue::weak_ref_t> m_weak_vars;
std::unordered_map<size_t, size_t> extern_var_to_id;
bool m_capture_as_const = false; bool m_capture_as_const = false;
bool m_record_input_shapes = false; bool m_record_input_shapes = false;
bool m_record_all_shapes = false;
ObjectType<TracingValue> m_value_type{"TracingValue"}; ObjectType<TracingValue> m_value_type{"TracingValue"};
public:
std::unordered_map<size_t, size_t> inpmark_to_id;
std::unordered_map<size_t, size_t> outmark_to_id;
public: public:
TracingTransformation(bool capture_as_const, bool record_input_shapes) TracingTransformation(bool capture_as_const, bool record_input_shapes)
: m_capture_as_const(capture_as_const), : m_capture_as_const(capture_as_const),
...@@ -148,7 +188,14 @@ public: ...@@ -148,7 +188,14 @@ public:
* \return TypedValueRef<TracingValue> traced value * \return TypedValueRef<TracingValue> traced value
*/ */
TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) { TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
if (kind == VarKind::External &&
extern_var_to_id.find(value.id()) != extern_var_to_id.end()) {
return m_value_type.make(value, extern_var_to_id[value.id()]);
}
size_t id = m_vars.size(); size_t id = m_vars.size();
if (kind == VarKind::External) {
extern_var_to_id[value.id()] = id;
}
auto wrapped_value = m_value_type.make(value, id); auto wrapped_value = m_value_type.make(value, id);
m_vars.push_back({id, value.dtype(), value.device()}); m_vars.push_back({id, value.dtype(), value.device()});
auto& var = m_vars.back(); auto& var = m_vars.back();
...@@ -156,9 +203,12 @@ public: ...@@ -156,9 +203,12 @@ public:
var.bound_data = value; var.bound_data = value;
} }
var.kind = kind; var.kind = kind;
if (m_record_input_shapes && kind != VarKind::Internal) { if ((m_record_input_shapes && kind != VarKind::Internal) ||
m_record_all_shapes) {
var.shape = value.shape()->as_tensor_shape(); var.shape = value.shape()->as_tensor_shape();
} }
if (m_record_all_shapes)
var.handle_id = value.handle_id();
if (auto name = value.name()) { if (auto name = value.name()) {
var.name = *name; var.name = *name;
} }
...@@ -185,8 +235,9 @@ public: ...@@ -185,8 +235,9 @@ public:
std::string name() const override { return "TracingTransformation"; } std::string name() const override { return "TracingTransformation"; }
void on_unregister() noexcept override; void on_unregister() noexcept override;
void postprocess_trace_result();
TraceResult get_result() { return {m_seq, m_vars}; } TraceResult get_result() { return {m_seq, m_vars}; }
void enable_record_all_shapes() { m_record_all_shapes = true; }
}; };
class TraceError : public std::exception { class TraceError : public std::exception {
...@@ -211,6 +262,7 @@ class CompiledTransformation final : public Transformation { ...@@ -211,6 +262,7 @@ class CompiledTransformation final : public Transformation {
public: public:
using VarInfo = TraceResult::VarInfo; using VarInfo = TraceResult::VarInfo;
using VarKind = VarInfo::Kind; using VarKind = VarInfo::Kind;
using OpKind = TraceResult::SeqItem::OpKind;
struct VarAccessor { struct VarAccessor {
VarNode* node; VarNode* node;
...@@ -254,6 +306,7 @@ private: ...@@ -254,6 +306,7 @@ private:
std::vector<TraceResult::SeqItem> m_seq; std::vector<TraceResult::SeqItem> m_seq;
std::vector<TraceResult::VarInfo> m_vars; std::vector<TraceResult::VarInfo> m_vars;
std::vector<VarAccessor> m_var_accessors; std::vector<VarAccessor> m_var_accessors;
std::unordered_map<std::string, size_t> mark2id;
size_t m_pc = 0; size_t m_pc = 0;
std::shared_ptr<ComputingGraph> m_graph; std::shared_ptr<ComputingGraph> m_graph;
std::unique_ptr<cg::AsyncExecutable> m_executable; std::unique_ptr<cg::AsyncExecutable> m_executable;
...@@ -268,6 +321,7 @@ private: ...@@ -268,6 +321,7 @@ private:
std::vector<std::shared_ptr<BoxBase>> m_boxes; std::vector<std::shared_ptr<BoxBase>> m_boxes;
ComputingGraph::OutputSpec m_output_spec; ComputingGraph::OutputSpec m_output_spec;
ObjectType<TracedValue> m_value_type{"TracedValue"}; ObjectType<TracedValue> m_value_type{"TracedValue"};
std::set<size_t> m_setted_extern;
public: public:
CompiledTransformation(TraceResult result, bool input_shape_static) CompiledTransformation(TraceResult result, bool input_shape_static)
...@@ -360,8 +414,10 @@ public: ...@@ -360,8 +414,10 @@ public:
return value; return value;
} }
std::string name() const override { return "CompiledTransformation"; } VarAccessor& get_accessor_by_id(size_t id) { return m_var_accessors[id]; }
std::string name() const override { return "CompiledTransformation"; }
void set_pc_to_end() { m_pc = m_seq.size(); }
void execute(); void execute();
void wait(); void wait();
......
...@@ -222,6 +222,7 @@ public: ...@@ -222,6 +222,7 @@ public:
TypedValueRef<DTypeValue> dtype() const; TypedValueRef<DTypeValue> dtype() const;
TypedValueRef<FormatValue> format() const; TypedValueRef<FormatValue> format() const;
TypedValueRef<StringValue> name() const; TypedValueRef<StringValue> name() const;
int handle_id() const;
bool is_scalar() const; bool is_scalar() const;
void watch() const; void watch() const;
...@@ -298,7 +299,7 @@ protected: ...@@ -298,7 +299,7 @@ protected:
public: public:
const IType& type() const { return *m_type; } const IType& type() const { return *m_type; }
uint64_t id() const { return m_id; }
static void register_value(ValueRef value); static void register_value(ValueRef value);
static ValueRef get_value_by_id(uint64_t id); static ValueRef get_value_by_id(uint64_t id);
static void begin_record_values(); static void begin_record_values();
...@@ -538,11 +539,11 @@ public: ...@@ -538,11 +539,11 @@ public:
const ValueRef* data() const { return m_data; } const ValueRef* data() const { return m_data; }
bool empty() const { return m_size == 0; } bool empty() const { return m_size == 0; }
ValueRef& front() { ValueRef& front() {
mgb_assert(m_size > 1); mgb_assert(m_size >= 1);
return m_data[0]; return m_data[0];
} }
ValueRef& back() { ValueRef& back() {
mgb_assert(m_size > 1); mgb_assert(m_size >= 1);
return m_data[m_size - 1]; return m_data[m_size - 1];
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册