提交 60c44b08 编写于 作者: M Megvii Engine Team

refactor(mge): refactor to prepare for custom grad rules

GitOrigin-RevId: 4bd8850fdfe28069f68a7248e6a7efb3d46a7384
上级 61f65cd4
......@@ -14,7 +14,10 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp"
namespace py = pybind11;
namespace views = ranges::views;
namespace mgb::imperative::python {
......@@ -25,6 +28,152 @@ struct GradSlotWeakPtr {
size_t idx;
};
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
std::shared_ptr<BackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
static_assert(alignof(size_t) % alignof(bool) == 0);
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
alignas(alignof(size_t)) std::byte buf[buf_size];
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
bool* bool_ptr0 = bool_ptr;
*(size_t_ptr++) = ctx.op->hash();
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
// slow path
SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
SmallVector<bool> input_requires_grad(ctx.nargs, false);
SmallVector<bool> output_has_grad(outputs.size(), true);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
}
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad));
if (!result->backward) {
result.reset();
}
backward_graph_cache.emplace(key, result);
return result;
}
struct BackwardGraphWithClosure {
std::shared_ptr<BackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure;
size_t output_mask_offset;
size_t grad_mask_offset;
BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_,
ApplyContext& ctx, const apply_result_t& outputs)
: backward_graph(backward_graph_),
output_mask_offset(ctx.nargs),
grad_mask_offset(ctx.nargs + outputs.size()) {
// save_for_backward[0:nargs]:
// whether input is kept for backward
//
// save_for_backward[nargs:nargs+outputs.size()]:
// whether output is kept for backward
//
// save_for_backward[-outputs.size():]:
// whether gradient of output can propagate to any input
//
// Example:
// perform c = a * b, with a.requires_grad == True and
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1]
auto& save_for_backward = backward_graph->save_for_backward;
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size());
closure.reserve(std::count_if(save_for_backward.begin(),
save_for_backward.end(),
ranges::identity{}));
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
closure.push_back(ctx.args[i]->shared_from_this());
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (save_for_backward[ctx.nargs + i]) {
closure.push_back(outputs[i]);
}
}
}
template <typename T, typename R>
void operator()(T&& grads, R&& receiver) {
Tensor* args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& t : closure) {
args[nargs++] = t.get();
}
bool null_grad = false;
for (size_t i = 0; i < grads.size(); ++i) {
if (backward_graph->save_for_backward[grad_mask_offset + i]) {
if (grads[i]) {
if (null_grad) {
PyErr_SetString(PyExc_NotImplementedError, "report to devs");
throw py::error_already_set();
}
args[nargs++] = grads[i];
} else {
null_grad = true;
}
}
}
if (null_grad) return;
ApplyContext ctx;
ctx.op = backward_graph->backward;
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
mgb_assert(args[i]);
}
auto igrads = apply(ctx);
auto&& it = igrads.begin();
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) {
if (p) {
receiver(i, std::move(*it));
++it;
}
}
}
bool input_has_grad(size_t i) {
return backward_graph->input_has_grad[i];
}
bool output_requires_grad(size_t i) {
return backward_graph->save_for_backward[grad_mask_offset + i];
}
bool output_captured(size_t i) {
return backward_graph->save_for_backward[output_mask_offset + i];
}
};
} // namespace
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
......@@ -54,10 +203,15 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
static MemPool<GradFn> pool;
std::weak_ptr<GradKey> key;
// slots for receiving and accumulating grads
// same length as outputs (of forward op)
SmallVector<GradSlot> slots;
// where to send and accumulate grads
// same length as inputs (of forward op)
SmallVector<GradSlotProducerPtr> dsts;
SmallVector<std::shared_ptr<Tensor>> closure;
std::shared_ptr<BackwardGraphResult> backward_graph;
// encapsules actual function to compute gradient
std::variant<std::monostate, BackwardGraphWithClosure> backward;
// a flag used during backward
bool in_ref_keeper = false;
static void deleter(GradFn* ptr) {
......@@ -72,8 +226,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
key.reset();
slots.clear();
dsts.clear();
closure.clear();
backward_graph.reset();
backward.emplace<std::monostate>();
}
};
......@@ -83,54 +236,36 @@ GradSlot* GradSlotPtr::operator->() {
namespace {
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
}
} backward_graph_cache;
class GradFnHelper {
std::shared_ptr<GradFn> grad_fn;
std::shared_ptr<BackwardGraphResult> make_backward_graph(
ApplyContext& ctx, const apply_result_t& outputs) {
// hash
static_assert(alignof(size_t) % alignof(bool) == 0);
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool);
alignas(alignof(size_t)) std::byte buf[buf_size];
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf);
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2));
bool* bool_ptr0 = bool_ptr;
*(size_t_ptr++) = ctx.op->hash();
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
GradFn* get() {
if (!grad_fn) {
grad_fn = std::make_shared<GradFn>();
}
return grad_fn.get();
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) {
return iter->second;
}
friend apply_result_t imperative::python::apply_grad(ApplyContext&);
// slow path
SmallVector<LogicalTensorDesc> inputs(ctx.nargs);
SmallVector<bool> input_requires_grad(ctx.nargs, false);
SmallVector<bool> output_has_grad(outputs.size(), true);
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
public:
template<typename T, typename... Args>
auto& emplace(Args&&... args) {
return get()->backward.emplace<T>(std::forward<Args>(args)...);
}
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad));
if (!result->backward) {
result.reset();
};
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
auto outputs = apply(ctx);
auto backward_graph = make_backward_graph(ctx, outputs);
if (!backward_graph) {
return outputs;
}
backward_graph_cache.emplace(key, result);
return result;
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs);
return outputs;
}
} // namespace
......@@ -164,76 +299,53 @@ apply_result_t apply_grad(ApplyContext& ctx) {
ctx.flags &= ~Tensor::Flags::GRAD;
// perform forward apply_op or trace
auto outputs = apply(ctx);
if (!grad_key) {
return outputs;
return apply(ctx);
}
auto backward_graph = make_backward_graph(ctx, outputs);
if (!backward_graph) {
GradFnHelper grad_fn_holder;
auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder);
auto& grad_fn = grad_fn_holder.grad_fn;
if (!grad_fn) {
return outputs;
}
auto grad_fn = std::make_shared<GradFn>();
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->backward_graph = std::move(backward_graph);
grad_fn->dsts.reserve(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
auto& save_for_backward = grad_fn->backward_graph->save_for_backward;
grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;}));
// given op, taking gradient of output_tensor_list wrt input_tensor_list:
//
// save_for_backward[0:nargs-1]: whether input tensor requires gradient,
// i.e., whether it is in input_tensor_list
//
// save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is
// needed to calculate gradients
//
// save_for_backward[-outputs.size():]: whether output tensor is in
// output_tensor_list
//
// Example: perform c = a * b, where a is input data, b is parameter to be
// optimized, save_for_backward = [1, 1, 0, 1]
mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size());
// record input tensors needed to take grad
for (size_t i = 0; i < ctx.nargs; ++i) {
if (save_for_backward[i]) {
grad_fn->closure.push_back(ctx.args[i]->shared_from_this());
}
}
// record output tensors needed to take grad
for (size_t i = 0; i < outputs.size(); ++i) {
bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i];
if (save_for_backward[ctx.nargs + i]) {
grad_fn->closure.push_back(outputs[i]);
if (requires_grad) {
// avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy();
std::visit([&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i)) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn]
outputs[i] = outputs[i]->copy();
}
// populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Tensor::Flags::GRAD;
}
}
}
if (requires_grad) {
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Tensor::Flags::GRAD;
}
}
}, grad_fn->backward);
// record forward history
grad_key->tape.emplace_back(grad_fn);
......@@ -334,54 +446,30 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock();
if (!grad_fn) continue;
if (grad_fn->backward_graph) {
for (size_t i = 0; i < grad_fn->slots.size(); ++i) {
// grad_fn->dsts correspond to input tensors during forward
// calculation, grad_fn->slots correspond to output tensors.
// condition true means the output tensor has gradient for
// back-propagation
if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) {
grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad));
}
}
ApplyContext ctx;
ctx.op = grad_fn->backward_graph->backward;
ctx.flags = 0;
ctx.nargs = grad_fn->closure.size();
Tensor* args[ctx.nargs];
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = grad_fn->closure[i].get();
mgb_assert(args[i]);
ctx.flags |= args[i]->m_flags;
}
ctx.args = args;
if (is_tracing)
ctx.flags |= Tensor::Flags::TRACE;
auto grads = apply(ctx);
size_t j = 0;
for (size_t i = 0; i < grad_fn->dsts.size(); ++i) {
if (grad_fn->backward_graph->input_has_grad[i]) {
auto& dst = grad_fn->dsts[i];
// grads[j] is consumed in accum_grad
accum_grad(dst->grad, std::move(grads[j]));
++j;
}
auto grad_receiver = [&](size_t i, auto&& g) {
accum_grad(grad_fn->dsts[i]->grad, std::forward<decltype(g)>(g));
};
std::visit([&](auto&& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
backward(std::forward<decltype(grads)>(grads), grad_receiver);
}
mgb_assert(j == grads.size());
}
}, grad_fn->backward);
for (auto&& dst : grad_fn->dsts) {
if (!dst.grad_fn) continue;
if (!dst.grad_fn->in_ref_keeper) {
// after grad_fn is cleared, refcnt of subsequent grad_fn
// could drop to 0
dst.grad_fn->in_ref_keeper = true;
ref_keeper.push_back(dst.grad_fn);
}
// grad_fn->clear will unlink current dst.producer_record
// such that if dst.producer_record.next is false, dst accumulates
// all the gradients
if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback
dst->callback(TensorWrapper::make(pytype, dst->grad));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册