grad_override.cpp 14.6 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/grad_override.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
14
#include "megbrain/imperative/transformations/grad.h"
15 16

namespace mgb::imperative::python {
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

class CustomGradMaker {
    bool output_size_set = false, input_has_grad_initialized = false;
    CustomBackward& target;
    size_t nr_inputs;
    void init_input_has_grad() {
        if (!input_has_grad_initialized) {
            input_has_grad_initialized = true;
            target.m_input_has_grad.resize(nr_inputs, true);
        }
    }

public:
    CustomGradMaker(CustomBackward& target, size_t nr_inputs)
            : target(target), nr_inputs(nr_inputs) {}

    CustomGradMaker& backward(CustomBackward::BackwardFn f) {
        mgb_assert(!target.m_backward);
        target.m_backward = f;
        return *this;
    }
    // mandatory
    CustomGradMaker& output_size(size_t sz) {
        mgb_assert(!output_size_set);
        output_size_set = true;
        target.m_output_attrs.resize(sz);
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& input_has_grad(size_t i, bool v) {
        init_input_has_grad();
        target.m_input_has_grad.at(i) = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_requires_grad(size_t i, bool v) {
        target.m_output_attrs.at(i).requires_grad = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_captured(size_t i, bool v) {
        target.m_output_attrs.at(i).captured = v;
        return *this;
    }
    void finalize() {
        mgb_assert(output_size_set);
        init_input_has_grad();
    }
};

67 68
namespace {

69
ValueRef get_shape(ValueRef x) {
70
    static auto op = GetVarShape::make();
71
    return imperative::apply(*op, x)[0];
72 73
}

74
ValueRef reduce_to(ValueRef x, ValueRef s) {
75
    static auto op = Reduce::make();
76
    return imperative::apply(*op, x, s)[0];
77 78
}

79
ValueRef reshape_to(ValueRef x, ValueRef s) {
80
    static auto op = Reshape::make();
81
    return imperative::apply(*op, x, s)[0];
82 83
}

84
ValueRef broadcast_to(ValueRef x, ValueRef s) {
85
    static auto op = Broadcast::make();
86
    return imperative::apply(*op, x, s)[0];
87 88
}

89 90 91 92 93 94 95 96 97
ValueRef make_empty_tensor(
        CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
    HostTensorStorage storage(*device);
    storage.ensure_size(dtype->size());
    std::memset(storage.ptr(), 0, dtype->size());
    auto t = imperative::apply(
            CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
            HostStorage::make(storage))[0];
    auto res = broadcast_to(t, shape);
98 99 100
    return res;
}

101
std::optional<ValueRefList> elemwise_grad_rule(
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& elemwise = op.cast_final_safe<Elemwise>();
    if (elemwise.mode != Elemwise::Mode::ADD) {
        return {};
    }
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
120
        SmallVector<ValueRef> ret(2);
121 122 123
        if (!grad) {
            return ret;
        }
124
        for (size_t i = 0; i < 2; ++i) {
125 126
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
127 128
            }
        }
129 130 131 132
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
133 134
}

135
std::optional<ValueRefList> reshape_grad_rule(
136 137
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
138 139
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
140
    std::array<ValueRef, 2> input_shapes;
141
    for (size_t i = 0; i < nr_inp; ++i) {
142 143
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
144 145
        }
    }
146
    auto maker = CustomGradMaker(backward, inputs.size());
147
    maker.output_size(1).output_captured(0, false);
148
    maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
149 150
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
151
        SmallVector<ValueRef> ret(nr_inp);
152 153 154
        if (!grad) {
            return ret;
        }
155
        for (size_t i = 0; i < nr_inp; ++i) {
156
            if (shapes[i]) {
157
                ret[i] = reshape_to(grad, shapes[i]);
158 159 160 161
            }
        }
        return ret;
    });
162 163
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
164 165
}

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
std::optional<ValueRefList> broadcast_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < nr_inp; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(nr_inp);
        if (!grad) {
            return ret;
        }
        for (size_t i = 0; i < nr_inp; ++i) {
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
            }
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
}

197
std::optional<ValueRefList> subtensor_grad_rule(
198 199 200 201 202 203 204 205 206
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& subtensor = op.cast_final_safe<Subtensor>();
    auto&& grad_op = SetSubtensor::make(subtensor.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
207 208
        }
    }
209
    auto maker = CustomGradMaker(backward, inputs.size());
210
    maker.output_size(1).output_captured(0, false);
211 212 213 214
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
215
        SmallVector<ValueRef> ret(1);
216
        if (grad && inputs[0]) {
217
            ValueRefList args_(inputs.size() + 1);
218 219
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
220 221
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
222
                args_[i + 1] = inputs[i];
223
            }
224
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
225 226 227
        }
        return ret;
    });
228 229
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
230 231
}

232
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
233 234 235 236 237 238 239 240 241
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
    auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
242 243
        }
    }
244
    auto maker = CustomGradMaker(backward, inputs.size());
245
    maker.output_size(1).output_captured(0, false);
246 247 248 249
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
250
        SmallVector<ValueRef> ret(1);
251
        if (grad && inputs[0]) {
252
            ValueRefList args_(inputs.size() + 1);
253 254
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
255 256
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
257
                args_[i + 1] = inputs[i];
258
            }
259
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
260 261 262
        }
        return ret;
    });
263 264
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
265 266
}

267
std::optional<ValueRefList> reduce_grad_rule(
268 269 270 271 272 273
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& reduce = op.cast_final_safe<Reduce>();
    if (reduce.mode != Reduce::Mode::SUM) {
        return {};
    }
274 275
    auto axis = reduce.axis;
    if (inputs.size() != 1 || axis == INT_MAX) {
276 277 278 279 280
        return {};
    }
    std::array<ValueRef, 1> input_shapes;
    if (inputs_require_grad[0]) {
        input_shapes[0] = get_shape(inputs[0]);
281
    }
282 283 284
    if (axis < 0) {
        axis = (*inputs[0].shape()).ndim + axis;
    }
285
    auto maker = CustomGradMaker(backward, inputs.size());
286
    auto keepdim = reduce.keepdim || axis == INT_MAX;
287
    maker.output_size(1).output_captured(0, false);
288 289 290 291 292 293 294 295 296 297 298 299 300 301
    maker.backward(
            [shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) {
                mgb_assert(grads.size() == 1);
                ValueRef grad = grads[0];
                if (!keepdim) {
                    auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis}));
                    grad = imperative::apply(*grad_op, grad)[0];
                }
                SmallVector<ValueRef> ret(1);
                if (grad && shapes[0]) {
                    ret[0] = broadcast_to(grad, shapes[0]);
                }
                return ret;
            });
302 303
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
304 305
}

306
std::optional<ValueRefList> addAxis_grad_rule(
307 308 309 310 311 312
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& addAxis = op.cast_final_safe<AddAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = RemoveAxis::make(addAxis.axis);
313
    std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
314
    auto maker = CustomGradMaker(backward, inputs.size());
315
    maker.output_size(1).output_captured(0, false);
316 317 318
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
319
        SmallVector<ValueRef> ret(1);
320
        if (grad && flag_) {
321
            ret[0] = imperative::apply(*grad_op_, grad)[0];
322
        }
323 324
        return ret;
    });
325 326
    maker.finalize();
    return imperative::apply(op, inputs);
327 328
}

329
std::optional<ValueRefList> removeAxis_grad_rule(
330 331 332 333 334 335
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = AddAxis::make(removeAxis.axis);
336
    std::sort(grad_op->axis.begin(), grad_op->axis.end());
337
    auto maker = CustomGradMaker(backward, inputs.size());
338
    maker.output_size(1).output_captured(0, false);
339 340 341
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
342
        SmallVector<ValueRef> ret(1);
343
        if (grad && flag_) {
344
            ret[0] = imperative::apply(*grad_op_, grad)[0];
345
        }
346 347
        return ret;
    });
348 349
    maker.finalize();
    return imperative::apply(op, inputs);
350 351
}

352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
std::optional<ValueRefList> pixelShuffle_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& pixelShuffle = op.cast_final_safe<PixelShuffle>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor);
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(1);
        if (grad && flag_) {
            ret[0] = imperative::apply(*grad_op_, grad)[0];
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(op, inputs);
}

374
std::optional<ValueRefList> fastpathcopy_grad_rule(
375 376 377 378
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1);
    auto maker = CustomGradMaker(backward, inputs.size());
379
    maker.output_size(1).output_captured(0, false);
380 381 382
    maker.backward([](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
383
        SmallVector<ValueRef> ret(1);
384
        if (grad) {
385
            ret[0] = grad;
386 387 388
        }
        return ret;
    });
389 390
    maker.finalize();
    return imperative::apply(op, inputs);
391 392
}

393 394
struct Init {
    Init() {
395 396
        CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
        CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
397
        CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule);
398 399 400 401 402 403 404 405 406
        CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
        CustomBackward::register_grad_rule(
                IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
        CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
        CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
        CustomBackward::register_grad_rule(
                RemoveAxis::typeinfo(), removeAxis_grad_rule);
        CustomBackward::register_grad_rule(
                FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
407 408
        CustomBackward::register_grad_rule(
                PixelShuffle::typeinfo(), pixelShuffle_grad_rule);
409 410 411
    }
} _;

M
Megvii Engine Team 已提交
412 413
}  // namespace
}  // namespace mgb::imperative::python