diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py deleted file mode 100644 index 9ae1e9f6998ce38f5944bbc45978b40819050291..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ /dev/null @@ -1,173 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import functools -import itertools - -import numpy as np - -from .._imperative_rt import TensorAttr, imperative -from .._imperative_rt.core2 import apply -from ..ops.builtin import ( - Broadcast, - Elemwise, - GetVarShape, - IndexingMultiAxisVec, - IndexingSetMultiAxisVec, - OpDef, - Reduce, - Reshape, - SetSubtensor, - Subtensor, -) -from ..ops.special import Const - - -def default_grad_fn(op, inputs, outputs, input_requires_grad): - def get_tensor_attr(x): - attr = TensorAttr() - attr.dtype = x.dtype - attr.comp_node = x.device.to_c() - return attr - - output_has_grads = [True,] * len(outputs) - result = imperative.make_backward_graph( - op, list(map(get_tensor_attr, inputs)), input_requires_grad, output_has_grads - ) - if result is None: - nr_inputs = len(inputs) - nr_outputs = len(outputs) - - def backward(*args): - return nr_inputs * [ - None, - ] - - return backward, nr_outputs * [False,] - backward_graph, save_for_backward_mask, input_has_grad = result - - intput_output_mask = save_for_backward_mask[: len(inputs + outputs) :] - output_grad_mask = save_for_backward_mask[len(inputs + outputs) :] - save_for_backward = tuple( - val for val, mask in zip(inputs + outputs, intput_output_mask) if mask - ) - - del inputs - del outputs - - def backward(*args): - output_grads = tuple(val for val, mask in zip(args, output_grad_mask) if mask) - assert None not in output_grads - ret = iter(apply(backward_graph, *(save_for_backward + output_grads))) - return tuple(next(ret) if mask else None for mask in input_has_grad) - - return backward, output_grad_mask - - -def get_shape(x): - (s,) = apply(GetVarShape(), x._data) - return Tensor(s) - - -# override for Elemwise.add -def elemwise_add_grad_fn(op, inputs, outputs, input_requires_grad): - assert len(inputs) == len(input_requires_grad) == 2 - - input_shapes = [ - get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) - ] - - def reduce_to(x, s): - (y,) = apply(Reduce(), x, s) - return y - - def backward(dy): - return tuple( - reduce_to(dy, s) if i else None - for i, s in zip(input_requires_grad, input_shapes) - ) - - return backward, [True] - - -# override for Reshape -def reshape_grad_fn(op, inputs, outputs, input_requires_grad): - assert len(inputs) == len(input_requires_grad) == 2 - - input_shapes = [ - get_shape(x) if i else None for i, x in zip(input_requires_grad, inputs) - ] - - def reshape_to(dy, s): - (dx,) = apply(Reshape(), dy, s) - return dx - - def backward(dy): - return tuple( - reshape_to(dy, s) if i else None - for i, s in zip(input_requires_grad, input_shapes) - ) - - return backward, [True] - - -# override for Subtensor -def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): - grad_op = SetSubtensor(op.items) - - input_shape = get_shape(inputs[0]) - params = inputs[1:] - - def make_grad(grad_op, dy): - (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy) - (grad,) = apply(Broadcast(), _z, input_shape) - (dx,) = apply(grad_op, grad, dy, *params) - return dx - - def backward(dy): - return tuple( - make_grad(grad_op, dy) if mask else None for mask in input_requires_grad - ) - - return backward, [True] - - -# override for IndexingMultiAxisVec -def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): - grad_op = IndexingSetMultiAxisVec(op.items) - - input_shape = get_shape(inputs[0]) - params = inputs[1:] - - def make_grad(grad_op, dy): - (_z,) = Const(0, dtype=dy.dtype, device=dy.device)(dy) - (grad,) = apply(Broadcast(), _z, input_shape) - (dx,) = apply(grad_op, grad, dy, *params) - return dx - - def backward(dy): - return tuple( - make_grad(grad_op, dy) if mask else None for mask in input_requires_grad - ) - - return backward, [True] - - -# override for Reduce.sum -def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): - assert len(inputs) == len(input_requires_grad) == 1 - input_shape = get_shape(inputs[0]) - - def broadcast_to(dy, s): - (dx,) = apply(Broadcast(), dy, s) - return dx - - def backward(dy): - return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) - - return backward, [True] diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index bf5ab0da92c00c36eb874eeb6d64b17a080f905c..5f887780a11316e0ae9aca7eb8e06616fe885748 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -19,7 +19,6 @@ import megengine as mge from .._imperative_rt import core2, ops from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const -from . import builtin_op_utils """ Some notes: 1. Initialize the optimizer: diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 46fa7269a8c6844131323ba91748cb9d2ea5eb28..9691290d3f82080216f80c86e059cc646ca13976 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -25,6 +25,25 @@ std::shared_ptr reduce_to(Tensor* x, Tensor* s) { return python::apply(op, x, s)[0]; } +std::shared_ptr reshape_to(Tensor* x, Tensor* s) { + static auto op = Reshape::make(); + return python::apply(op, x, s)[0]; +} + +std::shared_ptr broadcast_to(Tensor* x, Tensor* s) { + static auto op = Broadcast::make(); + return python::apply(op, x, s)[0]; +} + +std::shared_ptr make_tensor(CompNode cn, Tensor* shape, float v = 0) { + HostTensorND scalar{cn, {{1}, dtype::Float32()}}; + scalar.ptr()[0] = v; + interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar); + auto&& t = std::make_shared(handle); + auto&& res = broadcast_to(t.get(), shape); + return res; +} + apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto& op = ctx.op->cast_final_safe(); if (op.mode == Elemwise::Mode::ADD) { @@ -52,10 +71,138 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make throw GradRuleFallback(); } +apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + mgb_assert(ctx.nargs == 2); + std::array, 2> input_shapes; + for (size_t i = 0; i < 2; ++i) { + if (input_requires_grad(ctx, i)) { + input_shapes[i] = get_shape(ctx.args[i]); + } + } + maker.output_size(1).output_captured(0, false); + maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + Tensor* grad = grads[0]; + apply_result_t ret(2); + for (size_t i = 0; i < 2; ++i) { + if (shapes[i]) { + ret[i] = reshape_to(grad, shapes[i].get()); + } + } + return ret; + }); + return apply(ctx); +} + +apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto&& op = ctx.op->cast_final_safe(); + auto&& grad_op = SetSubtensor::make(op.items); + SmallVector> inputs; + if (input_requires_grad(ctx, 0)) { + inputs.push_back(get_shape(ctx.args[0])); + for (size_t i = 1; i < ctx.nargs; ++i) { + inputs.push_back(ctx.args[i]->copy()); + } + } + maker.output_size(1).output_captured(0, false); + maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + apply_result_t ret(1); + if (inputs[0]) { + SmallVector args_(inputs.size()+1); + Tensor* grad = grads[0]; + auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); + args_[0] = zeros.get(); + args_[1] = grad; + for (size_t i = 1; i < inputs.size(); ++i) { + args_[i+1] = inputs[i].get(); + } + ret[0] = python::apply(grad_op_, args_)[0]; + } + return ret; + }); + return apply(ctx); +} + +apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto&& op = ctx.op->cast_final_safe(); + auto&& grad_op = IndexingSetMultiAxisVec::make(op.items); + SmallVector> inputs; + if (input_requires_grad(ctx, 0)) { + inputs.push_back(get_shape(ctx.args[0])); + for (size_t i = 1; i < ctx.nargs; ++i) { + inputs.push_back(ctx.args[i]->copy()); + } + } + maker.output_size(1).output_captured(0, false); + maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + apply_result_t ret(1); + if (inputs[0]) { + SmallVector args_(inputs.size()+1); + Tensor* grad = grads[0]; + auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); + args_[0] = zeros.get(); + args_[1] = grad; + for (size_t i = 1; i < inputs.size(); ++i) { + args_[i+1] = inputs[i].get(); + } + ret[0] = python::apply(grad_op_, args_)[0]; + } + return ret; + }); + return apply(ctx); +} + +apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto& op = ctx.op->cast_final_safe(); + if (op.mode == Reduce::Mode::SUM) { + mgb_assert(ctx.nargs == 1); + std::array, 1> input_shapes; + if (input_requires_grad(ctx, 0)) { + input_shapes[0] = get_shape(ctx.args[0]); + } + maker.output_size(1).output_captured(0, false); + maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + Tensor* grad = grads[0]; + apply_result_t ret(1); + if (shapes[0]) { + ret[0] = broadcast_to(grad, shapes[0].get()); + } + return ret; + }); + return apply(ctx); + } + throw GradRuleFallback(); +} + +template +apply_result_t axisAddRemove_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { + auto&& op = ctx.op->cast_final_safe(); + mgb_assert(ctx.nargs == 1); + auto&& grad_op = U::make(op.axis); + maker.output_size(1).output_captured(0, false); + maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { + mgb_assert(ngrads == 1); + Tensor* grad = grads[0]; + apply_result_t ret(1); + ret[0] = python::apply(grad_op_, grad)[0]; + return ret; + }); + return apply(ctx); +} + struct Init { Init() { auto& reg = grad_rule_registry(); reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); + reg.emplace(Reshape::typeinfo(), reshape_grad_rule); + reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule); + reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); + reg.emplace(Reduce::typeinfo(), reduce_grad_rule); + reg.emplace(AddAxis::typeinfo(), axisAddRemove_grad_rule); + reg.emplace(RemoveAxis::typeinfo(), axisAddRemove_grad_rule); } } _;