grad_override.cpp 11.9 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/grad_override.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
14
#include "megbrain/imperative/transformations/grad.h"
15 16

namespace mgb::imperative::python {
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

class CustomGradMaker {
    bool output_size_set = false, input_has_grad_initialized = false;
    CustomBackward& target;
    size_t nr_inputs;
    void init_input_has_grad() {
        if (!input_has_grad_initialized) {
            input_has_grad_initialized = true;
            target.m_input_has_grad.resize(nr_inputs, true);
        }
    }

public:
    CustomGradMaker(CustomBackward& target, size_t nr_inputs)
            : target(target), nr_inputs(nr_inputs) {}

    CustomGradMaker& backward(CustomBackward::BackwardFn f) {
        mgb_assert(!target.m_backward);
        target.m_backward = f;
        return *this;
    }
    // mandatory
    CustomGradMaker& output_size(size_t sz) {
        mgb_assert(!output_size_set);
        output_size_set = true;
        target.m_output_attrs.resize(sz);
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& input_has_grad(size_t i, bool v) {
        init_input_has_grad();
        target.m_input_has_grad.at(i) = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_requires_grad(size_t i, bool v) {
        target.m_output_attrs.at(i).requires_grad = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_captured(size_t i, bool v) {
        target.m_output_attrs.at(i).captured = v;
        return *this;
    }
    void finalize() {
        mgb_assert(output_size_set);
        init_input_has_grad();
    }
};

67 68
namespace {

69
ValueRef get_shape(ValueRef x) {
70
    static auto op = GetVarShape::make();
71
    return imperative::apply(*op, x)[0];
72 73
}

74
ValueRef reduce_to(ValueRef x, ValueRef s) {
75
    static auto op = Reduce::make();
76
    return imperative::apply(*op, x, s)[0];
77 78
}

79
ValueRef reshape_to(ValueRef x, ValueRef s) {
80
    static auto op = Reshape::make();
81
    return imperative::apply(*op, x, s)[0];
82 83
}

84
ValueRef broadcast_to(ValueRef x, ValueRef s) {
85
    static auto op = Broadcast::make();
86
    return imperative::apply(*op, x, s)[0];
87 88
}

89 90 91 92 93 94 95 96 97
ValueRef make_empty_tensor(
        CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
    HostTensorStorage storage(*device);
    storage.ensure_size(dtype->size());
    std::memset(storage.ptr(), 0, dtype->size());
    auto t = imperative::apply(
            CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
            HostStorage::make(storage))[0];
    auto res = broadcast_to(t, shape);
98 99 100
    return res;
}

101
std::optional<ValueRefList> elemwise_grad_rule(
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& elemwise = op.cast_final_safe<Elemwise>();
    if (elemwise.mode != Elemwise::Mode::ADD) {
        return {};
    }
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
120
        ValueRefList ret(2);
121 122 123
        if (!grad) {
            return ret;
        }
124
        for (size_t i = 0; i < 2; ++i) {
125 126
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
127 128
            }
        }
129 130 131 132
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
133 134
}

135
std::optional<ValueRefList> reshape_grad_rule(
136 137 138 139
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> input_shapes;
140
    for (size_t i = 0; i < 2; ++i) {
141 142
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
143 144
        }
    }
145
    auto maker = CustomGradMaker(backward, inputs.size());
146
    maker.output_size(1).output_captured(0, false);
147 148 149
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
150
        ValueRefList ret(2);
151 152 153
        if (!grad) {
            return ret;
        }
154 155
        for (size_t i = 0; i < 2; ++i) {
            if (shapes[i]) {
156
                ret[i] = reshape_to(grad, shapes[i]);
157 158 159 160
            }
        }
        return ret;
    });
161 162
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
163 164
}

165
std::optional<ValueRefList> subtensor_grad_rule(
166 167 168 169 170 171 172 173 174
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& subtensor = op.cast_final_safe<Subtensor>();
    auto&& grad_op = SetSubtensor::make(subtensor.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
175 176
        }
    }
177
    auto maker = CustomGradMaker(backward, inputs.size());
178
    maker.output_size(1).output_captured(0, false);
179 180 181 182
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
183
        ValueRefList ret(1);
184
        if (grad && inputs[0]) {
185
            ValueRefList args_(inputs.size() + 1);
186 187
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
188 189
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
190
                args_[i + 1] = inputs[i];
191
            }
192
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
193 194 195
        }
        return ret;
    });
196 197
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
198 199
}

200
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
201 202 203 204 205 206 207 208 209
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
    auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
210 211
        }
    }
212
    auto maker = CustomGradMaker(backward, inputs.size());
213
    maker.output_size(1).output_captured(0, false);
214 215 216 217
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
218
        ValueRefList ret(1);
219
        if (grad && inputs[0]) {
220
            ValueRefList args_(inputs.size() + 1);
221 222
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
223 224
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
225
                args_[i + 1] = inputs[i];
226
            }
227
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
228 229 230
        }
        return ret;
    });
231 232
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
233 234
}

235
std::optional<ValueRefList> reduce_grad_rule(
236 237 238 239 240 241 242 243 244 245 246 247
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& reduce = op.cast_final_safe<Reduce>();
    if (reduce.mode != Reduce::Mode::SUM) {
        return {};
    }
    if (inputs.size() != 1) {
        return {};
    }
    std::array<ValueRef, 1> input_shapes;
    if (inputs_require_grad[0]) {
        input_shapes[0] = get_shape(inputs[0]);
248
    }
249 250 251 252 253
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
254
        ValueRefList ret(1);
255 256 257 258 259 260 261
        if (grad && shapes[0]) {
            ret[0] = broadcast_to(grad, shapes[0]);
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
262 263
}

264
std::optional<ValueRefList> addAxis_grad_rule(
265 266 267 268 269 270
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& addAxis = op.cast_final_safe<AddAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = RemoveAxis::make(addAxis.axis);
271
    std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
272
    auto maker = CustomGradMaker(backward, inputs.size());
273
    maker.output_size(1).output_captured(0, false);
274 275 276
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
277
        ValueRefList ret(1);
278
        if (grad && flag_) {
279
            ret[0] = imperative::apply(*grad_op_, grad)[0];
280
        }
281 282
        return ret;
    });
283 284
    maker.finalize();
    return imperative::apply(op, inputs);
285 286
}

287
std::optional<ValueRefList> removeAxis_grad_rule(
288 289 290 291 292 293
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = AddAxis::make(removeAxis.axis);
294
    std::sort(grad_op->axis.begin(), grad_op->axis.end());
295
    auto maker = CustomGradMaker(backward, inputs.size());
296
    maker.output_size(1).output_captured(0, false);
297 298 299
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
300
        ValueRefList ret(1);
301
        if (grad && flag_) {
302
            ret[0] = imperative::apply(*grad_op_, grad)[0];
303
        }
304 305
        return ret;
    });
306 307
    maker.finalize();
    return imperative::apply(op, inputs);
308 309
}

310
std::optional<ValueRefList> fastpathcopy_grad_rule(
311 312 313 314
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1);
    auto maker = CustomGradMaker(backward, inputs.size());
315
    maker.output_size(1).output_captured(0, false);
316 317 318
    maker.backward([](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
319
        ValueRefList ret(1);
320
        if (grad) {
321
            ret[0] = grad;
322 323 324
        }
        return ret;
    });
325 326
    maker.finalize();
    return imperative::apply(op, inputs);
327 328
}

329 330
struct Init {
    Init() {
331 332 333 334 335 336 337 338 339 340 341
        CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
        CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
        CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
        CustomBackward::register_grad_rule(
                IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
        CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
        CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
        CustomBackward::register_grad_rule(
                RemoveAxis::typeinfo(), removeAxis_grad_rule);
        CustomBackward::register_grad_rule(
                FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
342 343 344
    }
} _;

M
Megvii Engine Team 已提交
345 346
}  // namespace
}  // namespace mgb::imperative::python