# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import heapq import itertools import typing import weakref import numpy as np from ..ops.builtin import Elemwise, OpDef from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.function import Function from ..tensor.tensor import Tensor, get_context from . import builtin_op_utils """ Some notes: 1. Initialize the optimizer: for each trainable parameter: call wrt(param, callback) Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data 2. Tracer has one member: node, which is a VariableNode 3. VariableNode has a OpNode member: opnode 4. OpNode has four members: a. id b. inputs, which is made of VariableNode c. outputs, which are weakref's to VariableNode d. backward: call back function e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist f. backward_allow_noinput: whether backward allow noinput """ _grad_count = 0 _grad_manager_dict = weakref.WeakValueDictionary() def get_grad_managers(): return [_grad_manager_dict[key] for key in _grad_manager_dict] def add(a, b): (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) return c def get_tensor(x): # use recursion to avoid infinite loop if isinstance(x, Tensor): return x try: x = x.__wrapped__ except AttributeError: raise TypeError(type(x)) return get_tensor(x) class Grad: def __init__(self, name=None): if name is None: global _grad_count self._name = "grad_" + str(_grad_count) _grad_count += 1 else: self._name = name assert self._name not in _grad_manager_dict, "grad manager name duplicated" _grad_manager_dict[self._name] = self # list of all x in partial(y) / partial(x) self.xs = [] # constains weak reference of all OpNode during forward # OpNode contains inputs, outputs and its backward # ops forms the computational graph self.ops = [] self._attached_tensors = weakref.WeakSet() self._enabled = True @property def name(self): return self._name def wrt(self, *args: Tensor, callback=None): """ Indicates the loss is a function of the input tensors (usually the net trainable parameters), i.e., d (loss) / d (Tensor) != 0 callback is used to perform additional operations after gradient is obtained in backward. e.g., copy the grad to a particular place A VariableNode will be created and saved in the tensor/s _extra_data slot. """ for x in map(get_tensor, args): v = self._new_variable(x, callback=callback) assert self not in x._extra_data x._extra_data[self] = Tracer(v) self.xs.append(v) return self def _new_variable(self, owner, opnode=None, callback=None): self._attached_tensors.add(owner) return VariableNode(self, owner, opnode=opnode, callback=callback) def _new_opnode(self, inputs, outputs): inputs = tuple(inputs) for i in inputs: assert i is None or isinstance(i, VariableNode) o = OpNode() o.inputs = inputs o.outputs = [] tracers = [] for i in outputs: assert isinstance(i, Tensor) v = self._new_variable(i, o) o.outputs.append(weakref.ref(v)) tracers.append(Tracer(v)) self.ops.append(weakref.ref(o)) return o, tracers def copy(self): raise NotImplementedError def __enter__(self): return self def _exit(self): """clear all resources""" self._enabled = False for o in self.ops: o = o() if o: o.clear() for i in self._attached_tensors: i._extra_data.pop(self, None) def __exit__(self, *_): self._exit() def __call__(self, ys, dys): """ Defines Grad(). :param ys: outputs of forward operators, e.g., the loss tensor :type ys: list of Tensor or TensorWrapperBase :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, e.g., one for the loss itself :type dys: list of Tensor or TensorWrapperBase """ assert self._enabled self._enabled = False def check_wrapper(): if isinstance(dys, TensorWrapperBase): return type(dys) if isinstance(dys, TensorBase): return assert isinstance(dys, (tuple, list)) for i in dys: if isinstance(i, TensorWrapperBase): return type(i) Wrapper = check_wrapper() def aslist(x): if isinstance(x, (Tensor, TensorWrapperBase)): x = [x] else: x = list(x) x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] for i in x: assert isinstance(i, Tensor) return x ys = aslist(ys) dys = aslist(dys) assert len(ys) == len(dys) ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] ys = [y for i, y in enumerate(ys) if i in ids] dys = [dy for i, dy in enumerate(dys) if i in ids] # ys is changed to a list of VariableNode which contains more information # such as OpNode, callback, etc. ys = [i._extra_data[self].node for i in ys] # NOTE: callback is called only if grad is not None # the OpNode sequence in backward op_seq = [] # VariableNode -> (i, j), where i is time stamp in backward, j means jth input last_written_to = {} def schedule(): reached = set(ys) # i is the time stamp in backward i = 0 for o in self.ops[::-1]: o = o() if o is None: continue if not o.has_grad_fn(o, reached): continue op_seq.append(o) for j, v in enumerate(o.inputs): reached.add(v) last_written_to[v] = i, j i += 1 schedule() # VariableNode -> Tensor cache = {} def initialize(): for y, dy in zip(ys, dys): cache[y] = dy if y not in last_written_to and y.callback: y.callback(y.owner(), dy) initialize() # NOTE: None is used to mark a node has been consumed for seqno, opnode in enumerate(op_seq): input_nodes = opnode.inputs output_nodes = [i() for i in opnode.outputs] backward = opnode.backward backward_allow_noinput = opnode.backward_allow_noinput opnode.clear() output_grads = [] for i in output_nodes: if i is not None: if i in cache: assert cache[i] is not None output_grads.append(cache[i]) else: output_grads.append(None) # read by backward, mark consumed cache[i] = None else: output_grads.append(None) if ( any([grad is not None for grad in output_grads]) or backward_allow_noinput ): input_grads = backward(*output_grads) else: input_grads = [None] * len(input_nodes) assert len(input_nodes) == len(input_grads) for i, (v, g) in enumerate(zip(input_nodes, input_grads)): if v is None: continue if v in cache: assert cache[v] if g is not None: cache[v] = add(cache[v], g) elif g is not None: cache[v] = g if last_written_to[v] == (seqno, i): if v.callback: v.callback( v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] ) if v.opnode is None: # won't read by backward, mark consumed cache[v] = None for v in cache.values(): assert v is None self._exit() def __del__(self): self._exit() class clearable: __cleared = False def __bool__(self): return not self.__cleared def clear(self): self.__dict__.clear() self.__cleared = True class OpNode(clearable): """ OpNode saves all the information to form the computational graph. """ def __init__(self): self.id = None self.inputs = None # Could be VariableNode self.outputs = None # Could be VariableNode self.backward = None self.has_grad_fn = None self.backward_allow_noinput = False class VariableNode(clearable): """ VariableNode saves OpNode and callback. FIXME!!! Explain manager and owner """ def __init__(self, manager, owner, opnode=None, callback=None): # manager is Grad type self.manager = weakref.ref(manager) # owner is Tensor type self.owner = weakref.ref(owner) self.opnode = opnode self.callback = callback class Tracer(clearable, TensorBase): def __init__(self, node=None): """ type(node) is VariableNode """ self.node = node @functools.singledispatch def check_backward_allow_noinput(op: OpDef): return False @functools.singledispatch def get_op_has_grad_fn(op: OpDef): assert 0 @get_op_has_grad_fn.register(OpDef) def _(op: OpDef): return default_has_grad_fn @get_op_has_grad_fn.register(Function) def _(op: Function): return default_has_grad_fn def default_has_grad_fn(opnode, reached): for v in opnode.outputs: if v() in reached: return True return False @apply.register() def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): args = tuple(i if isinstance(i, Tracer) else None for i in args) input_requires_grad = list(map(bool, args)) if not any(input_requires_grad): return ctx = get_context() manager = None assert len(ctx.inputs) == len(args) for i, j in zip(ctx.inputs, args): if j: j = j.node assert i is j.owner() if manager is None: manager = j.manager() assert manager else: assert manager is j.manager() if not manager._enabled: return opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) # register backward method # tuple of backward functions corresponding to dy / dx_i # None means y is not a function of x_i opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( op, ctx.inputs, ctx.outputs, input_requires_grad ) assert len(outputs) == len(output_need_grad) outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] opnode.backward_allow_noinput = check_backward_allow_noinput(op) opnode.has_grad_fn = get_op_has_grad_fn(op) return tuple(outputs) @apply.register() def _(op: Const, *_: typing.Optional[Tracer]): return None