diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 784eceea5ba305223bbb92360911c2d2f28b96f9..9a9a0f97d64c22c27f020ff033b2f26584371a40 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -11,6 +11,7 @@ from typing import Iterable, Union import numpy as np +from .._imperative_rt import VarNode from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device from ..ops import builtin from ..ops.special import Const @@ -59,7 +60,7 @@ def astype(x, dtype): def convert_single_value(v, *, dtype=None, device=None): - if isinstance(v, Tensor): + if isinstance(v, (Tensor, VarNode)): if not is_quantize(v.dtype): v = astype(v, dtype) else: diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 47de381d791b63d970f97a612dca63d5b1f1489d..9e943d5b93a6848e42e8ad1f566e5e7b54c484cc 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -12,11 +12,12 @@ import functools import numpy as np from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.graph import VarNode from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor import utils from ..core.tensor.array_method import _elwise_apply -from ..core.tensor.utils import isscalar, setscalar +from ..core.tensor.utils import astype, isscalar, setscalar from ..device import get_default_device from ..jit.tracing import is_tracing from ..tensor import Tensor @@ -77,7 +78,7 @@ __all__ = [ def _elwise(*args, mode): - tensor_args = list(filter(lambda x: isinstance(x, Tensor), args)) + tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) if len(tensor_args) == 0: dtype = utils.dtype_promotion(args) first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) @@ -109,7 +110,7 @@ def _elwise(*args, mode): Elemwise.Mode.ROUND, ) and np.issubdtype(args[0].dtype, np.integer): return args[0] - args = tuple(map(lambda x: x.astype("float32"), args)) + args = tuple(map(lambda x: astype(x, "float32"), args)) return _elwise_apply(args, mode) diff --git a/imperative/python/megengine/utils/comp_graph_tools.py b/imperative/python/megengine/utils/comp_graph_tools.py index 20bc1e7d560957658edc0d435c990ac57ef4e58b..80ff11138ba6a51a81d4b8c4b4f3703083b5e368 100644 --- a/imperative/python/megengine/utils/comp_graph_tools.py +++ b/imperative/python/megengine/utils/comp_graph_tools.py @@ -65,7 +65,6 @@ def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: """ Gets the inputs of owner opr of a variable. """ - assert isinstance(var, VarNode) return var.owner.inputs @@ -74,7 +73,6 @@ def get_owner_opr_type(var: VarNode) -> str: Gets the type of owner opr of a variable. """ - assert isinstance(var, VarNode) return var.owner.type @@ -109,7 +107,7 @@ def graph_traversal(outputs: VarNode): var2oprs = collections.defaultdict(list) opr2receivers = collections.defaultdict(list) - queue = list(map(lambda x: x.owner, outputs)) + queue = list(set(map(lambda x: x.owner, outputs))) visited = set(map(lambda x: x.id, queue)) # iterate through whole comp_graph, fill in meta information @@ -143,12 +141,15 @@ def graph_traversal(outputs: VarNode): return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree -def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: +def get_oprs_seq( + outputs: List[VarNode], prune_reshape=False, prune_immtensor=True +) -> List[OperatorNode]: """ Gets oprs in some topological order for a dumped model. :param outputs: model outputs. - :param prune_reshape: whether to prune the useless operators during inference. + :param prune_reshape: whether to prune the useless operators used by Reshape opr during inference. + :param prune_immtensor: whether to prune the ImmutableTensor opr. :return: opr list with some correct execution order. """ @@ -160,9 +161,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo opr_id = indegree2opr[0].pop() opr = map_oprs[opr_id] nr_remain -= 1 - - # skip const value generation operator - if get_opr_type(opr) != "ImmutableTensor": + if opr.type != "ImmutableTensor" or not prune_immtensor: oprs_seq.append(opr) for post_id in opr2receivers[opr_id]: diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py new file mode 100644 index 0000000000000000000000000000000000000000..f26e8086becfb933a025c3661e3b06b03a6166e1 --- /dev/null +++ b/imperative/python/megengine/utils/network.py @@ -0,0 +1,682 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections +import fnmatch +import itertools +import re +from collections import OrderedDict +from typing import Dict, List + +import numpy as np + +from ..core._imperative_rt import ComputingGraph +from ..core.tensor import megbrain_graph as G +from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq +from .network_node import ( + NetworkNode, + Host2DeviceCopy, + ImmutableTensor, + OpNode, + VarNode, + str_to_mge_class, +) + + +class Network: + def __init__(self): + self.input_vars = [] # input var of graph + self._orig_inputs = [] + self.output_vars = [] # output var of graph + self._orig_outputs = [] + self.all_oprs_map = OrderedDict() + self.all_vars_map = OrderedDict() + self.graph = ComputingGraph() + + @classmethod + def load(cls, model_path: str, outspec: List[str] = None): + """ + Loads a computing graph as a Network object. + :param model_path: file path of mge model. + :param outspec: only load the subgraph with outspec as its endpoints. + """ + self = cls() + _, _, outputs = G.load_graph(model_path) + if outspec is not None: + output_spec = outspec.copy() + all_vars = get_dep_vars(outputs) + outputs + new_outputs = {} + for i in all_vars: + if i.name in output_spec: + new_outputs[i.name] = i + output_spec.remove(i.name) + assert len(output_spec) == 0, "Can not find {} in this model".format( + output_spec + ) + outputs = [new_outputs[i] for i in outspec] + self._orig_outputs = outputs + self.add_dep_oprs(*outputs) + for x in self._orig_inputs: + self.input_vars.append(self._get_var(x)) + + for x in self._orig_outputs: + self.output_vars.append(self._get_var(x)) + self.graph = self._orig_outputs[0].graph + return self + + def _compile(self): + self.all_oprs_map = {} + self.all_vars_map = {} + for opr in self.all_oprs: + if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): + opr.compile(self.graph) + else: + opr.compile() + if opr.name is not None: + opr._opr.name = opr.name + self.all_oprs_map[opr._opr.id] = opr + for o in opr.outputs: + self.all_vars_map[o.var.id] = o + + def dump( + self, + file, + *, + keep_var_name: int = 1, + keep_opr_name: bool = False, + keep_param_name: bool = False, + keep_opr_priority: bool = False, + strip_info_file=None, + append_json=False, + optimize_for_inference=True, + append=False, + **kwargs + ): + """ + Serializes graph to file. + + :param file: output file, could be file object or filename. + :param append: whether output is appended to ``file``. + Only works when ``file`` is str. + :param keep_var_name: level for keeping variable names: + + * 0: none of the names are kept + * 1: (default)keep names of output vars + * 2: keep names of all (output and internal) vars + :param keep_opr_name: whether to keep operator names. + :param keep_param_name: whether to keep param names, so param values can be + easily manipulated after loading model + :param keep_opr_priority: whether to keep priority setting for operators + :param strip_info_file: a string for path or a file handler. if is not None, + then the dump information for code strip would be written to ``strip_info_file`` + :param append_json: will be check when `strip_info_file` is not None. if set + true, the information for code strip will be append to strip_info_file. + if set false, will rewrite strip_info_file + :param optimize_for_inference: enbale optmizations, + will skip all optimize options if this is False. Default: True + + :Keyword Arguments: + + * enable_io16xc32 -- + whether to use float16 for I/O between oprs and use + float32 as internal computation precision. Note the output var would be + changed to float16. + * enable_ioc16 -- + whether to use float16 for both I/O and computation + precision. + + * enable_hwcd4 -- + whether to use NHWCD4 data layout. This is faster on some + OpenCL backend. + * enable_nchw88 -- + whether to use NCHW88 data layout, currently + used in X86 AVX backend. + * enable_nchw44 -- + whether to use NCHW44 data layout, currently + used in arm backend. + * enable_nchw44_dot -- + whether to use NCHW44_dot data layout, currently + used in armv8.2+dotprod backend. + * enable_nchw4 -- + whether to use NCHW4 data layout, currently + used in nvidia backend(based on cudnn). + * enable_nchw32 -- + whether to use NCHW32 data layout, currently + used in nvidia backend with tensorcore(based on cudnn). + * enable_chwn4 -- + whether to use CHWN4 data layout, currently + used in nvidia backend with tensorcore. + + * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty + into one opr. + * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z + input for inference on nvidia backend(this optimization pass will + result in mismatch of the precision of output of training and + inference) + """ + + self._compile() + out = [G.VarNode(var.var) for var in self.output_vars] + + if optimize_for_inference: + out = G.optimize_for_inference(out, **kwargs) + + dump_content, _ = G.dump_graph( + out, + keep_var_name=keep_var_name, + keep_opr_name=keep_opr_name, + keep_param_name=keep_param_name, + keep_opr_priority=keep_opr_priority, + strip_info_file=strip_info_file, + append_json=append_json, + ) + if isinstance(file, str): + permission = "wb" if append == False else "ab" + file = open(file, permission) + file.write(dump_content) + + def make_const(self, data, name=None, device=None): + """Makes an ImmutableTensor OpNode to provide a parameter for the network. + """ + node = ImmutableTensor(data, name, device, self.graph) + node.compile(self.graph) + return node.outputs[0] + + def make_input_node(self, shape, dtype, name=None, device=None): + """Makes a Host2DeviceCopy OpNode to provide an input varnode for the network. + """ + node = Host2DeviceCopy(shape, dtype, name, device) + node.compile(self.graph) + return node.outputs[0] + + def add_output(self, *vars: VarNode): + """Adds vars into the network output node list + """ + for var in vars: + if var not in self.output_vars: + self.output_vars.append(var) + + def remove_output(self, *vars: VarNode): + """Removes vars from the network output node list. + """ + for var in vars: + if var in self.output_vars: + self.output_vars.remove(var) + + def add_dep_oprs(self, *vars): + """Adds dependent opnodes and varnodes of vars into network + """ + oprs = get_oprs_seq(vars, False, False) + for mge_opr in oprs: + if get_opr_type(mge_opr) == "Host2DeviceCopy": + self._orig_inputs.extend(mge_opr.outputs) + opr = self._add_opr(mge_opr) + if opr is not None: + for x in mge_opr.inputs: + opr.add_inp_var(self._get_var(x)) + # set out var + for x in mge_opr.outputs: + opr.add_out_var(self._get_var(x)) + + return [self.all_vars_map[var.id] for var in vars] + + def modify_opr_names(self, modifier): + """Modifies names of operators **inplace**; useful for merging loaded + network into another network + + :param modifier: a string to be prepended to the name, or a function + that maps from name to name + :type modifier: str or callable + """ + if isinstance(modifier, str): + om = modifier + modifier = lambda v: "{}.{}".format(om, v) + assert isinstance(modifier, collections.Callable) + for i in self.all_oprs: + v0 = i.name + v1 = modifier(v0) + assert isinstance(v1, str) + i.name = v1 + + def reset_batch_size(self, batchsize, *, blacklist=()): + """Helper for reset batch size; first dimension of all data providers + not in blacklist are assumed to be the batch size + + :param blacklist: data provider names whose first dimension is not + batchbatch size + """ + blacklist = set(blacklist) + prev_batchsize = None + for i in self.data_providers_filter: + if i.name in blacklist: + blacklist.remove(i.name) + else: + shp = list(i.shape) + if prev_batchsize is None: + prev_batchsize = shp[0] + else: + assert prev_batchsize == shp[0], ( + "batchsize mismatch: batchsize={} " + "shape={} dp={}".format(prev_batchsize, shp, i.name) + ) + shp[0] = batchsize + i.shape = tuple(shp) + + assert prev_batchsize is not None, "no data provider found" + assert not blacklist, "unused items in blacklist: {}".format(blacklist) + + def replace_vars(self, repl_dict: Dict[VarNode, VarNode]): + """ + Replaces vars in the graph. + :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. + """ + for var in self.all_vars: + if var in repl_dict: + repl_var = repl_dict[var] + owner = repl_var.owner + idx = owner.outputs.index(repl_var) + owner.outputs[idx] = var + var.__dict__.update(repl_var.__dict__) + + def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): + """ + Replaces operators in the graph. + :param oprmap: the map {old_opr: new_opr} that specifies how to replace the operators. + """ + for opr in self.all_oprs: + if opr in repl_dict: + assert len(opr.outputs) == len( + repl_dict[opr].outputs + ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) + repl_dict[opr].outputs = opr.outputs + for ind, var in enumerate(opr.outputs): + var.owner = repl_dict[opr] + var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) + + def get_opr_by_type(self, oprcls, unique=True): + assert issubclass(oprcls, OpNode) + rst = self.opr_filter.type(oprcls).as_list() + if unique: + assert len(rst) == 1, "{} operators of type {} found".format( + len(rst), oprcls + ) + (rst,) = rst + return rst + + def get_opr_by_name(self, name, unique=True): + rst = self.opr_filter.name(name).as_list() + if unique: + assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) + (rst,) = rst + return rst + + def get_var_by_name(self, name, unique=True): + rst = self.var_filter.name(name).as_list() + if unique: + assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) + (rst,) = rst + return rst + + def get_var_receive_oprs(self, var): + """ Gets all oprs which use var as input + """ + return self.opr_filter.has_input(var).as_list() + + def get_dep_oprs(self, var): + """Gets dependent oprs of var + """ + return get_oprs_seq(var, False, False) + + @property + def opr_filter(self): + """Filter on all opnodes of the Network. + """ + oprs = self.all_oprs + return NodeFilter(itertools.islice(oprs, len(oprs))) + + @property + def var_filter(self): + """Filter on all varnode of the Network. + """ + vars = self.all_vars + return NodeFilter(itertools.islice(vars, len(vars))) + + @property + def params_filter(self): # all immutable tensor + """Filter on all parameters (ImmutableTensor Opr) of the Network + """ + return self.opr_filter.param_provider() + + @property + def data_providers_filter(self): # all host2devicecopy + """Filter on all input nodes (Host2DeviceCopy Opr) of the Network + """ + return self.opr_filter.data_provider() + + @property + def dest_vars(self): + """Output varnodes of the Network. + """ + return self.output_vars + + @property + def all_oprs(self): + return get_oprs_seq(self.output_vars, False, False) + + @property + def all_vars(self): + return get_dep_vars(self.output_vars) + + @property + def all_vars_dict(self): + return self.var_filter.as_dict() + + @property + def all_oprs_dict(self): + return self.opr_filter.as_dict() + + # used for loading and building graph + def _add_opr(self, x): + # TODO: use megbrain C++ RTTI to replace type string + if x.id not in self.all_oprs_map: + self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) + return self.all_oprs_map[x.id] + else: + return None + + def _get_opr(self, x): + if x.id in self.all_oprs_map: + return self.all_oprs_map[x.id] + else: + return None + + def _get_var(self, x): + # auto convert to VarNode of Network + if x.id not in self.all_vars_map: + self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) + return self.all_vars_map[x.id] + + +def as_varnode(obj): + """convert a :class:`.VarNode` compatible object to :class:`.VarNode`. + + :param obj: it must be one of the following: + + 1. a :class:`.VarNode` object + 2. a :class:`.OpNode` object that has unique output + 3. an iterable that produces either type 1 or 2, with length 1 + + :rtype: :class:`.VarNode` + """ + if type(obj) is VarNode: + return obj + + if isinstance(obj, OpNode): + assert len(obj.outputs) == 1, ( + "operator {} must have one output to be converted to VarNode; " + "got {} actually".format(obj, len(obj.outputs)) + ) + ret = obj.outputs[0] + assert type(ret) is VarNode + return ret + + assert isinstance( + obj, collections.Iterable + ), "{} is not compatible with VarNode".format(obj) + + val = list(obj) + assert ( + len(val) == 1 + ), "can not convert sequence of length {} to VarNode ({})".format( + len(val), (lambda s: s if len(s) < 50 else s[:50] + " ...")(str(val)) + ) + return as_varnode(val[0]) + + +def as_oprnode(obj): + """convert a :class:`.OpNode` compatible object to + :class:`.OpNode`; it works like :func:`as_varnode`.""" + if type(obj) is VarNode: + return obj.owner + + if isinstance(obj, OpNode): + return obj + + assert isinstance( + obj, collections.Iterable + ), "{} is not compatible with OpNode".format(obj) + + val = list(obj) + assert ( + len(val) == 1 + ), "can not convert sequence of length {} to " "OpNode({})".format(len(val), val) + return as_oprnode(val[0]) + + +class NodeFilter: + """Filter on node iterator. This class is an iterator of + :class:`.NetworkNode` objects and multiple filtering conditions and + mappers can be chained. + + Example:: + + # find all :class:`.ImmutableTensor` nodes + for i in NodeFilter(node_iter).param_provider(): + print(i) + + # find all :class:`.ImmutableTensor` nodes that end with ':W' + for i in NodeFilter(node_iter).param_provider().name('*:W'): + print(i) + + # number of inputs + nr_input = NodeFilter(node_iter).data_provider().as_count() + + """ + + _iter = None + + def __init__(self, node_iter): + """ + :param node_iter: iterator to :class:`.NetworkNode`, or a + :class:`.VarNode`-compatible object; in the later case, its + dependent oprs would be used + """ + if isinstance(node_iter, VarNode): + oprs = get_oprs_seq(node_iter, False, False) + node_iter = itertools.islice(oprs, len(oprs) - 1) + if isinstance(node_iter, OpNode): + oprs = get_oprs_seq(node_iter.inputs, False, False) + node_iter = itertools.islice(oprs, len(oprs) - 1) + + assert isinstance(node_iter, collections.Iterable) + if (not isinstance(node_iter, NodeFilter)) and type( + self + ) is not NodeFilterCheckType: + node_iter = NodeFilterCheckType(node_iter, NetworkNode) + self._iter = node_iter + + @classmethod + def make_all_deps(cls, *dest_vars): + """make a :class:`NodeFilter` that contains all deps of given vars""" + return cls(list(get_oprs_seq(dest_vars, False, False))) + + def __iter__(self): + """to be overwritten by subclass to implement filters""" + return iter(self._iter) + + def type(self, node_type): + """filter by specific node type + + :param node_type: node type class + :return: a new :class:`NodeFilter` object + """ + return NodeFilterType(self, node_type) + + def check_type(self, node_type): + """assert that all oprs produced by this iterator are instances of + certain type + + :param node_type: node type class + :return: a new :class:`NodeFilter` object + :raises TypeError: if type check failed + """ + return NodeFilterCheckType(self, node_type) + + def not_type(self, node_type): + """remove oprs of specific type + + :param node_type: node type class + :return: a new :class:`NodeFilter` object + """ + return NodeFilterNotType(self, node_type) + + def param_provider(self): + """get :class:`.ParamProvider` oprs; shorthand for + ``.type(ParamProvider)``""" + + return self.type(ImmutableTensor) + + def data_provider(self): + """get :class:`.DataProvider` oprs; shorthand for + ``.type(DataProvider)``""" + + return self.type(Host2DeviceCopy) + + def name(self, pattern, ignorecase=True): + """filter by node name + + :param pattern: a string in glob syntax that can contain ``?`` and + ``*`` to match a single or arbitrary characters. + :type pattern: :class:`str` + :param ignorecase: whether to ignroe case + :type ignorecase: bool + :return: a new :class:`NodeFilter` object + """ + return NodeFilterName(self, pattern, ignorecase) + + def has_input(self, var): + """an opr is kept if it has given var as one of its inputs + + :param var: var node to checked + :return: a new :class:`NodeFilter` object + """ + return NodeFilterHasInput(self, var) + + def as_list(self): + """consume this iterator and return its content as a list + + :rtype: [:class:`.GraphNodeBase`] + """ + return list(self) + + def as_unique(self): + """assert that this iterator yields only one node and return it + + :return: the unique node + :rtype: :class:`.GraphNodeBase` + :raises ValueError: if this iterator does not yield a unique node + """ + (opr,) = self + return opr + + def as_dict(self): + """construct an ordered dict to map from node names to objects in + this iterator + + :rtype: :class:`OrderedDict` + """ + return collections.OrderedDict((i.name, i) for i in self) + + def as_count(self): + """consume this iterator and get the number of elements + + :rtype: int + """ + return sum(1 for _ in self) + + +class NodeFilterType(NodeFilter): + """see :meth:`NodeFilter.type`""" + + _node_type = None + + def __init__(self, node_iter, node_type): + assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( + node_type + ) + super().__init__(node_iter) + self._node_type = node_type + + def __iter__(self): + for i in self._iter: + if isinstance(i, self._node_type): + yield i + + +class NodeFilterNotType(NodeFilterType): + """see :meth:`NodeFilter.not_type`""" + + def __iter__(self): + for i in self._iter: + if not isinstance(i, self._node_type): + yield i + + +class NodeFilterCheckType(NodeFilterType): + """see :meth:`NodeFilter.check_type`""" + + def __iter__(self): + for i in self._iter: + if not isinstance(i, self._node_type): + raise TypeError( + "all nodes should be {}; got {!r}".format(self._node_type, i) + ) + yield i + + +class NodeFilterHasInput(NodeFilter): + """see :meth:`NodeFilter.has_input`""" + + _var = None + + def __init__(self, node_iter, var): + var = as_varnode(var) + super().__init__(node_iter) + self.var = var + + def __iter__(self): + for i in self._iter: + assert isinstance( + i, OpNode + ), "has_input() must be used with OpNode; " "got {!r}".format(i) + if self.var in i.inputs: + yield i + + +class NodeFilterName(NodeFilter): + """see :meth:`NodeFilter.name`""" + + _re = None + + def __init__(self, node_iter, pattern, ignorecase): + super().__init__(node_iter) + self._re = self.make_re(pattern, ignorecase) + + @classmethod + def make_re(cls, pattern, ignorecase=True): + assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) + assert isinstance(ignorecase, bool) + flags = 0 + if ignorecase: + flags |= re.IGNORECASE + return re.compile(fnmatch.translate(pattern), flags=flags) + + def __iter__(self): + for i in self._iter: + if self._re.match(i.name): + yield i diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py new file mode 100644 index 0000000000000000000000000000000000000000..ada2df466fad1e23f45f4df628e7f888a86e4ff1 --- /dev/null +++ b/imperative/python/megengine/utils/network_node.py @@ -0,0 +1,628 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import json +import sys +from typing import Callable + +from ..core import _imperative_rt as rt +from ..core._wrap import Device +from ..core.ops import builtin +from ..core.tensor.megbrain_graph import InputNode +from ..tensor import Tensor +from .comp_graph_tools import replace_vars + + +class NetworkNode: + pass + + +class VarNode(NetworkNode): + def __init__(self, owner_opr=None, name=None): + self.var = None + self.owner = owner_opr + self.name = name + self.id = id(self) + + @classmethod + def load(cls, sym_var, owner_opr): + obj = cls() + obj.var = sym_var # mgb varnode + obj.name = sym_var.name + obj.owner = owner_opr + return obj + + @property + def shape(self): + rst = None + if self.var: + try: + rst = self.var.shape + except: + rst = None + return rst + + @property + def dtype(self): + return self.var.dtype if self.var else None + + def set_owner_opr(self, owner_opr): + self.owner_opr = owner_opr + + +class OpNode(NetworkNode): + + opdef = None + type = None + + def __init__(self): + self.inputs = [] + self.outputs = [] + self.params = {} + self._opr = None # mgb opnode + self.id = id(self) + + @classmethod + def load(cls, opr): + obj = cls() + obj.params = json.loads(opr.params) + obj.name = opr.name + obj._opr = opr + return obj + + def compile(self, graph=None): + op = self.opdef(**self.params) + args = [i.var for i in self.inputs] + outputs = rt.invoke_op(op, args) + assert len(outputs) == len(self.outputs) + self._opr = outputs[0].owner + for i in range(len(self.outputs)): + self.outputs[i].var = outputs[i] + self.outputs[i].var.name = self.outputs[i].name + assert self.outputs[i].owner is self + + def add_inp_var(self, x): + self.inputs.append(x) + + def add_out_var(self, x): + self.outputs.append(x) + + +def str_to_mge_class(classname): + # TODO: use megbrain C++ RTTI to replace type string + if classname == "RNGOpr": + classname = "RNGOpr" + oprcls = getattr(sys.modules[__name__], classname, None) + return oprcls if oprcls else ReadOnlyOpNode + + +class Host2DeviceCopy(OpNode): + type = "Host2DeviceCopy" + + def __init__(self, shape=None, dtype=None, name=None, device=None): + super().__init__() + self.shape = shape + self.dtype = dtype + self.name = name + self.device = Device(device).to_c() if device else Device("xpux").to_c() + self.outputs = [] + + @classmethod + def load(cls, opr): + self = cls() + self.outputs = [] + assert len(opr.outputs) == 1, "wrong number of outputs" + self.shape = opr.outputs[0].shape + self.dtype = opr.outputs[0].dtype + self.name = opr.outputs[0].name + self.device = opr.outputs[0].comp_node + self._opr = opr + return self + + def compile(self, graph): + outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) + self._opr = outputs.owner + if len(self.outputs) == 0: + self.outputs.append(VarNode(self, self.name)) + self.outputs[0].var = outputs + assert self.outputs[0].owner is self + + +class ImmutableTensor(OpNode): + type = "ImmutableTensor" + + def __init__(self, data=None, name=None, device=None, graph=None): + super().__init__() + self.name = name + self.outputs = [] + self.graph = graph + if data is not None: + self.set_value(data, device) + + @property + def device(self): + return self._opr.outputs[0].comp_node if self._opr else None + + @device.setter + def device(self, device): + self.set_value(self.numpy(), device) + + @property + def shape(self): + return self.outputs[0].shape + + @property + def dtype(self): + return self._opr.outputs[0].dtype if self._opr else None + + def numpy(self): + return self._opr.outputs[0].value if self._opr else None + + def set_value(self, data, device=None): + assert self.graph is not None + cn = device if device else self.device + assert isinstance(data, (int, float, np.ndarray)) + if isinstance(data, (int, float)): + data = np.array(data) + if data.dtype == np.float64: + data = data.astype(np.float32) + elif data.dtype == np.int64: + data = data.astype(np.int32) + varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) + if len(self.outputs) == 0: + self.outputs.append(VarNode(self, self.name)) + self.outputs[0].var = varnode + self._opr = varnode.owner + + @classmethod + def load(cls, opr): + self = cls() + self.outputs = [] + self._opr = opr + self.name = opr.outputs[0].name + self.graph = opr.graph + return self + + def compile(self, graph): + assert self.outputs[0].var is self._opr.outputs[0] + assert self.outputs[0].owner is self + if self.graph != graph: + self.graph = graph + self.set_value(self.numpy()) + if self.name is not None: + self.outputs[0].var.name = self.name + + +class ReadOnlyOpNode(OpNode): + @classmethod + def load(cls, opr): + obj = super(ReadOnlyOpNode, cls).load(opr) + obj.type = opr.type + return obj + + def compile(self): + assert self._opr is not None + assert len(self.inputs) == len(self._opr.inputs) + assert len(self.outputs) == len(self._opr.outputs) + repl_dict = {} + for ind, i in enumerate(self.inputs): + if i.var != self._opr.inputs[ind]: + repl_dict[self._opr.inputs[ind]] = i.var + if bool(repl_dict): + out_vars = replace_vars(self._opr.outputs, repl_dict) + for ind, o in enumerate(self.outputs): + o.var = out_vars[ind] + + +class Elemwise(OpNode): + type = "Elemwise" + opdef = builtin.Elemwise + + +class Reduce(OpNode): + type = "Reduce" + opdef = builtin.Reduce + + +class TypeCvt(OpNode): + type = "TypeCvt" + opdef = builtin.TypeCvt + + @classmethod + def load(cls, opr): + obj = super(TypeCvt, cls).load(opr) + t_dtype = opr.outputs[0].dtype + obj.params["dtype"] = t_dtype + return obj + + +class MatrixInverse(OpNode): + type = "MatrixInverse" + opdef = builtin.MatrixInverse + + +class MatrixMul(OpNode): + type = "MatrixMul" + opdef = builtin.MatrixMul + + +class BatchedMatrixMul(OpNode): + type = "BatchedMatmul" + opdef = builtin.BatchedMatrixMul + + +class Dot(OpNode): + type = "Dot" + opdef = builtin.Dot + + +class SVD(OpNode): + type = "SVD" + opdef = builtin.SVD + + +class ConvolutionForward(OpNode): + type = "Convolution" + opdef = builtin.Convolution + + +class ConvolutionBackwardData(OpNode): + type = "ConvTranspose" + opdef = builtin.ConvolutionBackwardData + + +class DeformableConvForward(OpNode): + type = "DeformableConv" + opdef = builtin.DeformableConv + + +class GroupLocalForward(OpNode): + type = "GroupLocal" + opdef = builtin.GroupLocal + + +class PoolingForward(OpNode): + type = "Pooling" + opdef = builtin.Pooling + + +class AdaptivePoolingForward(OpNode): + type = "AdaptivePooling" + opdef = builtin.AdaptivePooling + + +class ROIPoolingForward(OpNode): + type = "ROIPooling" + opdef = builtin.ROIPooling + + +class DeformablePSROIPoolingForward(OpNode): + type = "DeformablePSROIPooling" + opdef = builtin.DeformablePSROIPooling + + +class ConvBiasForward(OpNode): + type = "ConvBias" + opdef = builtin.ConvBias + + @classmethod + def load(cls, opr): + obj = super(ConvBiasForward, cls).load(opr) + obj.params["dtype"] = opr.outputs[0].dtype + return obj + + +class BatchConvBiasForward(OpNode): + type = "BatchConvBias" + opdef = builtin.BatchConvBias + + @classmethod + def load(cls, opr): + obj = super(BatchConvBiasForward, cls).load(opr) + obj.params["dtype"] = opr.outputs[0].dtype + return obj + + +class BatchNormForward(OpNode): + type = "BatchNorm" + opdef = builtin.BatchNorm + + +class ROIAlignForward(OpNode): + type = "ROIAlign" + opdef = builtin.ROIAlign + + +class WarpPerspectiveForward(OpNode): + type = "WarpPerspective" + opdef = builtin.WarpPerspective + + +class WarpAffineForward(OpNode): + type = "WarpAffine" + opdef = builtin.WarpAffine + + +class RemapForward(OpNode): + type = "Remap" + opdef = builtin.Remap + + +class ResizeForward(OpNode): + type = "Resize" + opdef = builtin.Resize + + +class IndexingOneHot(OpNode): + type = "IndexingOneHot" + opdef = builtin.IndexingOneHot + + +class IndexingSetOneHot(OpNode): + type = "IndexingSetOneHot" + opdef = builtin.IndexingSetOneHot + + +class Copy(OpNode): + type = "Copy" + opdef = builtin.Copy + + @classmethod + def load(cls, opr): + obj = super(Copy, cls).load(opr) + obj.params["comp_node"] = opr.outputs[0].comp_node + return obj + + +class ArgsortForward(OpNode): + type = "Argsort" + opdef = builtin.Argsort + + +class Argmax(OpNode): + type = "Argmax" + opdef = builtin.Argmax + + +class Argmin(OpNode): + type = "Argmin" + opdef = builtin.Argmin + + +class CondTake(OpNode): + type = "CondTake" + opdef = builtin.CondTake + + +class TopK(OpNode): + type = "TopK" + opdef = builtin.TopK + + +class NvOf(OpNode): + type = "NvOf" + opdef = builtin.NvOf + + +class RNGOpr(OpNode): + @classmethod + def load(cls, opr): + obj = super(RNGOpr, cls).load(opr) + if len(obj.params) == 3: + obj.opdef = builtin.GaussianRNG + obj.type = "GaussianRNG" + else: + obj.opdef = builtin.UniformRNG + obj.type = "UniformRNG" + return obj + + +class Linspace(OpNode): + type = "Linspace" + opdef = builtin.Linspace + + @classmethod + def load(cls, opr): + obj = super(Linspace, cls).load(opr) + obj.params["comp_node"] = opr.outputs[0].comp_node + return obj + + +class Eye(OpNode): + type = "Eye" + opdef = builtin.Eye + + @classmethod + def load(cls, opr): + obj = super(Eye, cls).load(opr) + obj.params["dtype"] = opr.outputs[0].dtype + obj.params["comp_node"] = opr.outputs[0].comp_node + return obj + + +class GetVarShape(OpNode): + type = "GetVarShape" + opdef = builtin.GetVarShape + + +class Concat(OpNode): + type = "Concat" + opdef = builtin.Concat + + @classmethod + def load(cls, opr): + obj = super(Concat, cls).load(opr) + obj.params["comp_node"] = Device("xpux").to_c() + return obj + + +class Broadcast(OpNode): + type = "Broadcast" + opdef = builtin.Broadcast + + +class Identity(OpNode): + type = "Identity" + opdef = builtin.Identity + + +class NMSKeep(OpNode): + type = "NMSKeep" + opdef = builtin.NMSKeep + + +# class ParamPackSplit +# class ParamPackConcat + + +class Dimshuffle(OpNode): + type = "Dimshuffle" + opdef = builtin.Dimshuffle + + @classmethod + def load(cls, opr): + obj = super(Dimshuffle, cls).load(opr) + del obj.params["ndim"] + return obj + + +class Reshape(OpNode): + type = "Reshape" + opdef = builtin.Reshape + + +class AxisAddRemove(OpNode): + type = "AxisAddRemove" + + @classmethod + def load(cls, opr): + obj = cls() + obj.name = opr.name + obj._opr = opr + params = json.loads(opr.params) + desc = params["desc"] + method = None + axis = [] + for i in desc: + if method is None: + method = i["method"] + assert method == i["method"] + axis.append(i["axisnum"]) + obj.params = {"axis": axis} + obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis + return obj + + +class IndexingBase(OpNode): + @classmethod + def load(cls, opr): + obj = cls() + obj.name = opr.name + obj._opr = opr + params = json.loads(opr.params) + items = [ + [ + p["axis"], + bool(p["begin"]), + bool(p["end"]), + bool(p["step"]), + bool(p["idx"]), + ] + for p in params + ] + obj.params["items"] = items + return obj + + +class Subtensor(IndexingBase): + type = "Subtensor" + opdef = builtin.Subtensor + + +class SetSubtensor(IndexingBase): + type = "SetSubtensor" + opdef = builtin.SetSubtensor + + +class IncrSubtensor(IndexingBase): + type = "IncrSubtensor" + opdef = builtin.IncrSubtensor + + +class IndexingMultiAxisVec(IndexingBase): + type = "IndexingMultiAxisVec" + opdef = builtin.IndexingMultiAxisVec + + +class IndexingSetMultiAxisVec(IndexingBase): + type = "IndexingSetMultiAxisVec" + opdef = builtin.IndexingSetMultiAxisVec + + +class IndexingIncrMultiAxisVec(IndexingBase): + type = "IndexingIncrMultiAxisVec" + opdef = builtin.IndexingIncrMultiAxisVec + + +class MeshIndexing(IndexingBase): + type = "MeshIndexing" + opdef = builtin.MeshIndexing + + +class SetMeshIndexing(IndexingBase): + type = "SetMeshIndexing" + opdef = builtin.SetMeshIndexing + + +class IncrMeshIndexing(IndexingBase): + type = "IncrMeshIndexing" + opdef = builtin.IncrMeshIndexing + + +class BatchedMeshIndexing(IndexingBase): + type = "BatchedMeshIndexing" + opdef = builtin.BatchedMeshIndexing + + +class BatchedSetMeshIndexing(IndexingBase): + type = "BatchedSetMeshIndexing" + opdef = builtin.BatchedSetMeshIndexing + + +class BatchedIncrMeshIndexing(IndexingBase): + type = "BatchedIncrMeshIndexing" + opdef = builtin.BatchedIncrMeshIndexing + + +# class CollectiveComm +# class RemoteSend +# class RemoteRecv +# class TQT +# class FakeQuant +# class InplaceAdd + + +class AssertEqual(OpNode): + type = "AssertEqual" + opdef = builtin.AssertEqual + + +class ElemwiseMultiType(OpNode): + type = "ElemwiseMultiType" + opdef = builtin.ElemwiseMultiType + + @classmethod + def load(cls, opr): + obj = super(ElemwiseMultiType, cls).load(opr) + obj.params["dtype"] = opr.outputs[0].dtype + return obj + + +class CvtColorForward(OpNode): + type = "CvtColor" + opdef = builtin.CvtColor diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 453003598af7f2ac23f4022cb22c3ef7b457ce0e..f5eac698e00003af002b1ee1551d771e2069dda5 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -160,6 +160,16 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje if (ctx.op->same_type()) { ctx.backward = true; } + + + if (py::isinstance(py::handle(args[0]))){ + SmallVector vinputs(nargs); + for (size_t i = 0; i < nargs; ++i) { + vinputs[i] = py::handle(args[i]).cast(); + } + auto op = ctx.op.get(); + return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); + } for (size_t i = 0; i < nargs; ++i) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { @@ -675,6 +685,16 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { tensors.emplace_back(descr); continue; } + + if (py::isinstance(py::handle(handle))){ + auto var = py::handle(handle).cast(); + mgb::DType type = var->dtype(); + auto && descr = npy::dtype_mgb2np_descr(type); + Py_INCREF(descr.get()); + tensors.emplace_back(descr.get()); + continue; + } + PyArray_Descr* descr = scalar2dtype(handle); if (descr) { scalars.emplace_back(descr); @@ -719,12 +739,14 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { for (size_t i = 0; i < nargs; ++i) { PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; TensorWrapper* tw = TensorWrapper::try_cast(handle); - if (tw) { + + bool is_var = py::isinstance(py::handle(handle)); + if (tw || is_var) { if (!valid) { - cn = tw->m_tensor->comp_node(); + cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast()->comp_node(); valid = true; } else { - CompNode cn1 = tw->m_tensor->comp_node(); + CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast()->comp_node(); if (cn1 != cn) { throw py::value_error(ssprintf("ambiguous device: %s vs %s", cn.to_string().c_str(), cn1.to_string().c_str())); diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py new file mode 100644 index 0000000000000000000000000000000000000000..e398fd64abd7be912f86b73b241ac146e8a93a84 --- /dev/null +++ b/imperative/python/test/unit/utils/test_network.py @@ -0,0 +1,351 @@ +import io + +import numpy as np + +import megengine.core.tensor.megbrain_graph as G +import megengine.functional as F +import megengine.module as M +import megengine.utils.network_node as N +from megengine.jit.tracing import trace +from megengine.tensor import Tensor +from megengine.utils.comp_graph_tools import GraphInference +from megengine.utils.network import Network as Net +from megengine.utils.network import as_oprnode +from megengine.utils.network_node import Host2DeviceCopy, VarNode + + +def test_replace_var(): + + a = Tensor([1, 2]) + b = Tensor([3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + graph = Net.load(orig_model) + vara = graph.var_filter.name("a").as_unique() + varb = graph.var_filter.name("b").as_unique() + + out = F.mul(vara.var, varb.var) + out = F.relu(out) + + var_list = graph.add_dep_oprs(out) + + opnode = list(graph.opr_filter.has_input(vara)) + repl_dict = {opnode[0].outputs[0]: var_list[0]} + graph.replace_vars(repl_dict) + + modified_model = io.BytesIO() + graph.dump(modified_model) + modified_model.seek(0) + load_graph = GraphInference(modified_model) + + out = load_graph.run(a, b) + np.testing.assert_equal(out["o"], [6, 16]) + + +def test_replace_opr(): + + a = Tensor([1, 2]) + b = Tensor([3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + graph = Net.load(orig_model) + vara = graph.var_filter.name("a").as_unique() + varb = graph.var_filter.name("b").as_unique() + + out1 = F.sub(vara.var, varb.var) + out1 = F.relu(out1) + + var_list = graph.add_dep_oprs(out1) + repl_opr = as_oprnode(var_list) + orig_opr = graph.opr_filter.has_input(vara).as_unique() + + repl_dict = {orig_opr: repl_opr} + graph.replace_oprs(repl_dict) + modified_model1 = io.BytesIO() + graph.dump(modified_model1) + modified_model1.seek(0) + + load_graph = GraphInference(modified_model1) + out = load_graph.run(a, b) + np.testing.assert_equal(out["o"], [0, 0]) + + +def test_modify_params(): + + a = Tensor([1, 2]) + b = Tensor([3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + graph = Net.load(orig_model) + param_const = graph.params_filter.as_unique() + param_const.set_value(3) + + modified_model = io.BytesIO() + graph.dump(modified_model) + modified_model.seek(0) + load_graph = GraphInference(modified_model) + + out = load_graph.run(a, b) + np.testing.assert_equal(out["o"], [12, 18]) + + +def test_make_const(): + + a = Tensor([1, 2]) + b = Tensor([3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + graph = Net.load(orig_model) + const_b = graph.make_const(np.array([0.0, 0.0]), name="b") + varb = graph.var_filter.name("b").as_unique() + + repl_dict = {varb: const_b} + graph.replace_vars(repl_dict) + + modified_model = io.BytesIO() + graph.dump(modified_model) + modified_model.seek(0) + load_graph = GraphInference(modified_model) + + out = load_graph.run(a) + np.testing.assert_equal(out["o"], [2, 4]) + + +def test_add_input(): + + a = Tensor([1, 2]) + b = Tensor([3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + graph = Net.load(orig_model) + inp_c = graph.make_input_node((2,), np.int32, name="c") + varo = graph.var_filter.name("o").as_unique() + + out = F.add(varo.var, inp_c.var) + out = graph.add_dep_oprs(out)[0] + out.name = "o1" + graph.remove_output(varo) + graph.add_output(out) + modified_model = io.BytesIO() + + graph.dump(modified_model) + modified_model.seek(0) + load_graph = GraphInference(modified_model) + + out = load_graph.run(a, b, a) + np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) + + +def test_add_output(): + + a = Tensor([1.0, 2.0]) + b = Tensor([3.0, 4.0]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(a, b): + return (a + b) * 2 + + fwd(a, b) + orig_model = io.BytesIO() + fwd.dump( + orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False + ) + orig_model.seek(0) + + net = Net.load(orig_model) + var_a = net.var_filter.name("a").as_unique() + var_b = net.var_filter.name("b").as_unique() + + y = F.add(var_a.var, var_b.var) + y = F.sigmoid(y) + + new_vars = net.add_dep_oprs(y)[0] + new_vars.name = "o1" + net.add_output(new_vars) + + modified_model = io.BytesIO() + net.dump(modified_model) + modified_model.seek(0) + + g = GraphInference(modified_model) + out = g.run(a.numpy(), b.numpy()) + + np.testing.assert_equal(out["o"], ((a + b) * 2).numpy()) + np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy()) + + +def test_query(): + class Model(M.Module): + def __init__(self): + super().__init__() + self.conv1 = M.Conv2d(3, 32, 3) + self.conv2 = M.Conv2d(32, 32, 3) + self.conv3 = M.Conv2d(32, 32, 3) + + def forward(self, data): + x = self.conv1(data) + x = self.conv2(x) + x = self.conv3(x) + return x + + n = Model() + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return n(data) + + fwd(Tensor(np.random.random((1, 3, 224, 224)))) + orig_model = io.BytesIO() + fwd.dump( + orig_model, + arg_names=["data"], + output_names="o", + keep_opr_name=True, + keep_var_name=True, + optimize_for_inference=False, + ) + orig_model.seek(0) + + graph = Net.load(orig_model) + + r = graph.data_providers_filter.as_count() + assert r == 1 + + opr = graph.get_opr_by_type(Host2DeviceCopy) + assert isinstance(opr, Host2DeviceCopy) + + r1 = graph.params_filter.as_count() + assert r1 == 6 + + r2 = graph.opr_filter.type(N.ConvolutionForward).as_count() + assert r2 == 3 + + r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count() + assert r3 == len(graph.all_oprs) - r2 + + var = graph.var_filter.name("data").as_unique() + r4 = graph.opr_filter.has_input(var).as_count() + assert r4 == 1 + + r5 = graph.opr_filter.name("data").as_count() + assert r5 == 1 + + opr = graph.get_opr_by_name("data") + assert isinstance(opr, Host2DeviceCopy) + + var = graph.get_var_by_name("data") + assert isinstance(var, VarNode) + + r6 = graph.var_filter.name("*bias").as_count() + assert r6 == 3 + + +def test_optimize_for_inference(): + @trace(symbolic=True, capture_as_const=True) + def f(x): + return F.exp(x) + + orig_model = io.BytesIO() + f(Tensor(5.0)) + f.dump(orig_model, optimize_for_inference=False) + orig_model.seek(0) + + optimize_model = io.BytesIO() + net = Net.load(orig_model) + net.dump(optimize_model, enable_io16xc32=True) + optimize_model.seek(0) + + res = G.load_graph(optimize_model) + computing_input = res.output_vars_list[0].owner.inputs[0] + assert computing_input.dtype == np.float16 + + +def test_reset_batchsize(): + @trace(symbolic=True, capture_as_const=True) + def f(x): + return F.exp(x) + + orig_model = io.BytesIO() + f(Tensor(np.random.random((3, 3, 224, 224)))) + f.dump(orig_model, optimize_for_inference=False) + orig_model.seek(0) + + modified_model = io.BytesIO() + net = Net.load(orig_model) + net.reset_batch_size(1) + net.dump(modified_model, optimize_for_inference=False) + modified_model.seek(0) + + net1 = Net.load(modified_model) + assert net1.data_providers_filter.as_unique().shape[0] == 1 + + +def test_modify_opr_name(): + @trace(symbolic=True, capture_as_const=True) + def f(x): + return F.exp(x) + + orig_model = io.BytesIO() + f(Tensor(np.random.random((3, 3, 224, 224)))) + f.dump(orig_model, arg_names=["a"], optimize_for_inference=False) + orig_model.seek(0) + + modified_model = io.BytesIO() + net = Net.load(orig_model) + net.modify_opr_names("net") + net.modify_opr_names(lambda x: "net1." + x) + net.dump(modified_model, optimize_for_inference=False) + modified_model.seek(0) + + net1 = Net.load(modified_model) + assert net1.data_providers_filter.as_unique().name == "net1.net.a" diff --git a/imperative/python/test/unit/utils/test_opr.py b/imperative/python/test/unit/utils/test_opr.py new file mode 100644 index 0000000000000000000000000000000000000000..008be7bdf3824f27da97711dd039473e5b960e6c --- /dev/null +++ b/imperative/python/test/unit/utils/test_opr.py @@ -0,0 +1,712 @@ +import io +import os +import platform + +import numpy as np +import pytest + +import megengine.core.tensor.dtype as dtype +import megengine.core.tensor.megbrain_graph as G +import megengine.functional as F +import megengine.module as M +import megengine.random as rand +from megengine.core._imperative_rt.core2 import apply +from megengine.core._wrap import Device +from megengine.core.ops import builtin +from megengine.device import is_cuda_available +from megengine.functional.external import tensorrt_runtime_opr +from megengine.jit.tracing import trace +from megengine.tensor import Tensor +from megengine.utils.comp_graph_tools import GraphInference +from megengine.utils.network import Network as Net + + +def check_pygraph_dump(trace_func, inp_data, expect_results): + orig_model = io.BytesIO() + inp_size = len(inp_data) + out_size = len(expect_results) + arg_names = ["arg_{}".format(i) for i in range(inp_size)] + output_names = ["out_{}".format(i) for i in range(out_size)] + trace_func.dump( + orig_model, + arg_names=arg_names, + output_names=output_names, + optimize_for_inference=False, + ) + orig_model.seek(0) + + net = Net.load(orig_model) + file = io.BytesIO() + net.dump(file, optimize_for_inference=False) + file.seek(0) + graph = GraphInference(file) + + inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)]) + results = graph.run(inp_dict=inp_dict) + + for ind, tensor in enumerate(expect_results): + np.testing.assert_equal(tensor.numpy(), results[output_names[ind]]) + assert tensor.dtype == results[output_names[ind]].dtype + + +def test_elemwise(): + @trace(symbolic=True, capture_as_const=True) + def fwd(x, y): + z1 = x * y + z2 = x + y + z3 = z1 / z2 + z3 = z3 ** 3 + return z3 + + x = Tensor([1.0, 2.0]) + y = Tensor([3.0, 5.0]) + result = fwd(x, y) + check_pygraph_dump(fwd, [x, y], [result]) + + +def test_reduce(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + x = data.sum(axis=2) + x = x.mean(axis=1) + return x + + data = Tensor(np.random.random((1, 32, 32))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_typecvt(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return data.astype(dtype.qint8(0.8)) + + x = Tensor(np.random.random((2, 3)) * 255) + result = fwd(x) + check_pygraph_dump(fwd, [x], [result]) + + +def test_matinv(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return F.matinv(data) + + data = Tensor(np.random.random((5, 5))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_matmul(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data1, data2): + return F.matmul(data1, data2) + + data1 = Tensor(np.random.random((32, 64))) + data2 = Tensor(np.random.random((64, 16))) + result = fwd(data1, data2) + check_pygraph_dump(fwd, [data1, data2], [result]) + + +def test_batchmatmul(): + @trace(symbolic=True, capture_as_const=True) + def fwd(x, y): + return F.matmul(x, y) + + x = Tensor(np.random.random((3, 3, 5))) + y = Tensor(np.random.random((3, 5, 3))) + result = fwd(x, y) + check_pygraph_dump(fwd, [x, y], [result]) + + +def test_dot(): + @trace(symbolic=True, capture_as_const=True) + def fwd(x, y): + return F.dot(x, y) + + x = Tensor([1.0, 2.0, 3.0]) + y = Tensor([3.0, 4.0, 5.0]) + result = fwd(x, y) + check_pygraph_dump(fwd, [x, y], [result]) + + +def test_svd(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + _, out, _ = F.svd(data) + return out + + input = Tensor(np.random.random((1, 1, 3, 3))) + result = fwd(input) + check_pygraph_dump(fwd, [input], [result]) + + +def test_conv(): + conv = M.Conv2d(3, 32, 3) + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return conv(data) + + data = Tensor(np.random.random((1, 3, 32, 32))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_deformable_conv(): + if not is_cuda_available(): + return + conv = M.DeformableConv2d(3, 32, 3) + + @trace(symbolic=True, capture_as_const=True) + def fwd(data, offset, mask): + return conv(data, offset, mask) + + data = Tensor(np.random.random((1, 3, 32, 32))) + offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5) + mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32")) + out = fwd(data, offset, mask) + check_pygraph_dump(fwd, [data, offset, mask], [out]) + + +def test_convtranspose(): + deconv = M.ConvTranspose2d(32, 32, 3) + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return deconv(data) + + data = Tensor(np.random.random((1, 32, 32, 32))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +@pytest.mark.skip(reason="pytest aborted") +def test_grouplocal(): + n = M.LocalConv2d(3, 32, 32, 32, 3) + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return n(data) + + input = Tensor(np.random.random((1, 3, 32, 32))) + result = fwd(input) + check_pygraph_dump(fwd, [input], [result]) + + +def test_pooling(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + out = F.max_pool2d(data, 2, 2) + out = F.avg_pool2d(out, 2, 2) + return out + + data = Tensor(np.random.random((1, 3, 64, 64))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_adaptivepooling(): + pool1 = M.AdaptiveMaxPool2d((2, 2)) + pool2 = M.AdaptiveAvgPool2d((2, 2)) + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + out = pool1(data) + out = pool2(out) + return out + + input = Tensor(np.random.random((1, 3, 32, 32))) + result = fwd(input) + check_pygraph_dump(fwd, [input], [result]) + + +def test_roipooling(): + inp = Tensor(np.random.random((1, 1, 128, 128))) + rois = Tensor(np.random.random((4, 5))) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, rois): + return F.nn.roi_pooling(inp, rois, (2, 2), scale=2.0) + + output = fwd(inp, rois) + check_pygraph_dump(fwd, [inp, rois], [output]) + + +def test_deformable_ps_roi_pooling(): + inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32")) + rois = Tensor(np.random.random((1, 5)).astype("float32")) + trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32")) + + pooled_h = 7 + pooled_w = 7 + sample_per_part = 4 + no_trans = False + part_size = 7 + spatial_scale = 1.0 / 64 + trans_std = 0.1 + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, rois, trans): + y = F.deformable_psroi_pooling( + inp, + rois, + trans, + no_trans, + part_size, + pooled_h, + pooled_w, + sample_per_part, + spatial_scale, + trans_std, + ) + return y + + result = fwd(inp, rois, trans) + check_pygraph_dump(fwd, [inp, rois, trans], [result]) + + +def test_convbias(): + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, weight, bias): + return F.quantized.conv_bias_activation( + inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" + ) + + inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) + weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) + bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) + result = fwd(inp, weight, bias) + check_pygraph_dump(fwd, [inp, weight, bias], [result]) + + +def test_batch_convbias(): + if is_cuda_available(): + return + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, weight, bias): + return F.quantized.batch_conv_bias_activation( + inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" + ) + + inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) + weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) + bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) + result = fwd(inp, weight, bias) + check_pygraph_dump(fwd, [inp, weight, bias], [result]) + + +def test_batchnorm(): + bn = M.BatchNorm2d(32) + bn.eval() + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return bn(data) + + data = Tensor(np.random.random((1, 32, 32, 32))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_roialign(): + inp = Tensor(np.random.randn(1, 1, 128, 128)) + rois = Tensor(np.random.random((4, 5))) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, rois): + return F.nn.roi_align(inp, rois, (2, 2)) + + output = fwd(inp, rois) + check_pygraph_dump(fwd, [inp, rois], [output]) + + +def test_warpperspective(): + inp_shape = (1, 1, 4, 4) + x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + M_shape = (1, 3, 3) + # M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) + M = Tensor( + np.array( + [[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 + ).reshape(M_shape) + ) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x, M): + return F.warp_perspective(x, M, (2, 2)) + + result = fwd(x, M) + check_pygraph_dump(fwd, [x, M], [result]) + + +def test_warpaffine(): + inp_shape = (1, 3, 3, 3) + x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) + weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x, weightv): + return F.warp_affine(x, weightv, (2, 2), border_mode="WRAP") + + outp = fwd(x, weightv) + check_pygraph_dump(fwd, [x, weightv], [outp]) + + +def test_remap(): + inp_shape = (1, 1, 4, 4) + inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + map_xy_shape = (1, 2, 2, 2) + map_xy = Tensor( + np.array( + [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32 + ).reshape(map_xy_shape) + ) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, map_xy): + return F.remap(inp, map_xy) + + out = fwd(inp, map_xy) + check_pygraph_dump(fwd, [inp, map_xy], [out]) + + +def test_resize(): + x = Tensor(np.random.randn(10, 3, 32, 32)) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x): + return F.nn.interpolate(x, size=(16, 16), mode="BILINEAR") + + out = fwd(x) + check_pygraph_dump(fwd, [x], [out]) + + +def test_index_onehot(): + src = Tensor([[1.0, 2.0]]) + index = Tensor([0]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(src, index): + return F.indexing_one_hot(src, index) + + out = fwd(src, index) + check_pygraph_dump(fwd, [src, index], [out]) + + +def test_set_onehot(): + x = Tensor(np.arange(1, 4, dtype=np.int32)) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x): + return F.one_hot(x, num_classes=4) + + out = fwd(x) + check_pygraph_dump(fwd, [x], [out]) + + +def test_copy(): + x = Tensor([1, 2, 3]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x): + return x.to("cpu0:0") + + o = fwd(x) + check_pygraph_dump(fwd, [x], [o]) + + +def test_argsort(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return F.argsort(data, True) + + data = Tensor([1.0, 2.0, 3.0, 5.0]) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_argmax_min(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return F.argmax(data), F.argmin(data) + + data = Tensor(np.random.random((10, 10))) + result = fwd(data) + check_pygraph_dump(fwd, [data], result) + + +def test_condtake(): + mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) + x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) + + @trace(symbolic=True, capture_as_const=True) + def fwd(mask, x): + v, index = F.cond_take(mask, x) + return v, index + + v, index = fwd(mask, x) + check_pygraph_dump(fwd, [mask, x], [v, index]) + + +def test_topk(): + x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x): + top, indices = F.topk(x, 5) + return top, indices + + top, indices = fwd(x) + check_pygraph_dump(fwd, [x], [top, indices]) + + + + +def test_random(): + @trace(symbolic=True, capture_as_const=True) + def fwd(): + x = rand.uniform(size=(2, 2)) + y = rand.normal(size=(1, 3, 3, 3)) + return x, y + + x, y = fwd() + check_pygraph_dump(fwd, [], [x, y]) + + +def test_tensor_gen(): + @trace(symbolic=True, capture_as_const=True) + def fwd(): + a = F.linspace(3, 10, 3, device=Device("xpux").to_c()) + b = F.eye(3, device=Device("xpux").to_c()) + return a, b + + a, b = fwd() + check_pygraph_dump(fwd, [], [a, b]) + + +def test_getvarshape(): + op = builtin.GetVarShape(axis=1) + + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return apply(op, data)[0] + + data = Tensor(np.random.random((1, 2, 3, 4))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_concat(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data1, data2): + return F.concat([data1, data2], axis=1) + + x = Tensor(np.random.random((2, 3))) + y = Tensor(np.random.random((2, 5))) + result = fwd(x, y) + check_pygraph_dump(fwd, [x, y], [result]) + + +def test_broadcast(): + inp = Tensor([[1], [2], [3], [4]]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp): + return F.broadcast_to(inp, (4, 4)) + + out = fwd(inp) + check_pygraph_dump(fwd, [inp], [out]) + + +def test_identity(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return F.copy(data) + + data = Tensor([1.0, 2.0]) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +@pytest.mark.skip(reason="advance indexing trace error") +def test_nms(): + x = np.zeros((100, 4)) + np.random.seed(42) + x[:, :2] = np.random.rand(100, 2) * 20 + x[:, 2:] = np.random.rand(100, 2) * 20 + 100 + scores = Tensor(np.random.rand(100)) + inp = Tensor(x) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp, scores): + return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3) + + result = fwd(inp, scores) + check_pygraph_dump(fwd, [inp, scores], [result]) + + +def test_dimshuffle(): + inp = Tensor([1, 2, 3, 4]) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp): + return inp.T + + out = fwd(inp) + check_pygraph_dump(fwd, [inp], [out]) + + +def test_reshape(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + return data.reshape((1, 8)) + + data = Tensor(np.random.random((1, 2, 2, 2))) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +def test_add_remove_axis(): + @trace(symbolic=True, capture_as_const=True) + def fwd(data): + x = F.expand_dims(data, [0, 0]) + y = F.squeeze(x, 0) + return y + + data = Tensor([1.0, 2.0]) + result = fwd(data) + check_pygraph_dump(fwd, [data], [result]) + + +@pytest.mark.parametrize("mode", ["get", "set", "inc"]) +def test_subtensor(mode): + items = [[0, True, True, True, False], [1, False, False, False, True]] + data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))] + if mode == "get": + op = builtin.Subtensor(items) + data = data[:1] + if mode == "set": + op = builtin.SetSubtensor(items) + if mode == "inc": + op = builtin.IncrSubtensor(items) + tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)] + + @trace(symbolic=True, capture_as_const=True) + def fwd(*tensors): + return apply(op, *tensors)[0] + + result = fwd(*data, *tensors) + check_pygraph_dump(fwd, data + tensors, [result]) + + +@pytest.mark.parametrize("mode", ["get", "set", "inc"]) +def test_advance_indexing(mode): + items = [[0, False, False, False, True]] + tensors = [Tensor([0, 4, 2])] + data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))] + if mode == "get": + op = builtin.IndexingMultiAxisVec(items) + data = data[:1] + if mode == "set": + op = builtin.IndexingSetMultiAxisVec(items) + if mode == "inc": + op = builtin.IndexingIncrMultiAxisVec(items) + + @trace(symbolic=True, capture_as_const=True) + def fwd(*tensors): + return apply(op, *tensors)[0] + + result = fwd(*data, *tensors) + check_pygraph_dump(fwd, data + tensors, [result]) + + +@pytest.mark.parametrize("mode", ["get", "set", "inc"]) +def test_mesh_indexing(mode): + items = [[0, True, True, True, False], [1, False, False, False, True]] + tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])] + data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))] + if mode == "get": + op = builtin.IndexingMultiAxisVec(items) + data = data[:1] + if mode == "set": + op = builtin.IndexingSetMultiAxisVec(items) + if mode == "inc": + op = builtin.IndexingIncrMultiAxisVec(items) + + @trace(symbolic=True, capture_as_const=True) + def fwd(*tensors): + return apply(op, *tensors)[0] + + result = fwd(*data, *tensors) + check_pygraph_dump(fwd, data + tensors, [result]) + + +@pytest.mark.parametrize("mode", ["get", "set", "inc"]) +def test_batch_mesh_indexing(mode): + items = [[1, False, False, False, True], [2, False, False, False, True]] + tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])] + data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))] + if mode == "get": + op = builtin.BatchedMeshIndexing(items) + data = data[:1] + if mode == "set": + op = builtin.BatchedSetMeshIndexing(items) + if mode == "inc": + op = builtin.BatchedIncrMeshIndexing(items) + + @trace(symbolic=True, capture_as_const=True) + def fwd(*tensors): + return apply(op, *tensors)[0] + + result = fwd(*data, *tensors) + check_pygraph_dump(fwd, data + tensors, [result]) + + +@pytest.mark.skip(reason="tmp skip") +def test_assert_equal(): + g = G.Graph() + inp1 = g.make_h2d(dtype=np.float32, device="xpux") + inp2 = g.make_h2d(dtype=np.float32, device="xpux") + op = builtin.AssertEqual(maxerr=1e-5) + out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] + print(out) + g.compile(out) + file = io.BytesIO() + out_model = G.dump_graph([out]) + file.write(out_model[0]) + file.seek(0) + net = Net.load(file) + + dump_file = io.BytesIO() + net.dump(dump_file) + dump_file.seek(0) + g = GraphInference(dump_file) + g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0])) + + +def test_elemwise_multitype(): + op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0)) + + @trace(symbolic=True, capture_as_const=True) + def fwd(x, y): + return apply(op, x, y)[0] + + x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) + y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) + result = fwd(x, y) + check_pygraph_dump(fwd, [x, y], [result]) + + + + +def test_cvtcolor(): + inp = np.random.randn(3, 3, 3, 3).astype(np.float32) + x = Tensor(inp) + + @trace(symbolic=True, capture_as_const=True) + def fwd(inp): + return F.img_proc.cvt_color(inp, mode="RGB2GRAY") + + result = fwd(x) + check_pygraph_dump(fwd, [x], [result]) diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index cabb61c966433eebf2947f25e04c5fa22c589c17..cc29f84be3e4efb4cb64ea9d544e92af69d6e26f 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -17,9 +17,20 @@ #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/dnn/adaptive_pooling.h" +#include "megbrain/opr/dnn/roi_pooling.h" +#include "megbrain/opr/dnn/roi_align.h" #include "megbrain/opr/imgproc.h" +#include "megbrain/opr/standalone/nms_opr.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/rand.h" +#include "megbrain/opr/dnn/batch_norm.h" +#include "megbrain/opr/misc.h" +#include "megbrain/opr/indexing.h" +#include "megbrain/opr/internal/indexing_helper.h" +#include "megbrain/opr/nn_int.h" +#include "megbrain/opr/tensor_gen.h" #if MGB_ENABLE_JSON #include "megdnn/opr_param_json.h" #endif @@ -354,7 +365,7 @@ uint64_t opr_footprint_func( auto&& out_shape = opr->output()[0]->shape(); auto&& filter_shape = opr->input()[1]->shape(); using Param = opr::DeformableConvForward::Param; - auto&& param = opr->cast_final_safe().param(); + auto&& param = opr->cast_final_safe().param(); size_t fh, fw, icpg; mgb_assert(param.format == Param::Format::NCHW); if (param.sparse == Param::Sparse::GROUP) { @@ -425,9 +436,11 @@ uint64_t opr_footprint_func( auto&& filter_shape = opr->input()[1]->shape(); using Param = opr::BatchConvBiasForward::Param; auto&& param = opr->cast_final_safe().param(); - mgb_assert(param.format == Param::Format::NCHW4); - size_t packed_channels = 4; + size_t packed_channels = 1; size_t kern_spatial_pos = 3; + if (param.format == Param::Format::NCHW4) { + packed_channels = 4; + } size_t fh = filter_shape[kern_spatial_pos], fw = filter_shape[kern_spatial_pos + 1]; return out_shape.total_nr_elems() * fh * fw * src_shape[1] * @@ -508,7 +521,29 @@ REGISTE_PARAM_JSON_FUNC(LocalShareBackwardFilter) REGISTE_PARAM_JSON_FUNC(DeformableConvForward) REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter) REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData) +REGISTE_PARAM_JSON_FUNC(DeformablePSROIPoolingForward) REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward) +REGISTE_PARAM_JSON_FUNC(BatchNormForward) +REGISTE_PARAM_JSON_FUNC(ElemwiseMultiType) +REGISTE_PARAM_JSON_FUNC(Argsort) +REGISTE_PARAM_JSON_FUNC(Argmax) +REGISTE_PARAM_JSON_FUNC(Argmin) +REGISTE_PARAM_JSON_FUNC(AdaptivePooling) +REGISTE_PARAM_JSON_FUNC(ROIPooling) +REGISTE_PARAM_JSON_FUNC(ROIAlign) +REGISTE_PARAM_JSON_FUNC(WarpPerspective) +REGISTE_PARAM_JSON_FUNC(WarpAffine) +REGISTE_PARAM_JSON_FUNC(Remap) +REGISTE_PARAM_JSON_FUNC(Resize) +REGISTE_PARAM_JSON_FUNC(IndexingOneHot) +REGISTE_PARAM_JSON_FUNC(IndexingSetOneHot) +REGISTE_PARAM_JSON_FUNC(TopK) +REGISTE_PARAM_JSON_FUNC(UniformRNG) +REGISTE_PARAM_JSON_FUNC(GaussianRNG) +REGISTE_PARAM_JSON_FUNC(Linspace) +REGISTE_PARAM_JSON_FUNC(Eye) +REGISTE_PARAM_JSON_FUNC(CvtColor) + template <> std::shared_ptr opr_param_json_func( @@ -547,24 +582,83 @@ std::shared_ptr opr_param_json_func( }); } +std::shared_ptr indexing_param_to_json( + const std::vector& indices) { + auto desc = json::Array::make(); + for (auto& index : indices) { + desc->add(json::Object::make({ + {"axis", json::NumberInt::make(index.axis.get_raw())}, + {"begin", + json::NumberInt::make(index.begin.node() != nullptr)}, + {"end", json::NumberInt::make(index.end.node() != nullptr)}, + {"step", + json::NumberInt::make(index.step.node() != nullptr)}, + {"idx", json::NumberInt::make(index.idx.node() != nullptr)}, + })); + } + return desc; +} + +#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ + template <> \ + std::shared_ptr opr_param_json_func( \ + cg::OperatorNodeBase * opr) { \ + auto indices = opr->cast_final_safe().index_desc(); \ + return indexing_param_to_json(indices); \ + } + +REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); +REGISTE_INDEXING_PARAM_JSON_FUNC(SetSubtensor); +REGISTE_INDEXING_PARAM_JSON_FUNC(IncrSubtensor); +REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingMultiAxisVec); +REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingSetMultiAxisVec); +REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingIncrMultiAxisVec); +REGISTE_INDEXING_PARAM_JSON_FUNC(MeshIndexing); +REGISTE_INDEXING_PARAM_JSON_FUNC(IncrMeshIndexing); +REGISTE_INDEXING_PARAM_JSON_FUNC(SetMeshIndexing); +REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); +REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); +REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); + template <> -std::shared_ptr opr_param_json_func( +std::shared_ptr opr_param_json_func( cg::OperatorNodeBase * opr) { auto desc = json::Array::make(); - auto indices = opr->cast_final_safe().index_desc(); - for (auto &index : indices){ - desc->add( - json::Object::make({ - {"axis", json::NumberInt::make(index.axis.get_raw())}, - {"begin", json::NumberInt::make(index.begin.node() != nullptr)}, - {"end", json::NumberInt::make(index.end.node() != nullptr)}, - {"step", json::NumberInt::make(index.step.node() != nullptr)}, - {"idx", json::NumberInt::make(index.idx.node() != nullptr)}, - })); + auto axis_param = opr->cast_final_safe().param(); + if (axis_param.axis != axis_param.MAX_NDIM){ + return json::Object::make({ + {"axis", json::NumberInt::make(axis_param.axis)}, + }); + } else { + return json::Object::make(); } + } - return desc; +template <> +std::shared_ptr opr_param_json_func( + cg::OperatorNodeBase * opr) { + auto desc = json::Array::make(); + auto axis_param = opr->cast_final_safe().param(); + if (axis_param.axis != axis_param.MAX_NDIM){ + return json::Object::make({ + {"axis", json::NumberInt::make(axis_param.axis)}, + }); + } else { + return json::Object::make(); + } } + +template <> +std::shared_ptr opr_param_json_func( + cg::OperatorNodeBase * opr) { + auto nms_param = opr->cast_final_safe().param(); + return json::Object::make({ + {"iou_thresh", json::Number::make(nms_param.iou_thresh)}, + {"max_output", json::Number::make(nms_param.max_output)}, + }); + } + + #endif // MGB_ENABLE_JSON } // namespace @@ -632,6 +726,17 @@ void OprFootprint::init_all_footprints() { add_single_param_json(); add_single_param_json(); add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); add_single_param_json(); add_single_param_json(); add_single_param_json(); @@ -639,7 +744,31 @@ void OprFootprint::init_all_footprints() { add_single_param_json(); add_single_param_json(); add_single_param_json(); + add_single_param_json(); add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); + add_single_param_json(); #endif }