grad_override.cpp 23.2 KB
Newer Older
1 2
#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
3
#include "megbrain/imperative/transformations/grad.h"
4 5

namespace mgb::imperative::python {
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

class CustomGradMaker {
    bool output_size_set = false, input_has_grad_initialized = false;
    CustomBackward& target;
    size_t nr_inputs;
    void init_input_has_grad() {
        if (!input_has_grad_initialized) {
            input_has_grad_initialized = true;
            target.m_input_has_grad.resize(nr_inputs, true);
        }
    }

public:
    CustomGradMaker(CustomBackward& target, size_t nr_inputs)
            : target(target), nr_inputs(nr_inputs) {}

    CustomGradMaker& backward(CustomBackward::BackwardFn f) {
        mgb_assert(!target.m_backward);
        target.m_backward = f;
        return *this;
    }
    // mandatory
    CustomGradMaker& output_size(size_t sz) {
        mgb_assert(!output_size_set);
        output_size_set = true;
        target.m_output_attrs.resize(sz);
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& input_has_grad(size_t i, bool v) {
        init_input_has_grad();
        target.m_input_has_grad.at(i) = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_requires_grad(size_t i, bool v) {
        target.m_output_attrs.at(i).requires_grad = v;
        return *this;
    }
    // optional, defaults to all true
    CustomGradMaker& output_captured(size_t i, bool v) {
        target.m_output_attrs.at(i).captured = v;
        return *this;
    }
    void finalize() {
        mgb_assert(output_size_set);
        init_input_has_grad();
    }
};

56 57
namespace {

58
ValueRef get_shape(ValueRef x) {
59
    static auto op = GetVarShape::make();
60
    return imperative::apply(*op, x)[0];
61 62
}

63
ValueRef reduce_to(ValueRef x, ValueRef s) {
64
    static auto op = Reduce::make();
65
    return imperative::apply(*op, x, s)[0];
66 67
}

68
ValueRef reshape_to(ValueRef x, ValueRef s) {
69
    static auto op = Reshape::make();
70
    return imperative::apply(*op, x, s)[0];
71 72
}

73
ValueRef broadcast_to(ValueRef x, ValueRef s) {
74
    static auto op = Broadcast::make();
75
    return imperative::apply(*op, x, s)[0];
76 77
}

78 79 80 81 82 83 84 85 86
ValueRef make_empty_tensor(
        CompNodeValue::ref_t device, ValueRef shape, DTypeValue::ref_t dtype) {
    HostTensorStorage storage(*device);
    storage.ensure_size(dtype->size());
    std::memset(storage.ptr(), 0, dtype->size());
    auto t = imperative::apply(
            CreateTensor(CreateTensor::Unique, *device, *dtype, ValueShape()),
            HostStorage::make(storage))[0];
    auto res = broadcast_to(t, shape);
87 88 89
    return res;
}

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
std::optional<ValueRefList> matrix_mul_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& matmul = op.cast_final_safe<MatrixMul>();
    size_t dimA = matmul.dimA;
    size_t dimB = matmul.dimB;
    auto&& param = matmul.param();
    auto&& policy = matmul.policy();
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> inps, input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (inputs_require_grad[i ^ 1]) {
            inps[i] = inputs[i];
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes),
                    param, policy, dimA, dimB](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(2);
        if (!grad) {
            return ret;
        }
        size_t dimG = std::max(dimA, dimB);
        if (inps_[1]) {
            if (param.transposeA) {
                // A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul
                auto&& grad_op = MatrixMul::make(
                        param.transposeB, true, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimB, dimG);
                ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0];
            } else {
                // A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul
                auto&& grad_op = MatrixMul::make(
                        false, !param.transposeB, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimG, dimB);
                ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0];
            }
        }
        if (inps_[0]) {
            if (param.transposeB) {
                // A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul
                // (specialized)
                auto&& grad_op = MatrixMul::make(
                        true, param.transposeA, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimG, dimA);
                ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0];
            } else {
                // A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul
                // (specialized)
                auto&& grad_op = MatrixMul::make(
                        !param.transposeA, false, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimA, dimG);
                ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0];
            }
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
}

std::optional<ValueRefList> batched_matrix_mul_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& bmm = op.cast_final_safe<BatchedMatrixMul>();
    size_t dimA = bmm.dimA;
    size_t dimB = bmm.dimB;
    auto&& param = bmm.param();
    auto&& policy = bmm.policy();
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> inps, input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (inputs_require_grad[i ^ 1]) {
            inps[i] = inputs[i];
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes),
                    param, policy, dimA, dimB](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(2);
        if (!grad) {
            return ret;
        }
        size_t dimG = std::max(dimA, dimB);
        if (inps_[1]) {
            if (param.transposeA) {
                auto&& grad_op = BatchedMatrixMul::make(
                        param.transposeB, true, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimB, dimG);
                ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0];
            } else {
                auto&& grad_op = BatchedMatrixMul::make(
                        false, !param.transposeB, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimG, dimB);
                ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0];
            }
            if (dimG != dimA) {
                ret[0] = reduce_to(ret[0], input_shapes_[0]);
            }
        }
        if (inps_[0]) {
            if (param.transposeB) {
                auto&& grad_op = BatchedMatrixMul::make(
                        true, param.transposeA, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimG, dimA);
                ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0];
            } else {
                auto&& grad_op = BatchedMatrixMul::make(
                        !param.transposeA, false, param.compute_mode, param.format,
                        policy.strategy, policy.workspace_limit, dimA, dimG);
                ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0];
            }
            if (dimG != dimB) {
                ret[1] = reduce_to(ret[1], input_shapes_[1]);
            }
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
}

220
std::optional<ValueRefList> elemwise_grad_rule(
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& elemwise = op.cast_final_safe<Elemwise>();
    if (elemwise.mode != Elemwise::Mode::ADD) {
        return {};
    }
    mgb_assert(inputs.size() == 2);
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < 2; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
239
        SmallVector<ValueRef> ret(2);
240 241 242
        if (!grad) {
            return ret;
        }
243
        for (size_t i = 0; i < 2; ++i) {
244 245
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
246 247
            }
        }
248 249 250 251
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
252 253
}

254
std::optional<ValueRefList> reshape_grad_rule(
255 256
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
257 258
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
259
    std::array<ValueRef, 2> input_shapes;
260
    for (size_t i = 0; i < nr_inp; ++i) {
261 262
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
263 264
        }
    }
265
    auto maker = CustomGradMaker(backward, inputs.size());
266
    maker.output_size(1).output_captured(0, false);
267
    maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
268 269
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
270
        SmallVector<ValueRef> ret(nr_inp);
271 272 273
        if (!grad) {
            return ret;
        }
274
        for (size_t i = 0; i < nr_inp; ++i) {
275
            if (shapes[i]) {
276
                ret[i] = reshape_to(grad, shapes[i]);
277 278 279 280
            }
        }
        return ret;
    });
281 282
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
283 284
}

285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
std::optional<ValueRefList> broadcast_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
    std::array<ValueRef, 2> input_shapes;
    for (size_t i = 0; i < nr_inp; ++i) {
        if (inputs_require_grad[i]) {
            input_shapes[i] = get_shape(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(nr_inp);
        if (!grad) {
            return ret;
        }
        for (size_t i = 0; i < nr_inp; ++i) {
            if (shapes[i]) {
                ret[i] = reduce_to(grad, shapes[i]);
            }
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
}

316
std::optional<ValueRefList> subtensor_grad_rule(
317 318 319 320 321 322 323 324 325
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& subtensor = op.cast_final_safe<Subtensor>();
    auto&& grad_op = SetSubtensor::make(subtensor.items);
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
326 327
        }
    }
328
    auto maker = CustomGradMaker(backward, inputs.size());
329
    maker.output_size(1).output_captured(0, false);
330 331 332 333
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
334
        SmallVector<ValueRef> ret(1);
335
        if (grad && inputs[0]) {
336
            ValueRefList args_(inputs.size() + 1);
337 338
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
339 340
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
341
                args_[i + 1] = inputs[i];
342
            }
343
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
344 345 346
        }
        return ret;
    });
347 348
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
349 350
}

351
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
352 353 354
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
355
    auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items);
356 357 358 359 360
    SmallVector<ValueRef> inputs2;
    if (inputs_require_grad[0]) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
361 362
        }
    }
363
    auto maker = CustomGradMaker(backward, inputs.size());
364
    maker.output_size(1).output_captured(0, false);
365 366 367 368
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
369
        SmallVector<ValueRef> ret(1);
370
        if (grad && inputs[0]) {
371
            ValueRefList args_(inputs.size() + 1);
372 373
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
374 375
            args_[1] = grad;
            for (size_t i = 1; i < inputs.size(); ++i) {
376
                args_[i + 1] = inputs[i];
377
            }
378
            ret[0] = imperative::apply(ApplyOp(*grad_op_), args_)[0];
379 380 381
        }
        return ret;
    });
382 383
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
384 385
}

386
std::optional<ValueRefList> reduce_grad_rule(
387 388 389 390 391 392
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto& reduce = op.cast_final_safe<Reduce>();
    if (reduce.mode != Reduce::Mode::SUM) {
        return {};
    }
393 394
    auto axis = reduce.axis;
    if (inputs.size() != 1 || axis == INT_MAX) {
395 396 397 398 399
        return {};
    }
    std::array<ValueRef, 1> input_shapes;
    if (inputs_require_grad[0]) {
        input_shapes[0] = get_shape(inputs[0]);
400
    }
401 402 403
    if (axis < 0) {
        axis = (*inputs[0].shape()).ndim + axis;
    }
404
    auto maker = CustomGradMaker(backward, inputs.size());
405
    auto keepdim = reduce.keepdim || axis == INT_MAX;
406
    maker.output_size(1).output_captured(0, false);
407 408 409 410
    maker.backward(
            [shapes = std::move(input_shapes), axis, keepdim](Span<ValueRef> grads) {
                mgb_assert(grads.size() == 1);
                ValueRef grad = grads[0];
411
                if (!keepdim && grad) {
412 413 414 415 416 417 418 419 420
                    auto&& grad_op = AddAxis::make(std::vector<int32_t>({axis}));
                    grad = imperative::apply(*grad_op, grad)[0];
                }
                SmallVector<ValueRef> ret(1);
                if (grad && shapes[0]) {
                    ret[0] = broadcast_to(grad, shapes[0]);
                }
                return ret;
            });
421 422
    maker.finalize();
    return imperative::apply(ApplyOp(op), inputs);
423 424
}

425
std::optional<ValueRefList> addAxis_grad_rule(
426 427 428 429 430 431
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& addAxis = op.cast_final_safe<AddAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = RemoveAxis::make(addAxis.axis);
432
    std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
433
    auto maker = CustomGradMaker(backward, inputs.size());
434
    maker.output_size(1).output_captured(0, false);
435 436 437
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
438
        SmallVector<ValueRef> ret(1);
439
        if (grad && flag_) {
440
            ret[0] = imperative::apply(*grad_op_, grad)[0];
441
        }
442 443
        return ret;
    });
444 445
    maker.finalize();
    return imperative::apply(op, inputs);
446 447
}

448
std::optional<ValueRefList> removeAxis_grad_rule(
449 450 451 452 453 454
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = AddAxis::make(removeAxis.axis);
455
    std::sort(grad_op->axis.begin(), grad_op->axis.end());
456
    auto maker = CustomGradMaker(backward, inputs.size());
457
    maker.output_size(1).output_captured(0, false);
458 459 460
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
461
        SmallVector<ValueRef> ret(1);
462
        if (grad && flag_) {
463
            ret[0] = imperative::apply(*grad_op_, grad)[0];
464
        }
465 466
        return ret;
    });
467 468
    maker.finalize();
    return imperative::apply(op, inputs);
469 470
}

471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
std::optional<ValueRefList> pixelShuffle_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& pixelShuffle = op.cast_final_safe<PixelShuffle>();
    mgb_assert(inputs.size() == 1);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = PixelShuffleBackward::make(pixelShuffle.factor);
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(1);
        if (grad && flag_) {
            ret[0] = imperative::apply(*grad_op_, grad)[0];
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(op, inputs);
}

493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
std::optional<ValueRefList> indexing_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& indexing = op.cast_final_safe<IndexingOneHot>();
    mgb_assert(inputs.size() == 2);
    bool flag = inputs_require_grad[0];
    auto&& grad_op = IndexingSetOneHot::make(indexing.axis, indexing.ndim);
    SmallVector<ValueRef> inputs2;
    if (flag) {
        inputs2.push_back(get_shape(inputs[0]));
        for (size_t i = 1; i < inputs.size(); ++i) {
            inputs2.push_back(inputs[i]);
        }
    }
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([inputs = std::move(inputs2),
                    grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(1);
        if (grad && inputs[0]) {
            ValueRefList args_(inputs.size() + 1);
            auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
            args_[0] = zeros;
            args_[1] = inputs[1];
            args_[2] = grads[0];
            ret[0] = imperative::apply(*grad_op_, args_)[0];
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(op, inputs);
}

std::optional<ValueRefList> indexing_set_one_hot_grad_rule(
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    auto&& indexingSetOneHot = op.cast_final_safe<IndexingSetOneHot>();
    mgb_assert(inputs.size() == 3);
    SmallVector<ValueRef> inputs2;
    inputs2.push_back(get_shape(inputs[0]));
    inputs2.push_back(inputs[1]);
    inputs2.push_back(get_shape(inputs[2]));
    auto maker = CustomGradMaker(backward, inputs.size());
    maker.output_size(1).output_captured(0, false);
    maker.backward([inputs = std::move(inputs2),
                    &indexingSetOneHot](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
        SmallVector<ValueRef> ret(3);
        if (!grad) {
            return ret;
        }
        if (inputs[0]) {
            auto&& grad_op = IndexingSetOneHot::make(
                    indexingSetOneHot.axis, indexingSetOneHot.ndim);
            ValueRefList args_(inputs.size());
            auto&& zeros = make_empty_tensor(grad.device(), inputs[2], grad.dtype());
            args_[0] = grads[0];
            args_[1] = inputs[1];
            args_[2] = zeros;
            ret[0] = imperative::apply(*grad_op, args_)[0];
        }
        if (inputs[2]) {
            auto&& grad_op = IndexingOneHot::make(
                    indexingSetOneHot.axis, indexingSetOneHot.ndim);
            ValueRefList args_(inputs.size() - 1);
            args_[0] = grads[0];
            args_[1] = inputs[1];
            ret[2] = imperative::apply(*grad_op, args_)[0];
        }
        return ret;
    });
    maker.finalize();
    return imperative::apply(op, inputs);
}

571
std::optional<ValueRefList> fastpathcopy_grad_rule(
572 573 574 575
        const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
        CustomBackward& backward) {
    mgb_assert(inputs.size() == 1);
    auto maker = CustomGradMaker(backward, inputs.size());
576
    maker.output_size(1).output_captured(0, false);
577 578 579
    maker.backward([](Span<ValueRef> grads) {
        mgb_assert(grads.size() == 1);
        ValueRef grad = grads[0];
580
        SmallVector<ValueRef> ret(1);
581
        if (grad) {
582
            ret[0] = grad;
583 584 585
        }
        return ret;
    });
586 587
    maker.finalize();
    return imperative::apply(op, inputs);
588 589
}

590 591
struct Init {
    Init() {
592 593
        CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
        CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
594
        CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule);
595 596 597 598 599 600 601
        CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
        CustomBackward::register_grad_rule(
                IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
        CustomBackward::register_grad_rule(Reduce::typeinfo(), reduce_grad_rule);
        CustomBackward::register_grad_rule(AddAxis::typeinfo(), addAxis_grad_rule);
        CustomBackward::register_grad_rule(
                RemoveAxis::typeinfo(), removeAxis_grad_rule);
602 603 604 605
        CustomBackward::register_grad_rule(
                IndexingOneHot::typeinfo(), indexing_grad_rule);
        CustomBackward::register_grad_rule(
                IndexingSetOneHot::typeinfo(), indexing_set_one_hot_grad_rule);
606 607
        CustomBackward::register_grad_rule(
                FastpathCopy::typeinfo(), fastpathcopy_grad_rule);
608 609
        CustomBackward::register_grad_rule(
                PixelShuffle::typeinfo(), pixelShuffle_grad_rule);
610 611 612
        CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule);
        CustomBackward::register_grad_rule(
                BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule);
613 614 615
    }
} _;

M
Megvii Engine Team 已提交
616 617
}  // namespace
}  // namespace mgb::imperative::python