grad_override.cpp 1.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
/**
 * \file imperative/python/src/grad_override.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"

namespace mgb::imperative::python {
namespace {

std::shared_ptr<Tensor> get_shape(Tensor* x) {
    static auto op = GetVarShape::make();
    return python::apply(op, x)[0];
}

std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
    static auto op = Reduce::make();
    return python::apply(op, x, s)[0];
}

apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto& op = ctx.op->cast_final_safe<Elemwise>();
    if (op.mode == Elemwise::Mode::ADD) {
        mgb_assert(ctx.nargs == 2);
        std::array<std::shared_ptr<Tensor>, 2> input_shapes;
        for (size_t i = 0; i < 2; ++i) {
            if (input_requires_grad(ctx, i)) {
                input_shapes[i] = get_shape(ctx.args[i]);
            }
        }
        maker.output_size(1).output_captured(0, false);
        maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
            mgb_assert(ngrads == 1);
            Tensor* grad = grads[0];
            apply_result_t ret(2);
            for (size_t i = 0; i < 2; ++i) {
                if (shapes[i]) {
                    ret[i] = reduce_to(grad, shapes[i].get());
                }
            }
            return ret;
        });
        return apply(ctx);
    }
    throw GradRuleFallback();
}

struct Init {
    Init() {
        auto& reg = grad_rule_registry();
        reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
    }
} _;

} // namespace
} // namespace mgb::imperative::python