grad.cpp 21.1 KB
Newer Older
1 2
#include "megbrain/imperative/transformations/grad.h"

3 4
#include <variant>

5
#include "megbrain/imperative/graph_cache.h"
6
#include "megbrain/imperative/resource_manager.h"
7 8 9 10 11 12 13

#include <range/v3/all.hpp>

namespace mgb {
namespace imperative {

static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph(
14
        const OpDef& op, Span<ValueRef> inputs, Span<ValueRef> outputs,
15 16 17 18
        Span<bool> inputs_require_grad) {
    // hash
    using OptimizedBackwardGraphCache = OpMethResultCache<
            std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
19 20
    thread_local auto& cache =
            *ResourceManager::create_local<OptimizedBackwardGraphCache>();
21
    OptimizedBackwardGraphCache::key_t cache_key{op.shared_from_this()};
22
    SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
23
    cache_key.extra<0>() = inputs_require_grad.copy_into<SmallVector<bool>>();
24
    input_descs.resize(inputs.size());
25
    // some overhead, consider simplify LogicalTensorDesc
26
    for (size_t i = 0; i < inputs.size(); ++i) {
27 28
        input_descs[i].layout.dtype = *inputs[i].dtype();
        input_descs[i].comp_node = *inputs[i].device();
29 30
    }

31 32
    auto iter = cache.find(cache_key);
    if (iter != cache.end()) {
33 34 35 36 37 38 39
        return iter->second;
    }

    // slow path
    SmallVector<bool> output_has_grad(outputs.size(), true);
    std::shared_ptr<OptimizedBackwardGraphResult> ret;
    auto bg = OpDef::make_backward_graph(
40
            op, input_descs, std::get<0>(cache_key.extras), output_has_grad);
41 42 43
    if (!bg.graph.empty()) {
        ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
    }
44
    cache.emplace(cache_key, ret);
45 46 47 48 49 50 51 52
    return ret;
}

BackwardGraphWithClosure::BackwardGraphWithClosure(
        std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
        std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs)
        : backward_graph(backward_graph),
          output_mask_offset(inputs.size()),
53 54
          grad_mask_offset(inputs.size() + outputs.size()),
          op(op) {
55 56 57 58 59
    auto& save_for_backward = backward_graph->save_for_backward;
    mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size());
    size_t count = std::count_if(
            save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
    if (!backward_graph->precomp.empty()) {
60
        SmallVector<ValueRef> inputs_and_outputs(inputs.size() + outputs.size());
61
        auto it = inputs_and_outputs.begin();
62
        for (auto&& input : inputs) {
63
            *it++ = input;
64 65
        }
        for (auto&& output : outputs) {
66
            *it++ = output;
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        }
        auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs);
        closure.reserve(precomp.size() + count);
        std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
    } else {
        closure.reserve(count);
    }
    for (size_t i = 0; i < inputs.size(); ++i) {
        if (save_for_backward[i]) {
            closure.push_back(inputs[i]);
        }
    }
    for (size_t i = 0; i < outputs.size(); ++i) {
        if (save_for_backward[inputs.size() + i]) {
            closure.push_back(outputs[i]);
        }
    }
84 85 86 87 88 89 90
    if (outputs.size() > 1) {
        output_descs.reserve(outputs.size());
        for (auto&& output : outputs) {
            auto symbolic_shape = imperative::apply(*GetVarShape::make(), output)[0];
            output_descs.push_back({symbolic_shape, output.dtype(), output.device()});
        }
    }
91 92
}
void BackwardGraphWithClosure::operator()(
93
        Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
94 95 96 97 98
    ValueRef args[closure.size() + grads.size()];
    size_t nargs = 0;
    for (auto&& value : closure) {
        args[nargs++] = value;
    }
99 100
    size_t null_grad = 0;
    size_t valid_grad = 0;
101 102 103
    for (size_t i = 0; i < grads.size(); ++i) {
        if (backward_graph->save_for_backward[grad_mask_offset + i]) {
            if (grads[i]) {
104
                valid_grad++;
105 106
                args[nargs++] = grads[i];
            } else {
107 108
                null_grad++;
                nargs++;
109 110 111
            }
        }
    }
112
    if (valid_grad == 0) {
113 114
        return;
    }
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    if (null_grad > 0) {
        auto zeros_like = [](const OutputDesc& desc) {
            HostTensorStorage storage(*desc.device);
            storage.ensure_size(desc.dtype->size());
            std::memset(storage.ptr(), 0, desc.dtype->size());
            auto t = imperative::apply(
                    CreateTensor(
                            CreateTensor::Unique, *desc.device, *desc.dtype,
                            ValueShape()),
                    HostStorage::make(storage))[0];
            auto res = imperative::apply(*Broadcast::make(), t, desc.shape)[0];
            return res;
        };
        nargs = closure.size();
        for (size_t i = 0; i < grads.size(); ++i) {
            if (backward_graph->save_for_backward[grad_mask_offset + i]) {
                if (!grads[i]) {
                    args[nargs] = zeros_like(output_descs[i]);
                }
                nargs++;
            }
        }
    }
    auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs));
139 140 141 142 143 144 145 146 147 148
    auto&& iter = igrads.begin();
    for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) {
        if (p) {
            receiver(i, std::move(*iter));
            ++iter;
        }
    }
}

void CustomBackward::operator()(
149
        Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
    size_t nargs = grads.size();
    ValueRef args[nargs];
    for (size_t i = 0; i < nargs; ++i) {
        args[i] = grads[i];
    }
    auto ret = m_backward({args, nargs});
    for (size_t i = 0; i < ret.size(); ++i) {
        if (auto&& t = ret[i]) {
            receiver(i, std::move(t));
        }
    }
}

std::string GradSlot::to_string() const {
    bool has_callback = bool(callback);
    return ssprintf(
            "GradSlot{grad=%s, has_callback=%d}", m_grad.to_string().c_str(),
            (int)has_callback);
}

std::string GradFn::to_string() const {
    return ssprintf("GradFn{dests=%s}", imperative::to_string(m_dests).c_str());
}

std::string GradSlotPtr::to_string() const {
    if (!m_fn) {
        return "<empty>";
    }
    return (*this)->to_string();
}

std::string GradValue::to_string() const {
    return ssprintf(
            "GradValue{key=\"%s\", slot=%s, value=%s}", m_key->name().c_str(),
            m_slot.to_string().c_str(), m_value.to_string().c_str());
}

static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule>&
get_backward_rule_storage() {
    static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule> sl_storage;
    return sl_storage;
}

bool CustomBackward::register_grad_rule(Typeinfo* typeinfo, BackwardRule rule) {
    return get_backward_rule_storage().insert({typeinfo, rule}).second;
}

auto CustomBackward::lookup_grad_rule(Typeinfo* typeinfo) -> BackwardRule {
    auto iter = get_backward_rule_storage().find(typeinfo);
    if (iter == get_backward_rule_storage().end()) {
        return {};
    }
    return iter->second;
}

void GradKey::backward() {
    mgb_assert(m_frozen);
    auto& tape = m_frozen_tape;
    for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
        auto& [grad_fn, op] = tape[k];
        auto grad_receiver = [&, grad_fn = grad_fn](size_t i, ValueRef grad) {
            auto& dest = grad_fn->m_dests[i];
            if (dest) {
                auto& existing_grad = dest->m_grad;
                if (!existing_grad) {
                    existing_grad = grad;
                } else {
                    existing_grad = imperative::apply(
                            ApplyOp(*Elemwise::make(Elemwise::Mode::ADD)),
                            existing_grad, grad)[0];
                }
            }
        };
        // clang-format off
        std::visit([&, grad_fn = grad_fn, op = op](auto&& backward) {
            using T = std::decay_t<decltype(backward)>;
            if constexpr (std::is_same_v<T, std::monostate>) {
                mgb_throw(AssertionError, "invalid backward");
            } else {
                mgb_assert(grad_fn->m_slots.size() > 0);
230
                SmallVector<ValueRef> grads (grad_fn->m_slots.size());
231
                auto iter = grads.begin();
232
                for (auto&& slot : grad_fn->m_slots) {
233
                    *iter++ = slot.m_grad;
234 235 236 237 238 239 240 241 242
                }
                backward(grads, grad_receiver);
            }
        }, grad_fn->m_backward);
        // clang-format on
        for (auto&& dest : grad_fn->m_dests) {
            if (!dest) {
                continue;
            }
243
            if (!dest.m_producer_record.next && dest->callback) {
244
                // I'm the last grad producer, invoke callback
245 246 247
                if (dest->m_grad) {
                    dest->callback(dest->m_grad);
                }
248 249 250 251 252 253 254 255 256
            }
        }
        grad_fn->clear();
    }
    tape.clear();
}

GradValue::ref_t GradKey::attach(
        ValueRef tensor, std::function<void(ValueRef)> callback) {
257 258 259 260 261 262 263 264 265 266 267 268 269
    // always create a new grad value
    GradSlotPtr grad_slot;
    auto& grad_fn = grad_slot.m_fn;
    grad_fn = LocalPtr<GradFn>::make();
    grad_fn->m_key = shared_from_this();
    grad_fn->m_slots.resize(1);
    grad_fn->m_slots[0].callback = callback;
    grad_slot.m_index = 0;
    if (auto&& grad_value = tensor.as_ref(m_value_type)) {
        grad_fn->m_backward.emplace<IdentityBackward>();
        grad_fn->m_dests.push_back(grad_value->m_slot);
        tensor = grad_value->m_value;
        m_tape.emplace_back(grad_fn, nullptr);
270
    }
271
    return m_value_type.make(tensor, shared_from_this(), grad_slot);
272 273 274 275 276 277 278 279 280 281 282 283 284
}

void GradKey::freeze() {
    mgb_assert(m_frozen_tape.empty() && !m_frozen);
    for (auto&& [grad_fn, op] : m_tape) {
        if (auto valid_grad_fn = grad_fn.lock()) {
            m_frozen_tape.push_back({valid_grad_fn, op});
        }
    }
    m_tape.clear();
    m_frozen = true;
}

285
ValueRefList GradTransformation::apply_transformation(
286
        const Operator& op, Span<ValueRef> inputs) {
287
    auto fallback = [&] {
288
        SmallVector<ValueRef> unwrapped_inputs(inputs.size());
289 290 291 292 293 294 295 296
        {
            // overhead
            for (size_t i = 0; i < inputs.size(); ++i) {
                if (auto&& grad_value = as_grad_value(inputs[i])) {
                    unwrapped_inputs[i] = grad_value->m_value;
                } else {
                    unwrapped_inputs[i] = inputs[i];
                }
297 298
            }
        }
299
        return imperative::apply(op, unwrapped_inputs);
300
    };
301 302 303
    if (op.is<GetAttr>()) {
        // overhead
        if (auto&& grad_value = as_grad_value(inputs.item())) {
304 305 306 307 308
            return imperative::apply(op, grad_value->m_value);
        } else {
            return imperative::apply(op, inputs);
        }
    }
309
    if (m_suppressed) {
310
        return fallback();
311 312 313
    }
    if (auto* op_val = op.as<ApplyOp>()) {
        size_t nr_require_grad = 0;
314 315 316
        SmallVector<bool> require_grads(inputs.size());
        for (size_t i = 0; i < inputs.size(); ++i) {
            if (is_grad_value(inputs[i])) {
317
                nr_require_grad++;
318
                require_grads[i] = true;
319
            } else {
320
                require_grads[i] = false;
321 322 323 324 325
            }
        }
        if (nr_require_grad == 0) {
            return imperative::apply(op, inputs);
        }
326
        SmallVector<ValueRef> captured_inputs(inputs.size());
327
        SmallVector<bool> inputs_require_grad(inputs.size());
328
        // capture value so that trace could assume input as same
329
        auto capture_value = [](const ValueRef& value) {
330
            // TODO: fastpath copy shouldn't be an OpDef
331 332
            static auto fastpath_copy = FastpathCopy::make();
            return imperative::apply(ApplyOp(*fastpath_copy), value)[0];
333
        };
334 335
        for (size_t i = 0; i < inputs.size(); ++i) {
            auto& input = inputs[i];
336
            if (auto&& grad_value = as_grad_value(input)) {
337 338
                captured_inputs[i] = capture_value(grad_value->m_value);
                inputs_require_grad[i] = true;
339
            } else {
340 341
                captured_inputs[i] = capture_value(input);
                inputs_require_grad[i] = false;
342 343
            }
        }
344 345 346
        // copy grad_fn->m_backward is expensive
        auto grad_fn = LocalPtr<GradFn>::make();
        auto& backward_storage = grad_fn->m_backward;
347 348 349 350 351 352
        auto outputs = [&] {
            auto backward_rule =
                    CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo());
            if (backward_rule) {
                CustomBackward backward;
                auto optional_outputs = backward_rule(
353
                        op_val->op(), captured_inputs, inputs_require_grad, backward);
354 355 356 357 358 359
                if (optional_outputs) {
                    backward_storage = backward;
                    // backward by rule
                    return *optional_outputs;
                }
            }
360
            auto outputs = imperative::apply(op, captured_inputs);
361
            auto backward_graph = make_optimized_backward_graph(
362
                    op_val->op(), captured_inputs, outputs, inputs_require_grad);
363 364
            if (backward_graph) {
                backward_storage = BackwardGraphWithClosure(
365
                        backward_graph, op_val->op().shared_from_this(),
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
                        {captured_inputs.begin(), captured_inputs.end()},
                        {outputs.data(), outputs.size()});
                // backward by make_backward_graph
                return outputs;
            } else {
                // no backward
                return outputs;
            }
        }();
        if (std::holds_alternative<std::monostate>(backward_storage)) {
            return outputs;
        }
        grad_fn->m_key = m_key;
        grad_fn->m_slots.resize(outputs.size());
        mgb_assert(!outputs.empty());
        grad_fn->m_dests.reserve(inputs.size());
        // clang-format off
383
        auto visitor = [&](auto& backward) {
384 385 386 387
            using T = std::decay_t<decltype(backward)>;
            if constexpr (std::is_same_v<T, std::monostate>) {
                mgb_throw(AssertionError, "invalid backward");
            } else {
388
                // little overhead
389 390 391
                for (size_t i = 0; i < inputs.size(); ++i) {
                    if (backward.input_has_grad(i) && require_grads[i]) {
                        auto& input_grad_slot =
392
                                inputs[i].cast(m_value_type).slot();
393 394 395 396 397 398 399 400 401
                        grad_fn->m_dests.emplace_back(input_grad_slot);
                        grad_fn->m_dests.back().m_producer_record.insert_after(
                                input_grad_slot->m_producer_head);
                    } else {
                        grad_fn->m_dests.emplace_back();
                    }
                }
                for (size_t i = 0; i < outputs.size(); ++i) {
                    if (backward.output_requires_grad(i)) {
402
                        // little overhead: Value::make
403
                        auto grad_value = m_value_type.make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
404 405 406 407
                        outputs[i] = record_grad(grad_value);
                    }
                }
            }
408 409 410
        };
        // std::visit may be slightly slower than direct if
        std::visit(visitor, backward_storage);
411 412 413 414
        // clang-format on
        mgb_assert(!grad_fn->m_slots.empty());
        m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
        return outputs;
415 416
    } else if (op.is<CreateTensor>()) {
        return imperative::apply(op, inputs);
417 418
    } else if (auto* attach_grad = op.as<AttachGrad>()) {
        if (!has_key(attach_grad->key())) {
419
            return fallback();
420 421 422
        } else {
            GenericFunction callback =
                    (GenericFunction&)inputs[1].cast<FunctionValue>();
423 424 425 426 427
            auto output =
                    attach_grad->key()->attach(inputs[0], [callback](ValueRef grad) {
                        auto ret = callback({&grad, 1});
                        mgb_assert(ret.empty());
                    });
428
            return {record_grad(output)};
429 430 431
        }
    } else if (auto* grad_backward = op.as<GradBackward>()) {
        if (!has_key(grad_backward->key())) {
432
            return fallback();
433 434 435 436 437 438 439 440 441
        }
        size_t nr_grads = inputs.size() / 2;
        mgb_assert(nr_grads * 2 == inputs.size());
        auto values = inputs.sub(0, nr_grads);
        auto grads = inputs.sub(nr_grads, nr_grads);
        make_backward_closure(values)(grads);
        return {};
    } else if (auto* is_attached_to = op.as<IsAttachedTo>()) {
        if (has_key(is_attached_to->key())) {
442
            if (auto&& grad_value = as_grad_value(inputs[0])) {
443 444 445 446 447 448 449
                // TODO: assert grad_fn
                return {BoolValue::make(true)};
            }
        }
        return {BoolValue::make(false)};
    } else if (auto* set_grad = op.as<SetGrad>()) {
        // TODO: merge SetGrad and ApplyOp
450
        auto grad_fn = LocalPtr<GradFn>::make();
451 452 453 454 455 456
        auto& backward =
                std::get<CustomBackward>(grad_fn->m_backward = CustomBackward());
        size_t nr_inputs = set_grad->nr_inputs();
        mgb_assert(inputs.size() > nr_inputs);
        size_t nr_outputs = inputs.size() - nr_inputs;
        Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
457 458 459 460
        auto outputs_ = fallback();
        backward.m_input_has_grad.resize(nr_inputs, true);
        backward.m_output_attrs.resize(
                nr_outputs, CustomBackward::OutputAttr{true, true});
461 462 463 464
        backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) {
            auto result = fn(inputs);
            return SmallVector<ValueRef>(result.begin(), result.end());
        };
465
        ValueRefList outputs(nr_outputs);
466 467 468 469
        grad_fn->m_key = m_key;
        grad_fn->m_slots.resize(nr_outputs);
        grad_fn->m_dests.reserve(nr_inputs);
        for (size_t i = 0; i < nr_inputs; ++i) {
470
            if (auto&& grad_value = as_grad_value(inputs_[i])) {
471 472 473 474 475 476 477 478 479 480 481 482
                auto& input_grad_slot = grad_value->m_slot;
                grad_fn->m_dests.emplace_back(grad_value->m_slot);
                grad_fn->m_dests.back().m_producer_record.insert_after(
                        input_grad_slot->m_producer_head);
            } else {
                grad_fn->m_dests.emplace_back();
            }
        }
        for (size_t i = 0; i < nr_outputs; ++i) {
            auto& output = outputs_[i];
            auto grad_value = as_grad_value(output);
            if (grad_value) {
483
                grad_value = m_value_type.make(
484 485
                        grad_value->m_value, m_key, GradSlotPtr(grad_fn, i));
            } else {
486
                grad_value = m_value_type.make(output, m_key, GradSlotPtr(grad_fn, i));
487
            }
488
            outputs[i] = record_grad(grad_value);
489 490 491 492 493
        }
        m_key->m_tape.push_back({grad_fn, nullptr});
        return outputs;
    } else if (auto* gbc = op.as<GetBackwardColsure>()) {
        if (gbc->key() != m_key) {
494
            return fallback();
495 496 497
        }
        return {FunctionValue::make(make_backward_closure(inputs))};
    } else if (op.is<DetachGrad>()) {
498
        if (auto&& grad_value = as_grad_value(inputs[0])) {
499 500 501 502 503 504
            return {grad_value->m_value};
        } else {
            return {inputs[0]};
        }
    } else if (op.is<GetGradKey>()) {
        for (auto&& input : inputs) {
505
            if (auto&& grad_value = as_grad_value(input)) {
506 507 508 509 510 511
                return {GradKeyValue::make(grad_value->m_key)};
            }
        }
        return imperative::apply(op, inputs);
    } else if (op.kind() == Operator::IdentityLike) {
        mgb_assert(inputs.size() == 1);
512
        if (auto&& grad_value = as_grad_value(inputs[0])) {
513
            auto output = imperative::apply(op, grad_value->m_value)[0];
514
            auto grad_output = m_value_type.make(output, m_key, grad_value->slot());
515 516 517 518 519
            return {record_grad(grad_output)};
        } else {
            return imperative::apply(op, inputs);
        }
    } else {
520
        return fallback();
521 522 523 524 525 526 527 528
    }
}

GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
    // reset GradKey
    auto grad_key = m_key;
    std::vector<GradSlotPtr> y_slots;
    for (auto&& y : ys) {
529
        if (auto&& grad_value = as_grad_value(y)) {
530
            y_slots.push_back(grad_value->slot());
531 532 533 534
        } else {
            y_slots.emplace_back();
        }
    }
535
    GenericFunction closure = [grad_key, y_slots](Span<ValueRef> dys) -> ValueRefList {
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
        size_t nr_grads = y_slots.size();
        mgb_assert(dys.size() == nr_grads);
        for (size_t i = 0; i < nr_grads; ++i) {
            if (y_slots[i]) {
                y_slots[i]->m_grad = dys[i];
            }
        }
        grad_key->backward();
        return {};
    };
    grad_key->freeze();
    cleanup();
    return closure;
}

void GradTransformation::on_unregister() noexcept {
    cleanup();
}

void GradTransformation::cleanup() {
    for (auto&& weak_value : m_weak_values) {
        auto grad_value = weak_value.lock();
        if (grad_value) {
            mgb_assert(grad_value->m_key == m_key);
            grad_value.reset(grad_value->m_value);
        }
    }
    m_weak_values.clear();
    m_key = {};
}

void GradTransformation::suppress() {
    m_suppressed++;
}

void GradTransformation::resume() {
    m_suppressed--;
}

}  // namespace imperative
}  // namespace mgb