grad_override.cpp 13.2 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/grad_override.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
14
#include "megbrain/imperative/transformations/grad.h"
15 16

namespace mgb::imperative::python {
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

class CustomGradMaker {
    bool output_size_set = false, input_has_grad_initialized = false;
    CustomBackward& target;
    size_t nr_inputs;
    void init_input_has_grad() {
        if (!input_has_grad_initialized) {
            input_has_grad_initialized = true;
            target.m_input_has_grad.resize(nr_inputs, true);
        }
    }

public:
    CustomGradMaker(CustomBackward& target, size_t nr_inputs)
            : target(target), nr_inputs(nr_inputs) {}

    CustomGradMaker& backward(CustomBackward::BackwardFn f) {
        mgb_assert(!target.m_backward);
        target.m_backward = f;
        return *this;
    }
    // mandatory
    CustomGradMaker& output_size(size_t sz) {
        mgb_assert(!output_size_set);
        output_size_set = true;
        target.m_output_attrs.resize(sz);
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& input_has_grad(size_t i, bool v) {
        init_input_has_grad();
        target.m_input_has_grad.at(i) = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_requires_grad(size_t i, bool v) {
        target.m_output_attrs.at(i).requires_grad = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_captured(size_t i, bool v) {
        target.m_output_attrs.at(i).captured = v;
        return *this;
    }
    void finalize() {
        mgb_assert(output_size_set);
        init_input_has_grad();
    }
};

67 68
namespace {

69
ValueRef get_shape(ValueRef x) {
70
    static auto op = GetVarShape::make();
71
    return imperative::apply(*op, x)[0];
72 73
}

74
ValueRef reduce_to(ValueRef x, ValueRef s) {
75
    static auto op = Reduce::make();
76
    return imperative::apply(*op, x, s)[0];
77 78
}

79
ValueRef reshape_to(ValueRef x, ValueRef s) {
80
    static auto op = Reshape::make();
81
    return imperative::apply(*op, x, s)[0];
82 83
}

84
ValueRef broadcast_to(ValueRef x, ValueRef s) {
85
    static auto op = Broadcast::make();
86
    return imperative::apply(*op, x, s)[0];
87 88
}

89 90 91 92 93 94 95 96 97
ValueRef make_empty_tensor(
        CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
    HostTensorStorage storage(*device);
    storage.ensure_size(dtype->size());
    std::memset(storage.ptr(), 0, dtype->size());
    auto t = imperative::apply(
            CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
            HostStorage::make(storage))[0];
    auto res = broadcast_to(t, shape);
98 99 100
    return res;
}

101
std::optional<ValueRefList> elemwise_grad_rule(
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& elemwise = op.cast_final_safe<Elemwise>();
    if (elemwise.mode != Elemwise::Mode::ADD) {
        return {};
    }
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
120
        SmallVector<ValueRef> ret(2);
121 122 123
        if (!grad) {
            return ret;
        }
124
        for (size_t i = 0; i < 2; ++i) {
125 126
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
127 128
            }
        }
129 130 131 132
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
133 134
}

135
std::optional<ValueRefList> reshape_grad_rule(
136 137
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
138 139
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
140
    std::array<ValueRef, 2> input_shapes;
141
    for (size_t i = 0; i < nr_inp; ++i) {
142 143
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
144 145
        }
    }
146
    auto maker = CustomGradMaker(backward, inputs.size());
147
    maker.output_size(1).output_captured(0, false);
148
    maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
149 150
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
151
        SmallVector<ValueRef> ret(nr_inp);
152 153 154
        if (!grad) {
            return ret;
        }
155
        for (size_t i = 0; i < nr_inp; ++i) {
156
            if (shapes[i]) {
157
                ret[i] = reshape_to(grad, shapes[i]);
158 159 160 161
            }
        }
        return ret;
    });
162 163
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
164 165
}

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
std::optional<ValueRefList> broadcast_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < nr_inp; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(nr_inp);
        if (!grad) {
            return ret;
        }
        for (size_t i = 0; i < nr_inp; ++i) {
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
            }
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
}

197
std::optional<ValueRefList> subtensor_grad_rule(
198 199 200 201 202 203 204 205 206
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& subtensor = op.cast_final_safe<Subtensor>();
    auto&& grad_op = SetSubtensor::make(subtensor.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
207 208
        }
    }
209
    auto maker = CustomGradMaker(backward, inputs.size());
210
    maker.output_size(1).output_captured(0, false);
211 212 213 214
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
215
        SmallVector<ValueRef> ret(1);
216
        if (grad && inputs[0]) {
217
            ValueRefList args_(inputs.size() + 1);
218 219
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
220 221
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
222
                args_[i + 1] = inputs[i];
223
            }
224
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
225 226 227
        }
        return ret;
    });
228 229
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
230 231
}

232
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
233 234 235 236 237 238 239 240 241
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
    auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
242 243
        }
    }
244
    auto maker = CustomGradMaker(backward, inputs.size());
245
    maker.output_size(1).output_captured(0, false);
246 247 248 249
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
250
        SmallVector<ValueRef> ret(1);
251
        if (grad && inputs[0]) {
252
            ValueRefList args_(inputs.size() + 1);
253 254
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
255 256
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
257
                args_[i + 1] = inputs[i];
258
            }
259
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
260 261 262
        }
        return ret;
    });
263 264
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
265 266
}

267
std::optional<ValueRefList> reduce_grad_rule(
268 269 270 271 272 273 274 275 276 277 278 279
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& reduce = op.cast_final_safe<Reduce>();
    if (reduce.mode != Reduce::Mode::SUM) {
        return {};
    }
    if (inputs.size() != 1) {
        return {};
    }
    std::array<ValueRef, 1> input_shapes;
    if (inputs_require_grad[0]) {
        input_shapes[0] = get_shape(inputs[0]);
280
    }
281 282 283 284 285
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
286
        SmallVector<ValueRef> ret(1);
287 288 289 290 291 292 293
        if (grad && shapes[0]) {
            ret[0] = broadcast_to(grad, shapes[0]);
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
294 295
}

296
std::optional<ValueRefList> addAxis_grad_rule(
297 298 299 300 301 302
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& addAxis = op.cast_final_safe<AddAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = RemoveAxis::make(addAxis.axis);
303
    std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
304
    auto maker = CustomGradMaker(backward, inputs.size());
305
    maker.output_size(1).output_captured(0, false);
306 307 308
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
309
        SmallVector<ValueRef> ret(1);
310
        if (grad && flag_) {
311
            ret[0] = imperative::apply(*grad_op_, grad)[0];
312
        }
313 314
        return ret;
    });
315 316
    maker.finalize();
    return imperative::apply(op, inputs);
317 318
}

319
std::optional<ValueRefList> removeAxis_grad_rule(
320 321 322 323 324 325
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = AddAxis::make(removeAxis.axis);
326
    std::sort(grad_op->axis.begin(), grad_op->axis.end());
327
    auto maker = CustomGradMaker(backward, inputs.size());
328
    maker.output_size(1).output_captured(0, false);
329 330 331
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
332
        SmallVector<ValueRef> ret(1);
333
        if (grad && flag_) {
334
            ret[0] = imperative::apply(*grad_op_, grad)[0];
335
        }
336 337
        return ret;
    });
338 339
    maker.finalize();
    return imperative::apply(op, inputs);
340 341
}

342
std::optional<ValueRefList> fastpathcopy_grad_rule(
343 344 345 346
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1);
    auto maker = CustomGradMaker(backward, inputs.size());
347
    maker.output_size(1).output_captured(0, false);
348 349 350
    maker.backward([](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
351
        SmallVector<ValueRef> ret(1);
352
        if (grad) {
353
            ret[0] = grad;
354 355 356
        }
        return ret;
    });
357 358
    maker.finalize();
    return imperative::apply(op, inputs);
359 360
}

361 362
struct Init {
    Init() {
363 364
        CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
        CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
365
        CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule);
366 367 368 369 370 371 372 373 374
        CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
        CustomBackward::register_grad_rule(
                IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
        CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
        CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
        CustomBackward::register_grad_rule(
                RemoveAxis::typeinfo(), removeAxis_grad_rule);
        CustomBackward::register_grad_rule(
                FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
375 376 377
    }
} _;

M
Megvii Engine Team 已提交
378 379
}  // namespace
}  // namespace mgb::imperative::python