From 9ce1f0f5d1170d81082695ab08d0466e5db706ce Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:23:37 +0800 Subject: [PATCH] refactor(dispatch): implement grad GitOrigin-RevId: d8367f9587093919c4dcb40361c7f91a9589f6c7 --- imperative/src/impl/transformations/grad.cpp | 543 ++++++++++++++++++ .../imperative/transformations/grad.h | 411 +++++++++++++ 2 files changed, 954 insertions(+) create mode 100644 imperative/src/impl/transformations/grad.cpp create mode 100644 imperative/src/include/megbrain/imperative/transformations/grad.h diff --git a/imperative/src/impl/transformations/grad.cpp b/imperative/src/impl/transformations/grad.cpp new file mode 100644 index 000000000..5a3023ae8 --- /dev/null +++ b/imperative/src/impl/transformations/grad.cpp @@ -0,0 +1,543 @@ +/** + * \file imperative/src/impl/transformations/grad.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/transformations/grad.h" + +#include "megbrain/imperative/graph_cache.h" + +#include + +namespace mgb { +namespace imperative { + +static std::shared_ptr make_optimized_backward_graph( + std::shared_ptr op, Span inputs, Span outputs, + Span inputs_require_grad) { + // hash + using OptimizedBackwardGraphCache = OpMethResultCache< + std::shared_ptr, SmallVector>; + thread_local auto cache = std::make_unique(); + OptimizedBackwardGraphCache::key_t cache_key{op}; + SmallVector& input_descs = cache_key.inputs; + std::get<0>(cache_key.extras) = inputs_require_grad.copy_into>(); + input_descs.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + input_descs[i].layout.dtype = inputs[i].dtype().cast(); + input_descs[i].comp_node = inputs[i].device().cast(); + } + + auto iter = cache->find(cache_key); + if (iter != cache->end()) { + return iter->second; + } + + // slow path + SmallVector output_has_grad(outputs.size(), true); + std::shared_ptr ret; + auto bg = OpDef::make_backward_graph( + *op, input_descs, std::get<0>(cache_key.extras), output_has_grad); + if (!bg.graph.empty()) { + ret = std::make_shared(bg); + } + cache->emplace(cache_key, ret); + return ret; +} + +BackwardGraphWithClosure::BackwardGraphWithClosure( + std::shared_ptr backward_graph, + std::shared_ptr op, Span inputs, Span outputs) + : backward_graph(backward_graph), + output_mask_offset(inputs.size()), + grad_mask_offset(inputs.size() + outputs.size()) { + auto& save_for_backward = backward_graph->save_for_backward; + mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); + size_t count = std::count_if( + save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); + if (!backward_graph->precomp.empty()) { + SmallVector inputs_and_outputs; + for (auto&& input : inputs) { + inputs_and_outputs.push_back(input); + } + for (auto&& output : outputs) { + inputs_and_outputs.push_back(output); + } + auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); + closure.reserve(precomp.size() + count); + std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); + } else { + closure.reserve(count); + } + for (size_t i = 0; i < inputs.size(); ++i) { + if (save_for_backward[i]) { + closure.push_back(inputs[i]); + } + } + for (size_t i = 0; i < outputs.size(); ++i) { + if (save_for_backward[inputs.size() + i]) { + closure.push_back(outputs[i]); + } + } +} +void BackwardGraphWithClosure::operator()( + std::vector grads, std::function receiver) { + ValueRef args[closure.size() + grads.size()]; + size_t nargs = 0; + for (auto&& value : closure) { + args[nargs++] = value; + } + bool null_grad = false; + for (size_t i = 0; i < grads.size(); ++i) { + if (backward_graph->save_for_backward[grad_mask_offset + i]) { + if (grads[i]) { + mgb_assert(!null_grad, "null_grad"); + args[nargs++] = grads[i]; + } else { + null_grad = true; + } + } + } + if (null_grad) { + return; + } + auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); + auto&& iter = igrads.begin(); + for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { + if (p) { + receiver(i, std::move(*iter)); + ++iter; + } + } +} + +void CustomBackward::operator()( + std::vector grads, std::function receiver) { + size_t nargs = grads.size(); + ValueRef args[nargs]; + for (size_t i = 0; i < nargs; ++i) { + args[i] = grads[i]; + } + auto ret = m_backward({args, nargs}); + for (size_t i = 0; i < ret.size(); ++i) { + if (auto&& t = ret[i]) { + receiver(i, std::move(t)); + } + } +} + +std::string GradSlot::to_string() const { + bool has_callback = bool(callback); + return ssprintf( + "GradSlot{grad=%s, has_callback=%d}", m_grad.to_string().c_str(), + (int)has_callback); +} + +std::string GradFn::to_string() const { + return ssprintf("GradFn{dests=%s}", imperative::to_string(m_dests).c_str()); +} + +std::string GradSlotPtr::to_string() const { + if (!m_fn) { + return ""; + } + return (*this)->to_string(); +} + +std::string GradValue::to_string() const { + return ssprintf( + "GradValue{key=\"%s\", slot=%s, value=%s}", m_key->name().c_str(), + m_slot.to_string().c_str(), m_value.to_string().c_str()); +} + +static std::unordered_map& +get_backward_rule_storage() { + static std::unordered_map sl_storage; + return sl_storage; +} + +bool CustomBackward::register_grad_rule(Typeinfo* typeinfo, BackwardRule rule) { + return get_backward_rule_storage().insert({typeinfo, rule}).second; +} + +auto CustomBackward::lookup_grad_rule(Typeinfo* typeinfo) -> BackwardRule { + auto iter = get_backward_rule_storage().find(typeinfo); + if (iter == get_backward_rule_storage().end()) { + return {}; + } + return iter->second; +} + +void GradKey::backward() { + mgb_assert(m_frozen); + auto& tape = m_frozen_tape; + for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { + auto& [grad_fn, op] = tape[k]; + auto grad_receiver = [&, grad_fn = grad_fn](size_t i, ValueRef grad) { + auto& dest = grad_fn->m_dests[i]; + if (dest) { + auto& existing_grad = dest->m_grad; + if (!existing_grad) { + existing_grad = grad; + } else { + existing_grad = imperative::apply( + ApplyOp(*Elemwise::make(Elemwise::Mode::ADD)), + existing_grad, grad)[0]; + } + } + }; + // clang-format off + std::visit([&, grad_fn = grad_fn, op = op](auto&& backward) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + mgb_throw(AssertionError, "invalid backward"); + } else { + mgb_assert(grad_fn->m_slots.size() > 0); + std::vector grads; + for (auto&& slot : grad_fn->m_slots) { + grads.push_back(slot.m_grad); + } + backward(grads, grad_receiver); + } + }, grad_fn->m_backward); + // clang-format on + for (auto&& dest : grad_fn->m_dests) { + if (!dest) { + continue; + } + if (!dest.m_producer_record.next && dest->callback && dest->m_grad) { + // I'm the last grad producer, invoke callback + dest->callback(dest->m_grad); + } + } + grad_fn->clear(); + } + tape.clear(); +} + +GradValue::ref_t GradKey::attach( + ValueRef tensor, std::function callback) { + auto grad_value = tensor.as_ref(); + if (grad_value && grad_value->has_key(shared_from_this())) { + mgb_assert( + !tensor.cast().slot_for(shared_from_this())->callback, + "callback exists"); + } else { + GradSlotPtr grad_slot; + auto& grad_fn = grad_slot.m_fn; + grad_fn = std::make_shared(); + grad_fn->m_key = shared_from_this(); + grad_fn->m_slots.resize(1); + grad_slot.m_index = 0; + grad_value = GradValue::make(tensor, shared_from_this(), grad_slot); + } + grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback; + return grad_value; +} + +void GradKey::freeze() { + mgb_assert(m_frozen_tape.empty() && !m_frozen); + for (auto&& [grad_fn, op] : m_tape) { + if (auto valid_grad_fn = grad_fn.lock()) { + m_frozen_tape.push_back({valid_grad_fn, op}); + } + } + m_tape.clear(); + m_frozen = true; +} + +std::vector GradTransformation::apply_transformation( + const Operator& op, Span inputs) { + auto unwrap_inputs = [this](Span inputs) -> SmallVector { + SmallVector unwrapped_inputs; + for (auto&& input : inputs) { + if (auto grad_value = as_grad_value(input)) { + unwrapped_inputs.push_back(grad_value->m_value); + } else { + unwrapped_inputs.push_back(input); + } + } + return unwrapped_inputs; + }; + if (m_suppressed) { + return imperative::apply(op, unwrap_inputs(inputs)); + } + if (auto* op_val = op.as()) { + size_t nr_require_grad = 0; + SmallVector require_grads; + for (auto&& input : inputs) { + if (is_grad_value(input)) { + nr_require_grad++; + require_grads.push_back(true); + } else { + require_grads.push_back(false); + } + } + if (nr_require_grad == 0) { + return imperative::apply(op, inputs); + } + SmallVector captured_inputs; + SmallVector inputs_require_grad; + // capture value so that trace could assume input as same + auto capture_value = [](ValueRef value) { + // TODO: fastpath copy shouldn't be an OpDef + return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; + }; + for (auto& input : inputs) { + if (auto grad_value = as_grad_value(input)) { + captured_inputs.push_back(capture_value(grad_value->m_value)); + inputs_require_grad.push_back(true); + } else { + captured_inputs.push_back(capture_value(input)); + inputs_require_grad.push_back(false); + } + } + decltype(std::declval().m_backward) backward_storage; + auto outputs = [&] { + auto backward_rule = + CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo()); + if (backward_rule) { + CustomBackward backward; + auto optional_outputs = backward_rule( + op_val->op(), {captured_inputs.data(), captured_inputs.size()}, + {inputs_require_grad.data(), inputs_require_grad.size()}, + backward); + if (optional_outputs) { + backward_storage = backward; + // backward by rule + return *optional_outputs; + } + } + auto outputs = imperative::apply( + op, {captured_inputs.begin(), captured_inputs.end()}); + auto backward_graph = make_optimized_backward_graph( + op.cast().op().shared_from_this(), + {captured_inputs.begin(), captured_inputs.end()}, + {outputs.data(), outputs.size()}, + {inputs_require_grad.data(), inputs_require_grad.size()}); + if (backward_graph) { + backward_storage = BackwardGraphWithClosure( + backward_graph, op.cast().op().shared_from_this(), + {captured_inputs.begin(), captured_inputs.end()}, + {outputs.data(), outputs.size()}); + // backward by make_backward_graph + return outputs; + } else { + // no backward + return outputs; + } + }(); + if (std::holds_alternative(backward_storage)) { + return outputs; + } + auto grad_fn = std::make_shared(); + grad_fn->m_key = m_key; + grad_fn->m_slots.resize(outputs.size()); + grad_fn->m_backward = backward_storage; + mgb_assert(!outputs.empty()); + grad_fn->m_dests.reserve(inputs.size()); + // clang-format off + std::visit([&](auto& backward) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + mgb_throw(AssertionError, "invalid backward"); + } else { + for (size_t i = 0; i < inputs.size(); ++i) { + if (backward.input_has_grad(i) && require_grads[i]) { + auto& input_grad_slot = + inputs[i].cast().slot_for(m_key); + grad_fn->m_dests.emplace_back(input_grad_slot); + grad_fn->m_dests.back().m_producer_record.insert_after( + input_grad_slot->m_producer_head); + } else { + grad_fn->m_dests.emplace_back(); + } + } + for (size_t i = 0; i < outputs.size(); ++i) { + if (backward.output_requires_grad(i)) { + auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); + outputs[i] = record_grad(grad_value); + } + } + } + }, grad_fn->m_backward); + // clang-format on + mgb_assert(!grad_fn->m_slots.empty()); + m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); + return outputs; + } else if (auto* attach_grad = op.as()) { + if (!has_key(attach_grad->key())) { + return imperative::apply(op, unwrap_inputs(inputs)); + } + auto tensor = inputs[0]; + GenericFunction callback = (GenericFunction&)inputs[1].cast(); + auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { + auto ret = callback({&grad, 1}); + assert(ret.empty()); + }); + return {record_grad(output)}; + } else if (auto* grad_backward = op.as()) { + if (!has_key(grad_backward->key())) { + return imperative::apply(op, unwrap_inputs(inputs)); + } + size_t nr_grads = inputs.size() / 2; + mgb_assert(nr_grads * 2 == inputs.size()); + auto values = inputs.sub(0, nr_grads); + auto grads = inputs.sub(nr_grads, nr_grads); + make_backward_closure(values)(grads); + return {}; + } else if (auto* is_attached_to = op.as()) { + if (has_key(is_attached_to->key())) { + if (auto grad_value = as_grad_value(inputs[0])) { + // TODO: assert grad_fn + return {BoolValue::make(true)}; + } + } + return {BoolValue::make(false)}; + } else if (auto* set_grad = op.as()) { + // TODO: merge SetGrad and ApplyOp + auto grad_fn = std::make_shared(); + auto& backward = + std::get(grad_fn->m_backward = CustomBackward()); + size_t nr_inputs = set_grad->nr_inputs(); + mgb_assert(inputs.size() > nr_inputs); + size_t nr_outputs = inputs.size() - nr_inputs; + Span inputs_ = {inputs.data(), nr_inputs}; + Span outputs_ = {inputs.data() + nr_inputs, nr_outputs}; + backward.m_input_has_grad = SmallVector(nr_inputs, true); + backward.m_output_attrs = + SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); + backward.m_backward = set_grad->grad_fn(); + std::vector outputs; + grad_fn->m_key = m_key; + grad_fn->m_slots.resize(nr_outputs); + grad_fn->m_dests.reserve(nr_inputs); + for (size_t i = 0; i < nr_inputs; ++i) { + if (auto grad_value = as_grad_value(inputs_[i])) { + auto& input_grad_slot = grad_value->m_slot; + grad_fn->m_dests.emplace_back(grad_value->m_slot); + grad_fn->m_dests.back().m_producer_record.insert_after( + input_grad_slot->m_producer_head); + } else { + grad_fn->m_dests.emplace_back(); + } + } + for (size_t i = 0; i < nr_outputs; ++i) { + auto& output = outputs_[i]; + auto grad_value = as_grad_value(output); + if (grad_value) { + grad_value = GradValue::make( + grad_value->m_value, m_key, GradSlotPtr(grad_fn, i)); + } else { + grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); + } + outputs.push_back(record_grad(grad_value)); + } + m_key->m_tape.push_back({grad_fn, nullptr}); + return outputs; + } else if (auto* gbc = op.as()) { + if (gbc->key() != m_key) { + return imperative::apply(op, unwrap_inputs(inputs)); + } + return {FunctionValue::make(make_backward_closure(inputs))}; + } else if (op.is()) { + if (auto grad_value = as_grad_value(inputs[0])) { + return {grad_value->m_value}; + } else { + return {inputs[0]}; + } + } else if (op.is()) { + for (auto&& input : inputs) { + if (auto grad_value = as_grad_value(input)) { + return {GradKeyValue::make(grad_value->m_key)}; + } + } + return imperative::apply(op, inputs); + } else if (op.kind() == Operator::IdentityLike) { + mgb_assert(inputs.size() == 1); + if (auto grad_value = as_grad_value(inputs[0])) { + auto output = imperative::apply(op, grad_value->m_value)[0]; + auto grad_output = GradValue::make( + output, grad_value->key(), grad_value->slot_for(m_key)); + return {record_grad(grad_output)}; + } else { + return imperative::apply(op, inputs); + } + } else if (op.is()) { + return imperative::apply(op, inputs); + } else { + SmallVector unwrapped_inputs; + for (auto&& input : inputs) { + if (auto grad_value = as_grad_value(input)) { + unwrapped_inputs.push_back(grad_value->m_value); + } else { + unwrapped_inputs.push_back(input); + } + } + auto outputs = imperative::apply( + op, {unwrapped_inputs.data(), unwrapped_inputs.size()}); + mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty()); + return outputs; + } +} + +GenericFunction GradTransformation::make_backward_closure(Span ys) { + // reset GradKey + auto grad_key = m_key; + std::vector y_slots; + for (auto&& y : ys) { + if (auto grad_value = as_grad_value(y)) { + y_slots.push_back(grad_value->slot_for(grad_key)); + } else { + y_slots.emplace_back(); + } + } + GenericFunction closure = [grad_key, + y_slots](Span dys) -> std::vector { + size_t nr_grads = y_slots.size(); + mgb_assert(dys.size() == nr_grads); + for (size_t i = 0; i < nr_grads; ++i) { + if (y_slots[i]) { + y_slots[i]->m_grad = dys[i]; + } + } + grad_key->backward(); + return {}; + }; + grad_key->freeze(); + cleanup(); + return closure; +} + +void GradTransformation::on_unregister() noexcept { + cleanup(); +} + +void GradTransformation::cleanup() { + for (auto&& weak_value : m_weak_values) { + auto grad_value = weak_value.lock(); + if (grad_value) { + mgb_assert(grad_value->m_key == m_key); + grad_value.reset(grad_value->m_value); + } + } + m_weak_values.clear(); + m_key = {}; +} + +void GradTransformation::suppress() { + m_suppressed++; +} + +void GradTransformation::resume() { + m_suppressed--; +} + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/grad.h b/imperative/src/include/megbrain/imperative/transformations/grad.h new file mode 100644 index 000000000..d5f2e399f --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/grad.h @@ -0,0 +1,411 @@ +/** + * \file imperative/src/include/megbrain/imperative/grad.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "megbrain/imperative/backward_graph_opt.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/interpreter.h" +#include "megbrain/imperative/opr_utility.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/imperative/utils/helper.h" +#include "megbrain/imperative/utils/intrusive_list.h" +#include "megbrain/imperative/utils/to_string.h" + +namespace mgb::imperative { + +struct BackwardGraphWithClosure { + std::shared_ptr backward_graph; + SmallVector closure; + size_t output_mask_offset; + size_t grad_mask_offset; + + BackwardGraphWithClosure( + std::shared_ptr backward_graph, + std::shared_ptr op, Span inputs, Span outputs); + + void operator()( + std::vector grads, + std::function receiver); + + bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } + + bool output_requires_grad(size_t i) { + return backward_graph->save_for_backward[grad_mask_offset + i]; + } + bool output_captured(size_t i) { + return backward_graph->save_for_backward[output_mask_offset + i]; + } +}; + +struct CustomBackward; + +using GradRuleFn = + std::function(Span inputs, CustomBackward&)>; + +struct CustomBackward { + using BackwardFn = std::function(Span)>; + using BackwardRule = std::function>( + const OpDef&, Span, Span, CustomBackward&)>; + BackwardFn m_backward; + SmallVector m_input_has_grad; + struct OutputAttr { + bool requires_grad = true, captured = true; + }; + SmallVector m_output_attrs; + +public: + void operator()( + std::vector grads, + std::function receiver); + + bool input_has_grad(size_t i) { return m_input_has_grad[i]; } + bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } + bool output_captured(size_t i) { return m_output_attrs[i].captured; } + + static bool register_grad_rule(Typeinfo* typeinfo, BackwardRule rule); + static BackwardRule lookup_grad_rule(Typeinfo* typeinfo); +}; + +class GradSlot; +class GradSlotPtr; +class GradSlotProducerPtr; +class GradFn; +class GradKey; + +struct GradProducerRecord : utils::intrusive_list::Node { + using Node = utils::intrusive_list::Node; + + GradProducerRecord() = default; + GradProducerRecord(head_t& head) : Node(utils::intrusive_list::after_t{}, head) {} +}; + +class GradSlot { +private: + ValueRef m_grad; + GradProducerRecord::head_t m_producer_head; + std::function callback; + +public: + std::string to_string() const; + + friend class GradKey; + friend class GradSlotProducerPtr; + friend class GradTransformation; +}; + +template <> +struct ToStringTrait { + std::string operator()(const GradSlot& value) const { return value.to_string(); } +}; + +class GradFn { +private: + std::weak_ptr m_key; + std::vector m_slots; + std::vector m_dests; + std::variant m_backward; + +public: + void clear() { + m_key.reset(); + m_slots.clear(); + m_dests.clear(); + m_backward.emplace(); + } + + std::string to_string() const; + + friend class GradSlotPtr; + friend class GradKey; + friend class GradTransformation; +}; + +class GradSlotPtr { +private: + std::shared_ptr m_fn; + size_t m_index = 0; + +public: + GradSlotPtr(std::shared_ptr fn, size_t index) : m_fn(fn), m_index(index) {} + GradSlotPtr() = default; + GradSlot* operator->() const { return &m_fn->m_slots[m_index]; } + + operator bool() const { return bool(m_fn); } + + std::string to_string() const; + + friend class GradKey; + friend class GradTransformation; +}; + +template <> +struct ToStringTrait { + std::string operator()(const GradSlotPtr& value) const { return value.to_string(); } +}; + +class GradSlotProducerPtr : public GradSlotPtr { +private: + GradProducerRecord m_producer_record; + bool dirty = false; + +public: + GradSlotProducerPtr(const GradSlotPtr& info) + : GradSlotPtr(info), m_producer_record(info->m_producer_head) {} + GradSlotProducerPtr() = default; + GradSlotProducerPtr(GradSlotProducerPtr&&) = default; + ~GradSlotProducerPtr() { dirty = true; } + friend class GradKey; + friend class GradTransformation; +}; + +template <> +struct ToStringTrait { + std::string operator()(const GradSlotProducerPtr& value) const { + return value.to_string(); + } +}; + +class GradValue final : public ValueImpl { +private: + ValueRef m_value; + std::shared_ptr m_key; + GradSlotPtr m_slot; + +public: + GradValue(ValueRef value, std::shared_ptr key, GradSlotPtr slot = {}) + : m_value(value), m_key(key), m_slot(slot) {} + + std::string to_string() const override; + + bool has_key(std::shared_ptr key) const { return m_key == key; } + + const GradSlotPtr& slot_for(std::shared_ptr key) const { + mgb_assert(m_key == key); + return m_slot; + } + + std::shared_ptr key() const { return m_key; } + + void clear() override { + m_slot = {}; + m_value = {}; + m_key = nullptr; + } + + void on_watch() override { m_value.watch(); } + + void on_unwatch() override { m_value.unwatch(); } + + friend class GradKey; + friend class GradTransformation; +}; + +class GradKey : public std::enable_shared_from_this { +private: + std::string m_name; + std::vector, std::shared_ptr>> m_tape; + std::vector, std::shared_ptr>> + m_frozen_tape; + bool m_frozen = false; + +public: + void backward(); + GradValue::ref_t attach(ValueRef tensor, std::function callback); + const std::string& name() const { return m_name; } + void name(std::string name) { m_name = std::move(name); } + void freeze(); + + friend class GradTransformation; +}; + +class GradKeyValue final + : public MixinValueImpl> { +public: + using MixinValueImpl::MixinValueImpl; + + std::string to_string() const override { + return ssprintf("GradKey{%s}", (*this)->name().c_str()); + } +}; + +class GradTransformation final : public Transformation { +private: + std::shared_ptr m_key; + std::vector m_weak_values; + size_t m_suppressed = 0; + +public: + GradTransformation(std::shared_ptr key) : m_key(key) {} + + auto record_grad(GradValue::ref_t tensor) { + m_weak_values.push_back(tensor); + return tensor; + } + + bool is_grad_value(ValueRef value) { + if (auto* grad_value = value.as()) { + if (grad_value->has_key(m_key)) { + return true; + } + } + return false; + } + + /** + * \brief test whether value is related to this GradTransformation + * + * there may be multiple grad transformations, so simply using value.is() + * is unsafe + * + * \param value + * \return GradValue::ref_t + */ + GradValue::ref_t as_grad_value(ValueRef value) { + if (auto grad_value = value.as_ref()) { + if (grad_value->has_key(m_key)) { + return grad_value; + } + } + return {}; + } + + bool has_key(std::shared_ptr key) { + if (key == m_key) { + return true; + } + return false; + } + + std::vector apply_transformation( + const Operator& op, Span inputs) override; + + ValueRef unwrap(ValueRef value) override { + if (auto grad_val = as_grad_value(value)) { + return grad_val->m_value; + } + return value; + } + + std::string name() const override { return "GradTransformation"; } + + GenericFunction make_backward_closure(Span ys); + + void on_unregister() noexcept override; + + void cleanup(); + void suppress(); + void resume(); +}; + +class DetachGrad : public OperatorImpl { +private: + // TODO: identified by GradKey +public: + std::string to_string() const override { return "DetachValue"; } + + std::vector fallback(Span inputs) const override { + return {inputs.as_array<1>()[0]}; + } +}; + +class AttachGrad : public OperatorImpl { +private: + std::shared_ptr m_key; + +public: + AttachGrad(std::shared_ptr key) : m_key(key) {} + std::shared_ptr key() { return m_key; } + + std::string to_string() const override { + return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); + } +}; + +class GradBackward : public OperatorImpl { +private: + std::shared_ptr m_key; + +public: + GradBackward(std::shared_ptr key) : m_key(key) {} + + std::shared_ptr key() { return m_key; } + + std::string to_string() const override { + return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); + } +}; + +class IsAttachedTo : public OperatorImpl { +private: + std::shared_ptr m_key; + +public: + IsAttachedTo(std::shared_ptr key) : m_key(key) {} + std::shared_ptr key() { return m_key; } + + std::string to_string() const override { + return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); + } + + std::vector fallback(Span inputs) const override { + return {BoolValue::make(false)}; + } +}; + +class SetGrad : public OperatorImpl { +private: + std::shared_ptr m_key; + GenericFunction m_grad_fn; + size_t m_nr_inputs; + +public: + SetGrad(std::shared_ptr key, GenericFunction grad_fn, size_t nr_inputs) + : m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} + + GenericFunction grad_fn() { return m_grad_fn; } + + size_t nr_inputs() { return m_nr_inputs; } + + std::string to_string() const override { + return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); + } +}; + +class GetGradKey : public OperatorImpl { +public: + GetGradKey() = default; + + std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } + + std::vector fallback(Span inputs) const override { + return {ValueRef()}; + } +}; + +class GetBackwardColsure + : public OperatorImpl { +private: + std::shared_ptr m_key; + +public: + GetBackwardColsure(std::shared_ptr key) : m_key(key) {} + + std::shared_ptr key() { return m_key; } + + std::string to_string() const override { + return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); + } +}; + +} // namespace mgb::imperative -- GitLab