提交 9ce1f0f5 编写于 作者: M Megvii Engine Team

refactor(dispatch): implement grad

GitOrigin-RevId: d8367f9587093919c4dcb40361c7f91a9589f6c7
上级 c609c031
/**
* \file imperative/src/impl/transformations/grad.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/graph_cache.h"
#include <range/v3/all.hpp>
namespace mgb {
namespace imperative {
static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph(
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs,
Span<bool> inputs_require_grad) {
// hash
using OptimizedBackwardGraphCache = OpMethResultCache<
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
thread_local auto cache = std::make_unique<OptimizedBackwardGraphCache>();
OptimizedBackwardGraphCache::key_t cache_key{op};
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>();
input_descs.resize(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
input_descs[i].layout.dtype = inputs[i].dtype().cast<DTypeValue>();
input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>();
}
auto iter = cache->find(cache_key);
if (iter != cache->end()) {
return iter->second;
}
// slow path
SmallVector<bool> output_has_grad(outputs.size(), true);
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = OpDef::make_backward_graph(
*op, input_descs, std::get<0>(cache_key.extras), output_has_grad);
if (!bg.graph.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
cache->emplace(cache_key, ret);
return ret;
}
BackwardGraphWithClosure::BackwardGraphWithClosure(
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs)
: backward_graph(backward_graph),
output_mask_offset(inputs.size()),
grad_mask_offset(inputs.size() + outputs.size()) {
auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size());
size_t count = std::count_if(
save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
if (!backward_graph->precomp.empty()) {
SmallVector<ValueRef> inputs_and_outputs;
for (auto&& input : inputs) {
inputs_and_outputs.push_back(input);
}
for (auto&& output : outputs) {
inputs_and_outputs.push_back(output);
}
auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs);
closure.reserve(precomp.size() + count);
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure));
} else {
closure.reserve(count);
}
for (size_t i = 0; i < inputs.size(); ++i) {
if (save_for_backward[i]) {
closure.push_back(inputs[i]);
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (save_for_backward[inputs.size() + i]) {
closure.push_back(outputs[i]);
}
}
}
void BackwardGraphWithClosure::operator()(
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
ValueRef args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& value : closure) {
args[nargs++] = value;
}
bool null_grad = false;
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (grads[i]) {
mgb_assert(!null_grad, "null_grad");
args[nargs++] = grads[i];
} else {
null_grad = true;
}
}
}
if (null_grad) {
return;
}
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs));
auto&& iter = igrads.begin();
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) {
if (p) {
receiver(i, std::move(*iter));
++iter;
}
}
}
void CustomBackward::operator()(
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
size_t nargs = grads.size();
ValueRef args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = grads[i];
}
auto ret = m_backward({args, nargs});
for (size_t i = 0; i < ret.size(); ++i) {
if (auto&& t = ret[i]) {
receiver(i, std::move(t));
}
}
}
std::string GradSlot::to_string() const {
bool has_callback = bool(callback);
return ssprintf(
"GradSlot{grad=%s, has_callback=%d}", m_grad.to_string().c_str(),
(int)has_callback);
}
std::string GradFn::to_string() const {
return ssprintf("GradFn{dests=%s}", imperative::to_string(m_dests).c_str());
}
std::string GradSlotPtr::to_string() const {
if (!m_fn) {
return "<empty>";
}
return (*this)->to_string();
}
std::string GradValue::to_string() const {
return ssprintf(
"GradValue{key=\"%s\", slot=%s, value=%s}", m_key->name().c_str(),
m_slot.to_string().c_str(), m_value.to_string().c_str());
}
static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule>&
get_backward_rule_storage() {
static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule> sl_storage;
return sl_storage;
}
bool CustomBackward::register_grad_rule(Typeinfo* typeinfo, BackwardRule rule) {
return get_backward_rule_storage().insert({typeinfo, rule}).second;
}
auto CustomBackward::lookup_grad_rule(Typeinfo* typeinfo) -> BackwardRule {
auto iter = get_backward_rule_storage().find(typeinfo);
if (iter == get_backward_rule_storage().end()) {
return {};
}
return iter->second;
}
void GradKey::backward() {
mgb_assert(m_frozen);
auto& tape = m_frozen_tape;
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto& [grad_fn, op] = tape[k];
auto grad_receiver = [&, grad_fn = grad_fn](size_t i, ValueRef grad) {
auto& dest = grad_fn->m_dests[i];
if (dest) {
auto& existing_grad = dest->m_grad;
if (!existing_grad) {
existing_grad = grad;
} else {
existing_grad = imperative::apply(
ApplyOp(*Elemwise::make(Elemwise::Mode::ADD)),
existing_grad, grad)[0];
}
}
};
// clang-format off
std::visit([&, grad_fn = grad_fn, op = op](auto&& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_throw(AssertionError, "invalid backward");
} else {
mgb_assert(grad_fn->m_slots.size() > 0);
std::vector<ValueRef> grads;
for (auto&& slot : grad_fn->m_slots) {
grads.push_back(slot.m_grad);
}
backward(grads, grad_receiver);
}
}, grad_fn->m_backward);
// clang-format on
for (auto&& dest : grad_fn->m_dests) {
if (!dest) {
continue;
}
if (!dest.m_producer_record.next && dest->callback && dest->m_grad) {
// I'm the last grad producer, invoke callback
dest->callback(dest->m_grad);
}
}
grad_fn->clear();
}
tape.clear();
}
GradValue::ref_t GradKey::attach(
ValueRef tensor, std::function<void(ValueRef)> callback) {
auto grad_value = tensor.as_ref<GradValue>();
if (grad_value && grad_value->has_key(shared_from_this())) {
mgb_assert(
!tensor.cast<GradValue>().slot_for(shared_from_this())->callback,
"callback exists");
} else {
GradSlotPtr grad_slot;
auto& grad_fn = grad_slot.m_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn->m_key = shared_from_this();
grad_fn->m_slots.resize(1);
grad_slot.m_index = 0;
grad_value = GradValue::make(tensor, shared_from_this(), grad_slot);
}
grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback;
return grad_value;
}
void GradKey::freeze() {
mgb_assert(m_frozen_tape.empty() && !m_frozen);
for (auto&& [grad_fn, op] : m_tape) {
if (auto valid_grad_fn = grad_fn.lock()) {
m_frozen_tape.push_back({valid_grad_fn, op});
}
}
m_tape.clear();
m_frozen = true;
}
std::vector<ValueRef> GradTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> {
SmallVector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
unwrapped_inputs.push_back(grad_value->m_value);
} else {
unwrapped_inputs.push_back(input);
}
}
return unwrapped_inputs;
};
if (m_suppressed) {
return imperative::apply(op, unwrap_inputs(inputs));
}
if (auto* op_val = op.as<ApplyOp>()) {
size_t nr_require_grad = 0;
SmallVector<bool> require_grads;
for (auto&& input : inputs) {
if (is_grad_value(input)) {
nr_require_grad++;
require_grads.push_back(true);
} else {
require_grads.push_back(false);
}
}
if (nr_require_grad == 0) {
return imperative::apply(op, inputs);
}
SmallVector<ValueRef> captured_inputs;
SmallVector<bool> inputs_require_grad;
// capture value so that trace could assume input as same
auto capture_value = [](ValueRef value) {
// TODO: fastpath copy shouldn't be an OpDef
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0];
};
for (auto& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
captured_inputs.push_back(capture_value(grad_value->m_value));
inputs_require_grad.push_back(true);
} else {
captured_inputs.push_back(capture_value(input));
inputs_require_grad.push_back(false);
}
}
decltype(std::declval<GradFn>().m_backward) backward_storage;
auto outputs = [&] {
auto backward_rule =
CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo());
if (backward_rule) {
CustomBackward backward;
auto optional_outputs = backward_rule(
op_val->op(), {captured_inputs.data(), captured_inputs.size()},
{inputs_require_grad.data(), inputs_require_grad.size()},
backward);
if (optional_outputs) {
backward_storage = backward;
// backward by rule
return *optional_outputs;
}
}
auto outputs = imperative::apply(
op, {captured_inputs.begin(), captured_inputs.end()});
auto backward_graph = make_optimized_backward_graph(
op.cast<ApplyOp>().op().shared_from_this(),
{captured_inputs.begin(), captured_inputs.end()},
{outputs.data(), outputs.size()},
{inputs_require_grad.data(), inputs_require_grad.size()});
if (backward_graph) {
backward_storage = BackwardGraphWithClosure(
backward_graph, op.cast<ApplyOp>().op().shared_from_this(),
{captured_inputs.begin(), captured_inputs.end()},
{outputs.data(), outputs.size()});
// backward by make_backward_graph
return outputs;
} else {
// no backward
return outputs;
}
}();
if (std::holds_alternative<std::monostate>(backward_storage)) {
return outputs;
}
auto grad_fn = std::make_shared<GradFn>();
grad_fn->m_key = m_key;
grad_fn->m_slots.resize(outputs.size());
grad_fn->m_backward = backward_storage;
mgb_assert(!outputs.empty());
grad_fn->m_dests.reserve(inputs.size());
// clang-format off
std::visit([&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_throw(AssertionError, "invalid backward");
} else {
for (size_t i = 0; i < inputs.size(); ++i) {
if (backward.input_has_grad(i) && require_grads[i]) {
auto& input_grad_slot =
inputs[i].cast<GradValue>().slot_for(m_key);
grad_fn->m_dests.emplace_back(input_grad_slot);
grad_fn->m_dests.back().m_producer_record.insert_after(
input_grad_slot->m_producer_head);
} else {
grad_fn->m_dests.emplace_back();
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
outputs[i] = record_grad(grad_value);
}
}
}
}, grad_fn->m_backward);
// clang-format on
mgb_assert(!grad_fn->m_slots.empty());
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
return outputs;
} else if (auto* attach_grad = op.as<AttachGrad>()) {
if (!has_key(attach_grad->key())) {
return imperative::apply(op, unwrap_inputs(inputs));
}
auto tensor = inputs[0];
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>();
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) {
auto ret = callback({&grad, 1});
assert(ret.empty());
});
return {record_grad(output)};
} else if (auto* grad_backward = op.as<GradBackward>()) {
if (!has_key(grad_backward->key())) {
return imperative::apply(op, unwrap_inputs(inputs));
}
size_t nr_grads = inputs.size() / 2;
mgb_assert(nr_grads * 2 == inputs.size());
auto values = inputs.sub(0, nr_grads);
auto grads = inputs.sub(nr_grads, nr_grads);
make_backward_closure(values)(grads);
return {};
} else if (auto* is_attached_to = op.as<IsAttachedTo>()) {
if (has_key(is_attached_to->key())) {
if (auto grad_value = as_grad_value(inputs[0])) {
// TODO: assert grad_fn
return {BoolValue::make(true)};
}
}
return {BoolValue::make(false)};
} else if (auto* set_grad = op.as<SetGrad>()) {
// TODO: merge SetGrad and ApplyOp
auto grad_fn = std::make_shared<GradFn>();
auto& backward =
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward());
size_t nr_inputs = set_grad->nr_inputs();
mgb_assert(inputs.size() > nr_inputs);
size_t nr_outputs = inputs.size() - nr_inputs;
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs};
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs};
backward.m_input_has_grad = SmallVector(nr_inputs, true);
backward.m_output_attrs =
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true});
backward.m_backward = set_grad->grad_fn();
std::vector<ValueRef> outputs;
grad_fn->m_key = m_key;
grad_fn->m_slots.resize(nr_outputs);
grad_fn->m_dests.reserve(nr_inputs);
for (size_t i = 0; i < nr_inputs; ++i) {
if (auto grad_value = as_grad_value(inputs_[i])) {
auto& input_grad_slot = grad_value->m_slot;
grad_fn->m_dests.emplace_back(grad_value->m_slot);
grad_fn->m_dests.back().m_producer_record.insert_after(
input_grad_slot->m_producer_head);
} else {
grad_fn->m_dests.emplace_back();
}
}
for (size_t i = 0; i < nr_outputs; ++i) {
auto& output = outputs_[i];
auto grad_value = as_grad_value(output);
if (grad_value) {
grad_value = GradValue::make(
grad_value->m_value, m_key, GradSlotPtr(grad_fn, i));
} else {
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i));
}
outputs.push_back(record_grad(grad_value));
}
m_key->m_tape.push_back({grad_fn, nullptr});
return outputs;
} else if (auto* gbc = op.as<GetBackwardColsure>()) {
if (gbc->key() != m_key) {
return imperative::apply(op, unwrap_inputs(inputs));
}
return {FunctionValue::make(make_backward_closure(inputs))};
} else if (op.is<DetachGrad>()) {
if (auto grad_value = as_grad_value(inputs[0])) {
return {grad_value->m_value};
} else {
return {inputs[0]};
}
} else if (op.is<GetGradKey>()) {
for (auto&& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
return {GradKeyValue::make(grad_value->m_key)};
}
}
return imperative::apply(op, inputs);
} else if (op.kind() == Operator::IdentityLike) {
mgb_assert(inputs.size() == 1);
if (auto grad_value = as_grad_value(inputs[0])) {
auto output = imperative::apply(op, grad_value->m_value)[0];
auto grad_output = GradValue::make(
output, grad_value->key(), grad_value->slot_for(m_key));
return {record_grad(grad_output)};
} else {
return imperative::apply(op, inputs);
}
} else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs);
} else {
SmallVector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
unwrapped_inputs.push_back(grad_value->m_value);
} else {
unwrapped_inputs.push_back(input);
}
}
auto outputs = imperative::apply(
op, {unwrapped_inputs.data(), unwrapped_inputs.size()});
mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty());
return outputs;
}
}
GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
// reset GradKey
auto grad_key = m_key;
std::vector<GradSlotPtr> y_slots;
for (auto&& y : ys) {
if (auto grad_value = as_grad_value(y)) {
y_slots.push_back(grad_value->slot_for(grad_key));
} else {
y_slots.emplace_back();
}
}
GenericFunction closure = [grad_key,
y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> {
size_t nr_grads = y_slots.size();
mgb_assert(dys.size() == nr_grads);
for (size_t i = 0; i < nr_grads; ++i) {
if (y_slots[i]) {
y_slots[i]->m_grad = dys[i];
}
}
grad_key->backward();
return {};
};
grad_key->freeze();
cleanup();
return closure;
}
void GradTransformation::on_unregister() noexcept {
cleanup();
}
void GradTransformation::cleanup() {
for (auto&& weak_value : m_weak_values) {
auto grad_value = weak_value.lock();
if (grad_value) {
mgb_assert(grad_value->m_key == m_key);
grad_value.reset(grad_value->m_value);
}
}
m_weak_values.clear();
m_key = {};
}
void GradTransformation::suppress() {
m_suppressed++;
}
void GradTransformation::resume() {
m_suppressed--;
}
} // namespace imperative
} // namespace mgb
/**
* \file imperative/src/include/megbrain/imperative/grad.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <variant>
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/intrusive_list.h"
#include "megbrain/imperative/utils/to_string.h"
namespace mgb::imperative {
struct BackwardGraphWithClosure {
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph;
SmallVector<ValueRef> closure;
size_t output_mask_offset;
size_t grad_mask_offset;
BackwardGraphWithClosure(
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs);
void operator()(
std::vector<ValueRef> grads,
std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }
bool output_requires_grad(size_t i) {
return backward_graph->save_for_backward[grad_mask_offset + i];
}
bool output_captured(size_t i) {
return backward_graph->save_for_backward[output_mask_offset + i];
}
};
struct CustomBackward;
using GradRuleFn =
std::function<std::vector<ValueRef>(Span<ValueRef> inputs, CustomBackward&)>;
struct CustomBackward {
using BackwardFn = std::function<std::vector<ValueRef>(Span<ValueRef>)>;
using BackwardRule = std::function<std::optional<std::vector<ValueRef>>(
const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>;
BackwardFn m_backward;
SmallVector<bool, 8> m_input_has_grad;
struct OutputAttr {
bool requires_grad = true, captured = true;
};
SmallVector<OutputAttr> m_output_attrs;
public:
void operator()(
std::vector<ValueRef> grads,
std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
bool output_captured(size_t i) { return m_output_attrs[i].captured; }
static bool register_grad_rule(Typeinfo* typeinfo, BackwardRule rule);
static BackwardRule lookup_grad_rule(Typeinfo* typeinfo);
};
class GradSlot;
class GradSlotPtr;
class GradSlotProducerPtr;
class GradFn;
class GradKey;
struct GradProducerRecord : utils::intrusive_list::Node<GradProducerRecord> {
using Node = utils::intrusive_list::Node<GradProducerRecord>;
GradProducerRecord() = default;
GradProducerRecord(head_t& head) : Node(utils::intrusive_list::after_t{}, head) {}
};
class GradSlot {
private:
ValueRef m_grad;
GradProducerRecord::head_t m_producer_head;
std::function<void(ValueRef)> callback;
public:
std::string to_string() const;
friend class GradKey;
friend class GradSlotProducerPtr;
friend class GradTransformation;
};
template <>
struct ToStringTrait<GradSlot> {
std::string operator()(const GradSlot& value) const { return value.to_string(); }
};
class GradFn {
private:
std::weak_ptr<GradKey> m_key;
std::vector<GradSlot> m_slots;
std::vector<GradSlotProducerPtr> m_dests;
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward;
public:
void clear() {
m_key.reset();
m_slots.clear();
m_dests.clear();
m_backward.emplace<std::monostate>();
}
std::string to_string() const;
friend class GradSlotPtr;
friend class GradKey;
friend class GradTransformation;
};
class GradSlotPtr {
private:
std::shared_ptr<GradFn> m_fn;
size_t m_index = 0;
public:
GradSlotPtr(std::shared_ptr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {}
GradSlotPtr() = default;
GradSlot* operator->() const { return &m_fn->m_slots[m_index]; }
operator bool() const { return bool(m_fn); }
std::string to_string() const;
friend class GradKey;
friend class GradTransformation;
};
template <>
struct ToStringTrait<GradSlotPtr> {
std::string operator()(const GradSlotPtr& value) const { return value.to_string(); }
};
class GradSlotProducerPtr : public GradSlotPtr {
private:
GradProducerRecord m_producer_record;
bool dirty = false;
public:
GradSlotProducerPtr(const GradSlotPtr& info)
: GradSlotPtr(info), m_producer_record(info->m_producer_head) {}
GradSlotProducerPtr() = default;
GradSlotProducerPtr(GradSlotProducerPtr&&) = default;
~GradSlotProducerPtr() { dirty = true; }
friend class GradKey;
friend class GradTransformation;
};
template <>
struct ToStringTrait<GradSlotProducerPtr> {
std::string operator()(const GradSlotProducerPtr& value) const {
return value.to_string();
}
};
class GradValue final : public ValueImpl<GradValue> {
private:
ValueRef m_value;
std::shared_ptr<GradKey> m_key;
GradSlotPtr m_slot;
public:
GradValue(ValueRef value, std::shared_ptr<GradKey> key, GradSlotPtr slot = {})
: m_value(value), m_key(key), m_slot(slot) {}
std::string to_string() const override;
bool has_key(std::shared_ptr<GradKey> key) const { return m_key == key; }
const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const {
mgb_assert(m_key == key);
return m_slot;
}
std::shared_ptr<GradKey> key() const { return m_key; }
void clear() override {
m_slot = {};
m_value = {};
m_key = nullptr;
}
void on_watch() override { m_value.watch(); }
void on_unwatch() override { m_value.unwatch(); }
friend class GradKey;
friend class GradTransformation;
};
class GradKey : public std::enable_shared_from_this<GradKey> {
private:
std::string m_name;
std::vector<std::pair<std::weak_ptr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<std::shared_ptr<GradFn>, std::shared_ptr<OpDef>>>
m_frozen_tape;
bool m_frozen = false;
public:
void backward();
GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback);
const std::string& name() const { return m_name; }
void name(std::string name) { m_name = std::move(name); }
void freeze();
friend class GradTransformation;
};
class GradKeyValue final
: public MixinValueImpl<GradKeyValue, std::shared_ptr<GradKey>> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf("GradKey{%s}", (*this)->name().c_str());
}
};
class GradTransformation final : public Transformation {
private:
std::shared_ptr<GradKey> m_key;
std::vector<GradValue::weak_ref_t> m_weak_values;
size_t m_suppressed = 0;
public:
GradTransformation(std::shared_ptr<GradKey> key) : m_key(key) {}
auto record_grad(GradValue::ref_t tensor) {
m_weak_values.push_back(tensor);
return tensor;
}
bool is_grad_value(ValueRef value) {
if (auto* grad_value = value.as<GradValue>()) {
if (grad_value->has_key(m_key)) {
return true;
}
}
return false;
}
/**
* \brief test whether value is related to this GradTransformation
*
* there may be multiple grad transformations, so simply using value.is<GradValue>()
* is unsafe
*
* \param value
* \return GradValue::ref_t
*/
GradValue::ref_t as_grad_value(ValueRef value) {
if (auto grad_value = value.as_ref<GradValue>()) {
if (grad_value->has_key(m_key)) {
return grad_value;
}
}
return {};
}
bool has_key(std::shared_ptr<GradKey> key) {
if (key == m_key) {
return true;
}
return false;
}
std::vector<ValueRef> apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
if (auto grad_val = as_grad_value(value)) {
return grad_val->m_value;
}
return value;
}
std::string name() const override { return "GradTransformation"; }
GenericFunction make_backward_closure(Span<ValueRef> ys);
void on_unregister() noexcept override;
void cleanup();
void suppress();
void resume();
};
class DetachGrad : public OperatorImpl<DetachGrad, Operator::IdentityLike> {
private:
// TODO: identified by GradKey
public:
std::string to_string() const override { return "DetachValue"; }
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {inputs.as_array<1>()[0]};
}
};
class AttachGrad : public OperatorImpl<AttachGrad> {
private:
std::shared_ptr<GradKey> m_key;
public:
AttachGrad(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::string to_string() const override {
return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str());
}
};
class GradBackward : public OperatorImpl<GradBackward, Operator::GetAttrLike> {
private:
std::shared_ptr<GradKey> m_key;
public:
GradBackward(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::string to_string() const override {
return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str());
}
};
class IsAttachedTo : public OperatorImpl<IsAttachedTo, Operator::GetAttrLike> {
private:
std::shared_ptr<GradKey> m_key;
public:
IsAttachedTo(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::string to_string() const override {
return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str());
}
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {BoolValue::make(false)};
}
};
class SetGrad : public OperatorImpl<SetGrad> {
private:
std::shared_ptr<GradKey> m_key;
GenericFunction m_grad_fn;
size_t m_nr_inputs;
public:
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs)
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
GenericFunction grad_fn() { return m_grad_fn; }
size_t nr_inputs() { return m_nr_inputs; }
std::string to_string() const override {
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str());
}
};
class GetGradKey : public OperatorImpl<GetGradKey, Operator::GetAttrLike> {
public:
GetGradKey() = default;
std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); }
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {ValueRef()};
}
};
class GetBackwardColsure
: public OperatorImpl<GetBackwardColsure, Operator::GetAttrLike> {
private:
std::shared_ptr<GradKey> m_key;
public:
GetBackwardColsure(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::string to_string() const override {
return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str());
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册