grad_override.cpp 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/**
 * \file imperative/python/src/grad_override.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"

namespace mgb::imperative::python {
namespace {

std::shared_ptr<Tensor> get_shape(Tensor* x) {
    static auto op = GetVarShape::make();
    return python::apply(op, x)[0];
}

std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
    static auto op = Reduce::make();
    return python::apply(op, x, s)[0];
}

28 29 30 31 32 33 34 35 36 37 38 39 40
std::shared_ptr<Tensor> reshape_to(Tensor* x, Tensor* s) {
    static auto op = Reshape::make();
    return python::apply(op, x, s)[0];
}

std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
    static auto op = Broadcast::make();
    return python::apply(op, x, s)[0];
}

std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
    HostTensorND scalar{cn, {{1}, dtype::Float32()}};
    scalar.ptr<float>()[0] = v;
41
    interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
42
    auto&& t = std::make_shared<Tensor>(handle);
43
    auto res = broadcast_to(t.get(), shape);
44 45 46
    return res;
}

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto& op = ctx.op->cast_final_safe<Elemwise>();
    if (op.mode == Elemwise::Mode::ADD) {
        mgb_assert(ctx.nargs == 2);
        std::array<std::shared_ptr<Tensor>, 2> input_shapes;
        for (size_t i = 0; i < 2; ++i) {
            if (input_requires_grad(ctx, i)) {
                input_shapes[i] = get_shape(ctx.args[i]);
            }
        }
        maker.output_size(1).output_captured(0, false);
        maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
            mgb_assert(ngrads == 1);
            Tensor* grad = grads[0];
            apply_result_t ret(2);
62 63 64
            if (!grad) {
                return ret;
            }
65 66 67 68 69 70 71 72 73 74 75 76
            for (size_t i = 0; i < 2; ++i) {
                if (shapes[i]) {
                    ret[i] = reduce_to(grad, shapes[i].get());
                }
            }
            return ret;
        });
        return apply(ctx);
    }
    throw GradRuleFallback();
}

77 78 79 80 81 82 83 84 85 86 87 88 89
apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    mgb_assert(ctx.nargs == 2);
    std::array<std::shared_ptr<Tensor>, 2> input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (input_requires_grad(ctx, i)) {
            input_shapes[i] = get_shape(ctx.args[i]);
        }
    }
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
        mgb_assert(ngrads == 1);
        Tensor* grad = grads[0];
        apply_result_t ret(2);
90 91 92
        if (!grad) {
            return ret;
        }
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        for (size_t i = 0; i < 2; ++i) {
            if (shapes[i]) {
                ret[i] = reshape_to(grad, shapes[i].get());
            }
        }
        return ret;
    });
    return apply(ctx);
}

apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto&& op = ctx.op->cast_final_safe<Subtensor>();
    auto&& grad_op = SetSubtensor::make(op.items);
    SmallVector<std::shared_ptr<Tensor>> inputs;
    if (input_requires_grad(ctx, 0)) {
        inputs.push_back(get_shape(ctx.args[0]));
        for (size_t i = 1; i < ctx.nargs; ++i) {
            inputs.push_back(ctx.args[i]->copy());
        }
    }
    maker.output_size(1).output_captured(0, false);
    maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
        mgb_assert(ngrads == 1);
116
        Tensor* grad = grads[0];
117
        apply_result_t ret(1);
118
        if (grad && inputs[0]) {
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
            SmallVector<Tensor*> args_(inputs.size()+1);
            auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
            args_[0] = zeros.get();
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
                args_[i+1] = inputs[i].get();
            }
            ret[0] = python::apply(grad_op_, args_)[0];
        }
        return ret;
    });
    return apply(ctx);
}

apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto&& op = ctx.op->cast_final_safe<IndexingMultiAxisVec>();
    auto&& grad_op = IndexingSetMultiAxisVec::make(op.items);
    SmallVector<std::shared_ptr<Tensor>> inputs;
    if (input_requires_grad(ctx, 0)) {
        inputs.push_back(get_shape(ctx.args[0]));
        for (size_t i = 1; i < ctx.nargs; ++i) {
            inputs.push_back(ctx.args[i]->copy());
        }
    }
    maker.output_size(1).output_captured(0, false);
    maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
        mgb_assert(ngrads == 1);
146
        Tensor* grad = grads[0];
147
        apply_result_t ret(1);
148
        if (grad && inputs[0]) {
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
            SmallVector<Tensor*> args_(inputs.size()+1);
            auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
            args_[0] = zeros.get();
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
                args_[i+1] = inputs[i].get();
            }
            ret[0] = python::apply(grad_op_, args_)[0];
        }
        return ret;
    });
    return apply(ctx);
}

apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto& op = ctx.op->cast_final_safe<Reduce>();
    if (op.mode == Reduce::Mode::SUM) {
        mgb_assert(ctx.nargs == 1);
        std::array<std::shared_ptr<Tensor>, 1> input_shapes;
        if (input_requires_grad(ctx, 0)) {
            input_shapes[0] = get_shape(ctx.args[0]);
        }
        maker.output_size(1).output_captured(0, false);
        maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
            mgb_assert(ngrads == 1);
            Tensor* grad = grads[0];
            apply_result_t ret(1);
176
            if (grad && shapes[0]) {
177 178 179 180 181 182 183 184 185
                ret[0] = broadcast_to(grad, shapes[0].get());
            }
            return ret;
        });
        return apply(ctx);
    }
    throw GradRuleFallback();
}

186 187
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto&& op = ctx.op->cast_final_safe<AddAxis>();
188
    mgb_assert(ctx.nargs == 1);
189
    bool flag = input_requires_grad(ctx, 0);
190 191 192
    auto&& grad_op = RemoveAxis::make(op.axis);
    std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
    maker.output_size(1).output_captured(0, false);
193
    maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
194 195 196
        mgb_assert(ngrads == 1);
        Tensor* grad = grads[0];
        apply_result_t ret(1);
197 198 199
        if (grad && flag_) {
            ret[0] = python::apply(grad_op_, grad)[0];
        }
200 201 202 203 204 205 206 207
        return ret;
    });
    return apply(ctx);
}

apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
    auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
    mgb_assert(ctx.nargs == 1);
208
    bool flag = input_requires_grad(ctx, 0);
209 210
    auto&& grad_op = AddAxis::make(op.axis);
    std::sort(grad_op->axis.begin(), grad_op->axis.end());
211
    maker.output_size(1).output_captured(0, false);
212
    maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
213 214 215
        mgb_assert(ngrads == 1);
        Tensor* grad = grads[0];
        apply_result_t ret(1);
216 217 218
        if (grad && flag_) {
            ret[0] = python::apply(grad_op_, grad)[0];
        }
219 220 221 222 223
        return ret;
    });
    return apply(ctx);
}

224 225 226 227
struct Init {
    Init() {
        auto& reg = grad_rule_registry();
        reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
228 229 230 231
        reg.emplace(Reshape::typeinfo(), reshape_grad_rule);
        reg.emplace(Subtensor::typeinfo(), subtensor_grad_rule);
        reg.emplace(IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
        reg.emplace(Reduce::typeinfo(), reduce_grad_rule);
232 233
        reg.emplace(AddAxis::typeinfo(), addAxis_grad_rule);
        reg.emplace(RemoveAxis::typeinfo(), removeAxis_grad_rule);
234 235 236 237 238
    }
} _;

} // namespace
} // namespace mgb::imperative::python