提交 60c44b08 编写于 作者: M Megvii Engine Team

refactor(mge): refactor to prepare for custom grad rules

GitOrigin-RevId: 4bd8850fdfe28069f68a7248e6a7efb3d46a7384
上级 61f65cd4
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace views = ranges::views;
namespace mgb::imperative::python { namespace mgb::imperative::python {
...@@ -25,6 +28,152 @@ struct GradSlotWeakPtr { ...@@ -25,6 +28,152 @@ struct GradSlotWeakPtr {
size_t idx; size_t idx;
}; };
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
std::shared_ptr<BackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
static_assert(alignof(size_t) % alignof(bool) == 0);
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
alignas(alignof(size_t)) std::byte buf[buf_size];
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
bool* bool_ptr0 = bool_ptr;
*(size_t_ptr++) = ctx.op->hash();
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
// slow path
SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
SmallVector<bool> input_requires_grad(ctx.nargs, false);
SmallVector<bool> output_has_grad(outputs.size(), true);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
}
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad));
if (!result->backward) {
result.reset();
}
backward_graph_cache.emplace(key, result);
return result;
}
struct BackwardGraphWithClosure {
std::shared_ptr<BackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure;
size_t output_mask_offset;
size_t grad_mask_offset;
BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_,
ApplyContext& ctx, const apply_result_t& outputs)
: backward_graph(backward_graph_),
output_mask_offset(ctx.nargs),
grad_mask_offset(ctx.nargs + outputs.size()) {
// save_for_backward[0:nargs]:
// whether input is kept for backward
//
// save_for_backward[nargs:nargs+outputs.size()]:
// whether output is kept for backward
//
// save_for_backward[-outputs.size():]:
// whether gradient of output can propagate to any input
//
// Example:
// perform c = a * b, with a.requires_grad == True and
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
closure.reserve(std::count_if(save_for_backward.begin(),
save_for_backward.end(),
ranges::identity{}));
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
closure.push_back(ctx.args[i]->shared_from_this());
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (save_for_backward[ctx.nargs + i]) {
closure.push_back(outputs[i]);
}
}
}
template <typename T, typename R>
void operator()(T&& grads, R&& receiver) {
Tensor* args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& t : closure) {
args[nargs++] = t.get();
}
bool null_grad = false;
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (grads[i]) {
if (null_grad) {
PyErr_SetString(PyExc_NotImplementedError, "report to devs");
throw py::error_already_set();
}
args[nargs++] = grads[i];
} else {
null_grad = true;
}
}
}
if (null_grad) return;
ApplyContext ctx;
ctx.op = backward_graph->backward;
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
mgb_assert(args[i]);
}
auto igrads = apply(ctx);
auto&& it = igrads.begin();
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
if (p) {
receiver(i, std::move(*it));
++it;
}
}
}
bool input_has_grad(size_t i) {
return backward_graph->input_has_grad[i];
}
bool output_requires_grad(size_t i) {
return backward_graph->save_for_backward[grad_mask_offset + i];
}
bool output_captured(size_t i) {
return backward_graph->save_for_backward[output_mask_offset + i];
}
};
} // namespace } // namespace
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
...@@ -54,10 +203,15 @@ struct GradFn : std::enable_shared_from_this<GradFn> { ...@@ -54,10 +203,15 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
static MemPool<GradFn> pool; static MemPool<GradFn> pool;
std::weak_ptr<GradKey> key; std::weak_ptr<GradKey> key;
// slots for receiving and accumulating grads
// same length as outputs (of forward op)
SmallVector<GradSlot> slots; SmallVector<GradSlot> slots;
// where to send and accumulate grads
// same length as inputs (of forward op)
SmallVector<GradSlotProducerPtr> dsts; SmallVector<GradSlotProducerPtr> dsts;
SmallVector<std::shared_ptr<Tensor>> closure; // encapsules actual function to compute gradient
std::shared_ptr<BackwardGraphResult> backward_graph; std::variant<std::monostate, BackwardGraphWithClosure> backward;
// a flag used during backward
bool in_ref_keeper = false; bool in_ref_keeper = false;
static void deleter(GradFn* ptr) { static void deleter(GradFn* ptr) {
...@@ -72,8 +226,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { ...@@ -72,8 +226,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
key.reset(); key.reset();
slots.clear(); slots.clear();
dsts.clear(); dsts.clear();
closure.clear(); backward.emplace<std::monostate>();
backward_graph.reset();
} }
}; };
...@@ -83,54 +236,36 @@ GradSlot* GradSlotPtr::operator->() { ...@@ -83,54 +236,36 @@ GradSlot* GradSlotPtr::operator->() {
namespace { namespace {
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { class GradFnHelper {
std::shared_ptr<void> on_comp_node_finalize() override { std::shared_ptr<GradFn> grad_fn;
clear();
return {};
}
} backward_graph_cache;
std::shared_ptr<BackwardGraphResult> make_backward_graph( GradFn* get() {
ApplyContext& ctx, const apply_result_t& outputs) { if (!grad_fn) {
// hash grad_fn = std::make_shared<GradFn>();
static_assert(alignof(size_t) % alignof(bool) == 0); }
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); return grad_fn.get();
alignas(alignof(size_t)) std::byte buf[buf_size];
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
bool* bool_ptr0 = bool_ptr;
*(size_t_ptr++) = ctx.op->hash();
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
} }
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key); friend apply_result_t imperative::python::apply_grad(ApplyContext&);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
// slow path public:
SmallVector<LogicalTensorDesc> inputs(ctx.nargs); template<typename T, typename... Args>
SmallVector<bool> input_requires_grad(ctx.nargs, false); auto& emplace(Args&&... args) {
SmallVector<bool> output_has_grad(outputs.size(), true); return get()->backward.emplace<T>(std::forward<Args>(args)...);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
} }
auto result = std::make_shared<BackwardGraphResult>( };
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad)); apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
if (!result->backward) { auto outputs = apply(ctx);
result.reset();
auto backward_graph = make_backward_graph(ctx, outputs);
if (!backward_graph) {
return outputs;
} }
backward_graph_cache.emplace(key, result);
return result; ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs);
return outputs;
} }
} // namespace } // namespace
...@@ -164,76 +299,53 @@ apply_result_t apply_grad(ApplyContext& ctx) { ...@@ -164,76 +299,53 @@ apply_result_t apply_grad(ApplyContext& ctx) {
ctx.flags &= ~Tensor::Flags::GRAD; ctx.flags &= ~Tensor::Flags::GRAD;
// perform forward apply_op or trace
auto outputs = apply(ctx);
if (!grad_key) { if (!grad_key) {
return outputs; return apply(ctx);
} }
auto backward_graph = make_backward_graph(ctx, outputs); GradFnHelper grad_fn_holder;
if (!backward_graph) { auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder);
auto& grad_fn = grad_fn_holder.grad_fn;
if (!grad_fn) {
return outputs; return outputs;
} }
auto grad_fn = std::make_shared<GradFn>();
grad_fn->key = grad_key; grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size()); grad_fn->slots.resize(outputs.size());
grad_fn->backward_graph = std::move(backward_graph);
grad_fn->dsts.reserve(ctx.nargs); grad_fn->dsts.reserve(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
auto& save_for_backward = grad_fn->backward_graph->save_for_backward; std::visit([&](auto& backward) {
grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;})); using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
// given op, taking gradient of output_tensor_list wrt input_tensor_list: mgb_assert(0);
// } else {
// save_for_backward[0:nargs-1]: whether input tensor requires gradient, for (size_t i = 0; i < ctx.nargs; ++i) {
// i.e., whether it is in input_tensor_list if (backward.input_has_grad(i)) {
// auto& input_grad_info = ctx.args[i]->m_grad_info;
// save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is grad_fn->dsts.emplace_back(input_grad_info);
// needed to calculate gradients // register as grad producer
// grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
// save_for_backward[-outputs.size():]: whether output tensor is in } else {
// output_tensor_list grad_fn->dsts.emplace_back();
// }
// Example: perform c = a * b, where a is input data, b is parameter to be }
// optimized, save_for_backward = [1, 1, 0, 1] for (size_t i = 0; i < outputs.size(); ++i) {
mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size()); if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// record input tensors needed to take grad // avoid reference cycle [Tensor <-> GradFn]
for (size_t i = 0; i < ctx.nargs; ++i) { outputs[i] = outputs[i]->copy();
if (save_for_backward[i]) { }
grad_fn->closure.push_back(ctx.args[i]->shared_from_this()); // populate grad info of output tensor
} auto& grad_info = outputs[i]->m_grad_info;
} grad_info.grad_fn = grad_fn;
// record output tensors needed to take grad grad_info.idx = i;
for (size_t i = 0; i < outputs.size(); ++i) { grad_info.insert_after(grad_key->free_vars_head);
bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i]; outputs[i]->m_flags |= Tensor::Flags::GRAD;
if (save_for_backward[ctx.nargs + i]) { }
grad_fn->closure.push_back(outputs[i]);
if (requires_grad) {
// avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy();
} }
} }
if (requires_grad) { }, grad_fn->backward);
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Tensor::Flags::GRAD;
}
}
// record forward history // record forward history
grad_key->tape.emplace_back(grad_fn); grad_key->tape.emplace_back(grad_fn);
...@@ -334,54 +446,30 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -334,54 +446,30 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock(); auto&& grad_fn = tape[k].lock();
if (!grad_fn) continue; if (!grad_fn) continue;
if (grad_fn->backward_graph) {
for (size_t i = 0; i < grad_fn->slots.size(); ++i) {
// grad_fn->dsts correspond to input tensors during forward
// calculation, grad_fn->slots correspond to output tensors.
// condition true means the output tensor has gradient for
// back-propagation
if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) {
grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad));
}
}
ApplyContext ctx;
ctx.op = grad_fn->backward_graph->backward;
ctx.flags = 0;
ctx.nargs = grad_fn->closure.size();
Tensor* args[ctx.nargs];
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = grad_fn->closure[i].get();
mgb_assert(args[i]);
ctx.flags |= args[i]->m_flags;
}
ctx.args = args;
if (is_tracing)
ctx.flags |= Tensor::Flags::TRACE;
auto grads = apply(ctx); auto grad_receiver = [&](size_t i, auto&& g) {
accum_grad(grad_fn->dsts[i]->grad, std::forward<decltype(g)>(g));
size_t j = 0; };
for (size_t i = 0; i < grad_fn->dsts.size(); ++i) { std::visit([&](auto&& backward) {
if (grad_fn->backward_graph->input_has_grad[i]) { using T = std::decay_t<decltype(backward)>;
auto& dst = grad_fn->dsts[i]; if constexpr (std::is_same_v<T, std::monostate>) {
// grads[j] is consumed in accum_grad mgb_assert(0);
accum_grad(dst->grad, std::move(grads[j])); } else {
++j; auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
} backward(std::forward<decltype(grads)>(grads), grad_receiver);
} }
mgb_assert(j == grads.size()); }, grad_fn->backward);
}
for (auto&& dst : grad_fn->dsts) { for (auto&& dst : grad_fn->dsts) {
if (!dst.grad_fn) continue; if (!dst.grad_fn) continue;
if (!dst.grad_fn->in_ref_keeper) { if (!dst.grad_fn->in_ref_keeper) {
// after grad_fn is cleared, refcnt of subsequent grad_fn
// could drop to 0
dst.grad_fn->in_ref_keeper = true; dst.grad_fn->in_ref_keeper = true;
ref_keeper.push_back(dst.grad_fn); ref_keeper.push_back(dst.grad_fn);
} }
// grad_fn->clear will unlink current dst.producer_record
// such that if dst.producer_record.next is false, dst accumulates
// all the gradients
if (!dst.producer_record.next && dst->callback && dst->grad) { if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback
dst->callback(TensorWrapper::make(pytype, dst->grad)); dst->callback(TensorWrapper::make(pytype, dst->grad));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册