提交 9fb5581f 编写于 作者: M Megvii Engine Team

refactor(mge): add specialized grad rule support

GitOrigin-RevId: 141ff0a24f0f843ff5457c06e303332eb4276ef6
上级 645fc6f0
......@@ -70,7 +70,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs[i].comp_node = ctx.args[i]->comp_node();
inputs[i].layout.dtype = ctx.args[i]->dtype();
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn);
input_requires_grad[i] = python::input_requires_grad(ctx, i);
}
auto result = std::make_shared<BackwardGraphResult>(
proxy_graph_detail::make_backward_graph(
......@@ -82,21 +82,6 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
return result;
}
struct BackwardContext {
PyTypeObject* pytype = nullptr;
auto wrap_tensor(std::shared_ptr<Tensor> t) {
if (pytype) {
return TensorWrapper::make(pytype, std::move(t));
}
return TensorWrapper::make(std::move(t));
}
auto wrap_tensor(Tensor* t) {
return wrap_tensor(t->shared_from_this());
}
};
struct BackwardGraphWithClosure {
std::shared_ptr<BackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure;
......@@ -270,7 +255,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
// same length as inputs (of forward op)
SmallVector<GradSlotProducerPtr> dsts;
// encapsules actual function to compute gradient
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward;
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward;
// a flag used during backward
bool in_ref_keeper = false;
......@@ -335,8 +320,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = (scoped_disable(Flags::GRAD),
py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
......@@ -388,9 +372,25 @@ apply_result_t apply_grad(ApplyContext& ctx) {
}
GradFnHelper grad_fn_holder;
auto outputs = ctx.op->same_type<GenericPyOp>() ?
python_grad_rule(ctx, grad_fn_holder) :
backward_graph_grad_rule(ctx, grad_fn_holder);
auto outputs = [&]() {
auto _ = scoped_disable(Flags::GRAD);
if (ctx.op->same_type<GenericPyOp>()) {
return python_grad_rule(ctx, grad_fn_holder);
}
auto&& registry = grad_rule_registry();
auto&& it = registry.find(ctx.op->dyn_typeinfo());
if (it != registry.end()) {
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx);
try {
auto ret = it->second(ctx, maker);
maker.finalize();
return ret;
} catch (GradRuleFallback&) {
grad_fn_holder.emplace<std::monostate>();
}
}
return backward_graph_grad_rule(ctx, grad_fn_holder);
}();
auto& grad_fn = grad_fn_holder.grad_fn;
if (!grad_fn) {
......@@ -407,7 +407,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
mgb_assert(0);
} else {
for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i)) {
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer
......@@ -487,18 +487,8 @@ void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
grad = std::forward<T>(delta);
return;
}
static ApplyContext ctx;
if (!ctx.op) {
ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
ctx.nargs = 2;
}
Tensor* args[2] = {grad.get(), delta.get()};
ctx.args = args;
ctx.flags = grad->m_flags | delta->m_flags;
if (is_tracing) {
ctx.flags |= Flags::TRACE;
}
grad = apply(ctx)[0];
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD));
grad = apply(op, grad, std::forward<T>(delta))[0];
}
void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) {
......@@ -582,4 +572,9 @@ GradKey::~GradKey() {
cleanup();
}
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
static std::unordered_map<Typeinfo*, GradRuleFn> registry;
return registry;
}
} // namespace mgb::imperative::python
......@@ -45,6 +45,117 @@ struct GradKeyWrapper {
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
};
struct BackwardContext {
PyTypeObject* pytype = nullptr;
auto wrap_tensor(std::shared_ptr<Tensor> t) {
if (pytype) {
return TensorWrapper::make(pytype, std::move(t));
}
return TensorWrapper::make(std::move(t));
}
auto wrap_tensor(Tensor* t) {
return wrap_tensor(t->shared_from_this());
}
};
struct CustomBackward {
using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>;
BackwardFn m_backward;
SmallVector<bool, 8> m_input_has_grad;
struct OutputAttr {bool requires_grad = true, captured = true;};
SmallVector<OutputAttr> m_output_attrs;
public:
template<typename T, typename R>
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
size_t nargs = grads.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = grads[i];
}
auto ret = m_backward(ctx, args, nargs);
for (size_t i = 0; i < ret.size(); ++i) {
if (auto&& t = ret[i]) {
receiver(i, std::move(t));
}
}
}
bool input_has_grad(size_t i) {return m_input_has_grad[i];}
bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;}
bool output_captured(size_t i) {return m_output_attrs[i].captured;}
class Maker {
bool output_size_set = false, input_has_grad_initialized = false;
CustomBackward& target;
ApplyContext& ctx;
void init_input_has_grad() {
if (!input_has_grad_initialized) {
input_has_grad_initialized = true;
target.m_input_has_grad.resize(ctx.nargs, true);
}
}
public:
Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {}
template<typename F>
Maker& backward(F&& f) {
mgb_assert(!target.m_backward);
target.m_backward = std::forward<F>(f);
return *this;
}
// mandatory
Maker& output_size(size_t sz) {
mgb_assert(!output_size_set);
output_size_set = true;
target.m_output_attrs.resize(sz);
return *this;
}
// optional, defaults to all true
Maker& input_has_grad(size_t i, bool v) {
init_input_has_grad();
target.m_input_has_grad.at(i) = v;
return *this;
}
// optional, defaults to all true
Maker& output_requires_grad(size_t i, bool v) {
target.m_output_attrs.at(i).requires_grad = v;
return *this;
}
// optional, defaults to all true
Maker& output_captured(size_t i, bool v) {
target.m_output_attrs.at(i).captured = v;
return *this;
}
void finalize() {
mgb_assert(output_size_set);
init_input_has_grad();
}
};
Maker maker(ApplyContext& ctx) {return {*this, ctx};}
};
using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>;
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
return bool(ctx.args[i]->m_grad_info.grad_fn);
}
struct GradRuleFallback : std::exception {};
template<typename T>
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) {
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second;
}
} // namespace mgb::imperative::python
namespace pybind11::detail {
......
/**
* \file imperative/python/src/grad_override.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb::imperative::python {
namespace {
std::shared_ptr<Tensor> get_shape(Tensor* x) {
static auto op = GetVarShape::make();
return python::apply(op, x)[0];
}
std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) {
static auto op = Reduce::make();
return python::apply(op, x, s)[0];
}
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Elemwise>();
if (op.mode == Elemwise::Mode::ADD) {
mgb_assert(ctx.nargs == 2);
std::array<std::shared_ptr<Tensor>, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
if (input_requires_grad(ctx, i)) {
input_shapes[i] = get_shape(ctx.args[i]);
}
}
maker.output_size(1).output_captured(0, false);
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i].get());
}
}
return ret;
});
return apply(ctx);
}
throw GradRuleFallback();
}
struct Init {
Init() {
auto& reg = grad_rule_registry();
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule);
}
} _;
} // namespace
} // namespace mgb::imperative::python
......@@ -199,12 +199,59 @@ using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
apply_result_t apply(ApplyContext& ctx);
void init_tensor(pybind11::module);
template <typename T>
decltype(auto) resolve_arrow(T&& p) {
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
auto* ret = p;
return ret;
} else {
auto probe = [](auto&& p) -> decltype(p.operator->()) {};
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) {
return resolve_arrow(p.operator->());
} else {
return p;
}
}
}
template <typename... Args>
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>);
extern bool is_tracing;
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable
extern bool is_symbolic;
extern bool is_compiled;
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0>
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
ApplyContext ctx;
Tensor* arg_arr[] = {resolve_arrow(args)...};
ctx.flags = (0 | ... | args->m_flags);
ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0;
ctx.args = arg_arr;
ctx.nargs = sizeof...(args);
ctx.op = std::move(op);
return apply(ctx);
}
template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
apply_result_t> {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0;
ctx.nargs = tensors.size();
Tensor* args[ctx.nargs];
ctx.args = args;
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}
void init_tensor(pybind11::module);
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册