megbrain_graph.py 12.7 KB
Newer Older
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
10
import json
11 12 13 14
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor

15 16
import numpy as np

17
from .. import _imperative_rt
18
from .._imperative_rt import GraphOptimizeOptions
19
from .._imperative_rt.ops import BackwardGraph
20 21 22 23 24
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import OpBase, TensorBase, apply


M
Megvii Engine Team 已提交
25 26 27 28 29 30 31
class Graph(_imperative_rt.ComputingGraph):
    def __init__(self):
        super().__init__()
        self._var_cache = weakref.WeakKeyDictionary()
        self._op_cache = weakref.WeakKeyDictionary()
        self._executor = ThreadPoolExecutor(1)
        self._function = None
32 33
        self._future = None

M
Megvii Engine Team 已提交
34 35 36 37 38
    def _wrap(self, obj):
        if type(obj) is _imperative_rt.VarNode:
            wrapper, cache = VarNode, self._var_cache
        elif type(obj) is _imperative_rt.OperatorNode:
            wrapper, cache = OpNode, self._op_cache
39 40
        else:
            raise TypeError(type(obj))
M
Megvii Engine Team 已提交
41 42 43 44 45 46 47 48
        if obj not in cache:
            cache[obj] = wrapper(obj)
        return cache[obj]

    def compile(self, *args):
        self._function = super().compile(_unwrap(args))
        return self

49 50
    def execute(self, *args):
        assert self._future is None
M
Megvii Engine Team 已提交
51
        self._future = self._executor.submit(self._function.execute, *args)
52 53 54 55 56 57 58 59 60 61 62 63 64 65

    def wait(self):
        assert self._future is not None
        self._future.exception()
        self._function.wait()
        try:
            return self._future.result()
        finally:
            self._future = None

    def __call__(self, *args):
        self.execute(*args)
        return self.wait()

M
Megvii Engine Team 已提交
66 67 68 69 70
    def make_const(self, data, dtype=None, device=None):
        if isinstance(data, _imperative_rt.DeviceTensorND):
            assert dtype is None and device is None
            return self._wrap(_imperative_rt.make_shared(self, data))
        else:
71 72 73 74 75
            data = np.asarray(data, dtype=dtype)
            if data.dtype == np.float64:
                data = data.astype(np.float32)
            elif data.dtype == np.int64:
                data = data.astype(np.int32)
M
Megvii Engine Team 已提交
76 77
            device = as_device(device).to_c()
            return self._wrap(_imperative_rt.make_const(self, data, device, dtype))
78

M
Megvii Engine Team 已提交
79 80 81
    def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None):
        opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
        return opnode.outputs[0]
82

83
    def make_h2d(self, *, dtype, device, shape=None, name=None):
84
        device = as_device(device).to_c()
85
        return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name))
86 87


88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
def optimize_for_inference(dest_vars, **kwargs):
    r"""Applies optimize_for_inference pass for computing graph.

        :param dest_vars: list of output vars in the computing graph

        :Keyword Arguments:

            * enable_io16xc32 --
                whether to use float16 for I/O between oprs and use
                float32 as internal computation precision. Note the output var would be
                changed to float16.
            * enable_ioc16 --
                whether to use float16 for both I/O and computation
                precision.

            * enable_hwcd4 --
                whether to use NHWCD4 data layout. This is faster on some
                OpenCL backend.
            * enable_nchw88 --
                whether to use NCHW88 data layout, currently
                used in X86 AVX backend.
            * enable_nchw44 --
                whether to use NCHW44 data layout, currently
                used in arm backend.
            * enable_nchw44_dot --
                whether to use NCHW44_dot data layout, currently
                used in armv8.2+dotprod backend.
            * enable_nchw4 --
                whether to use NCHW4 data layout, currently
                used in nvidia backend(based on cudnn).
            * enable_nchw32 --
                whether to use NCHW32 data layout, currently
                used in nvidia backend with tensorcore(based on cudnn).
            * enable_chwn4 --
                whether to use CHWN4 data layout, currently
                used in nvidia backend with tensorcore.

            * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
                into one opr.
            * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
                input for inference on nvidia backend(this optimization pass will
                result in mismatch of the precision of output of training and
                inference)
    """
    inference_options = GraphOptimizeOptions()
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    inference_optimize_layout_transform_map = {
        "enable_hwcd4": GraphOptimizeOptions.LayoutTransform.NHWCD4,
        "enable_nchw4": GraphOptimizeOptions.LayoutTransform.NCHW4,
        "enable_nchw88": GraphOptimizeOptions.LayoutTransform.NCHW88,
        "enable_nchw32": GraphOptimizeOptions.LayoutTransform.NCHW32,
        "enable_nchw44": GraphOptimizeOptions.LayoutTransform.NCHW44,
        "enable_nchw44_dot": GraphOptimizeOptions.LayoutTransform.NCHW44_DOT,
        "enable_chwn4": GraphOptimizeOptions.LayoutTransform.CHWN4,
    }

    for k, v in inference_optimize_layout_transform_map.items():
        if kwargs.pop(k, False):
            inference_options.layout_transform = v

    if kwargs.pop("enable_io16xc32", False):
        inference_options.f16_io_f32_comp = True
    if kwargs.pop("enable_ioc16", False):
        inference_options.f16_io_comp = True
    if kwargs.pop("enable_fuse_conv_bias_nonlinearity", False):
        inference_options.fuse_conv_bias_nonlinearity = True
    if kwargs.pop("enable_fuse_conv_bias_with_z", False):
        inference_options.fuse_conv_bias_with_z = True

    if kwargs:
        raise ValueError("unknown options: %s" % list(kwargs))
158 159 160 161 162 163 164

    res_vars = _imperative_rt.optimize_for_inference(
        [i._node for i in dest_vars], inference_options
    )
    return [VarNode(i) for i in res_vars]


165
def dump_graph(*args):
166 167
    return _imperative_rt.dump_graph([i._node for i in args])

168

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
CompGraphLoadResult = collections.namedtuple(
    "CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
)


def load_graph(fpath):
    """Load a serialized computing graph from file.

    :parma fpath: Path or Handle for the output file
    :return: An instance of namedtuple :class:`CompGraphLoadResult`,
        whose fields are:

            * ``graph`` loaded CompGraph
            * ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
            * ``output_vars_list`` A Python list, containing output vars in the
                                   order passed to serialize_comp_graph_to_file
    """
    output_vars_map = []
    output_vars_list = []
    if isinstance(fpath, str):
        buf = open(fpath, "rb").read()
    else:
        buf = fpath.read()
    cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
    return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list)


196 197 198
class VarNode(TensorBase):
    def __init__(self, node: _imperative_rt.VarNode):
        self._node = node
199 200
        if hasattr(self.graph, "_var_cache"):
            self.graph._var_cache[node] = self
201 202 203 204 205 206 207

    @property
    def graph(self) -> Graph:
        return self._node.graph

    @property
    def op(self):
208 209 210 211
        if hasattr(self.graph, "_wrap"):
            return self.graph._wrap(self._node.owner)
        else:
            return self._node.owner
212

213 214 215 216
    @property
    def name(self):
        return self._node.name

217 218 219 220
    @property
    def id(self):
        return self._node.id

221 222 223 224
    @name.setter
    def name(self, name):
        self._node.name = name

225 226 227 228 229 230 231 232
    @property
    def dtype(self):
        return self._node.dtype

    @property
    def device(self):
        return as_device(self._node.comp_node)

M
Megvii Engine Team 已提交
233 234 235 236
    @property
    def shape(self):
        return self._node.shape

237 238 239 240
    @property
    def value(self):
        return self._node.value

241 242 243 244

class OpNode:
    def __init__(self, node: _imperative_rt.OperatorNode):
        self._node = node
245 246
        if hasattr(self.graph, "_op_cache"):
            self.graph._op_cache[node] = self
247 248 249 250 251

    @property
    def graph(self) -> Graph:
        return self._node.graph

252 253 254 255
    @property
    def name(self):
        return self._node.name

256 257 258 259
    @property
    def id(self):
        return self._node.id

260 261 262 263
    @name.setter
    def name(self, name):
        self._node.name = name

264 265
    @property
    def inputs(self):
266 267 268 269
        if hasattr(self.graph, "_wrap"):
            return tuple(map(self.graph._wrap, self._node.inputs))
        else:
            return self._node.inputs
270 271 272

    @property
    def outputs(self):
273 274 275 276 277 278 279 280 281 282 283 284
        if hasattr(self.graph, "_wrap"):
            return tuple(map(self.graph._wrap, self._node.outputs))
        else:
            return self._node.outputs

    @property
    def params(self):
        return json.loads(self._node.params)

    @property
    def type(self):
        return self._node.type
285 286 287


def _wrap(x):
M
Megvii Engine Team 已提交
288
    if isinstance(x, collections.abc.Sequence):
289
        return type(x)(map(_wrap, x))
290 291 292 293
    if hasattr(x.graph, "_wrap"):
        return x.graph._wrap(x)
    else:
        return x
294 295 296


def _unwrap(x):
M
Megvii Engine Team 已提交
297
    if isinstance(x, collections.abc.Sequence):
298
        return type(x)(map(_unwrap, x))
299 300 301 302
    if isinstance(x, VarNode):
        return x._node
    else:
        return x
303 304


305
@apply.register()
306 307 308 309 310
def _(op: OpDef, *args: VarNode):
    outputs = _imperative_rt.invoke_op(op, _unwrap(args))
    return _wrap(outputs)


311 312 313 314 315 316 317
@apply.register()
def _(op: BackwardGraph, *args: VarNode):
    assert args
    graph = args[0].graph
    return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args)


M
Megvii Engine Team 已提交
318
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
319
    outputs = _imperative_rt.input_callback(
M
Megvii Engine Team 已提交
320
        callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph
321 322 323 324 325 326
    )
    value, dummy = _wrap(outputs)
    return value, dummy


class InputNode(OpNode):
M
Megvii Engine Team 已提交
327
    def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None):
328 329 330 331
        r = _imperative_rt.DeviceTensorNDRendezvous()
        if device is not None:
            device = as_device(device).to_c()
        outputs = _imperative_rt.input_callback(
M
Megvii Engine Team 已提交
332
            r, device, dtype, shape, _unwrap(args), graph=graph
333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
        )
        super().__init__(outputs[0].owner)
        self._rendezvous = r

    def set_value(self, value):
        assert isinstance(value, _imperative_rt.DeviceTensorND)
        self._rendezvous.set(value)

    def reset(self):
        self._rendezvous.reset()

    @property
    def device(self):
        return self.outputs[0].device

    @property
    def dtype(self):
        return self.outputs[0].dtype


def output_callback(callback, var, *args):
    args = (var,) + args
    dummy = _imperative_rt.output_callback(callback, _unwrap(args))
    return _wrap(dummy)


class OutputNode(OpNode):
    def __init__(self, var, *args):
        args = (var,) + args
        r = _imperative_rt.DeviceTensorNDRendezvous()
        dummy = _imperative_rt.output_callback(r, _unwrap(args))
        super().__init__(dummy.owner)
        self._rendezvous = r

    def get_value(self):
        return self._rendezvous.get()

M
Megvii Engine Team 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    def drop_value(self):
        self._rendezvous.drop()

    def reset(self):
        self._rendezvous.reset()


class ValueOutputNode(OpNode):
    def __init__(self, var, *args):
        args = (var,) + args
        r = _imperative_rt.HostTensorNDRendezvous()
        dummy = _imperative_rt.value_output_callback(r, _unwrap(args))
        super().__init__(dummy.owner)
        self._rendezvous = r

    def get_value(self):
        hostnd, event = self._rendezvous.get()
        event.wait()
        return hostnd.numpy()

    def drop_value(self):
        self._rendezvous.drop()

393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
    def reset(self):
        self._rendezvous.reset()


class TensorAttr:
    def __init__(self, shape, dtype, device):
        self.shape = shape
        self.dtype = dtype
        self.device = device


class AttrOutputNode(OpNode):
    def __init__(self, var, *args):
        args = (var,) + args
        r = _imperative_rt.TensorAttrRendezvous()
        dummy = _imperative_rt.attr_output_callback(r, _unwrap(args))
        super().__init__(dummy.owner)
        self._rendezvous = r

    def get_value(self):
        attr = self._rendezvous.get()
        return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node))

M
Megvii Engine Team 已提交
416 417 418
    def drop_value(self):
        self._rendezvous.drop()

419 420
    def reset(self):
        self._rendezvous.reset()