提交 522e556b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(autodiff): support higher order grad

GitOrigin-RevId: 86390d217940d2240d6908a29a6956b90f3b7b2e
上级 5198b783
......@@ -20,6 +20,9 @@ class AttachSpec:
__slots__ = "tensor", "callbacks"
_global_priority = 0
class GradManager:
r"""
GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
......@@ -118,6 +121,7 @@ class GradManager:
self._grad = None
self._after_backward_callback = []
self._gradients = {}
self._priority = None
def attach(self, tensors: Iterable[Tensor], callbacks=None):
r"""
......@@ -293,6 +297,7 @@ class GradManager:
After this call, you will be able to call :meth:`backward`.
"""
global _global_priority
if self._recording:
raise RuntimeError("already recording")
grad = Grad()
......@@ -300,6 +305,9 @@ class GradManager:
self._grad = grad
for spec in self._attach_specs.values():
self._do_record(spec)
if self._priority is None:
grad._priority = _global_priority
_global_priority -= 1
grad.__enter__()
def _do_record(self, spec):
......@@ -321,11 +329,14 @@ class GradManager:
After this call, you will not be able to call :meth:`backward`.
"""
global _global_priority
if self._grad is not None:
self._grad.__exit__(None, None, None)
self._grad = None
self._recording = False
self._gradients = dict()
if self._priority is None:
_global_priority += 1
def __enter__(self):
self.record()
......@@ -333,3 +344,41 @@ class GradManager:
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
def __and__(self, other):
if isinstance(other, GradManager):
return GradManagerGroup([self, other])
return NotImplemented
__rand__ = __and__
class GradManagerGroup:
def __init__(self, gms) -> None:
self._gms = list(gms)
def merge_with(self, other):
if isinstance(other, GradManager):
other = GradManagerGroup([other])
elif not isinstance(other, GradManagerGroup):
return NotImplemented
return GradManagerGroup([*self._gms, *other._gms])
__and__ = merge_with
__rand__ = merge_with
__or__ = merge_with
__ror__ = merge_with
def __enter__(self):
global _global_priority
_global_priority += 1
for gm in self._gms:
gm._priority = _global_priority
gm.record()
def __exit__(self, exc_type, exc_val, exc_tb):
global _global_priority
_global_priority -= 1
for gm in self._gms:
gm.release()
gm._priority = None
......@@ -47,6 +47,14 @@ class Grad:
self._impl = GradKey(name)
_grad_manager_dict[self._name] = self
@property
def _priority(self):
return self._impl.priority
@_priority.setter
def _priority(self, priority):
self._impl.priority = priority
@property
def _name(self):
return self._impl.name
......
......@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn);
*(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty();
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
......@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info;
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
}
ApplyContext ctx_dup = ctx;
ctx_dup.args = inputs_copy_weak.data();
......@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
} // namespace
apply_result_t apply_grad(ApplyContext& ctx) {
std::shared_ptr<GradKey> grad_key;
std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
for (size_t i = 0; i < ctx.nargs; ++i) {
auto* tensor = ctx.args[i];
if (tensor->m_grad_info.grad_fn) {
auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock();
// tensor is attached to a live GradKey
if (input_grad_key && input_grad_key->active) {
if (grad_key) {
if (grad_key != input_grad_key) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
} else {
grad_key = std::move(input_grad_key);
if (!tensor->m_grad_info_dict.empty()) {
size_t grad_cnt = 0;
for (auto&& grad_info: tensor->m_grad_info_dict) {
auto input_grad_key = grad_info.grad_fn->key.lock();
if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) {
grad_keys.insert(input_grad_key);
grad_cnt++;
}
} else {
// cleanup stale grad info
// under what condition?
tensor->m_grad_info = {};
}
if (!grad_cnt) {
tensor->m_flags &= ~Flags::GRAD;
}
} else {
......@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
ctx.flags &= ~Flags::GRAD;
if (!grad_key) {
if (grad_keys.empty()) {
return apply(ctx);
}
......@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) {
return backward_graph_grad_rule(ctx, grad_fn_holder);
}();
auto& grad_fn = grad_fn_holder.grad_fn;
if (!grad_fn) {
if (!grad_fn_holder.grad_fn) {
return outputs;
}
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->dsts.reserve(ctx.nargs);
for (auto&& grad_key: grad_keys) {
auto grad_fn = std::make_shared<GradFn>();
grad_fn->backward = grad_fn_holder.grad_fn->backward;
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->dsts.reserve(ctx.nargs);
std::visit([&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) {
auto& input_grad_info = ctx.args[i]->m_grad_info;
grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
std::visit([&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0);
} else {
for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get());
grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else {
grad_fn->dsts.emplace_back();
}
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn]
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
outputs[i] = python::apply(op, outputs[i])[0];
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
if (backward.output_captured(i)) {
// avoid reference cycle [Tensor <-> GradFn]
static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>();
outputs[i] = python::apply(op, outputs[i])[0];
}
// populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()];
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Flags::GRAD;
}
// populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Flags::GRAD;
}
}
}
}, grad_fn->backward);
}, grad_fn->backward);
// record forward history
grad_key->tape.emplace_back(grad_fn);
// record forward history
grad_key->tape.emplace_back(grad_fn);
}
return outputs;
}
PyObject* GradKeyWrapper::get_priority() {
return py::cast(m_key->priority).release().ptr();
}
void GradKeyWrapper::set_priority(pybind11::handle priority) {
m_key->name = py::cast<int>(priority);
}
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
if (nargs != 2) {
throw py::type_error("expect 2 arguments");
......@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) {
throw py::value_error("grad key finalized");
}
if (tensor->m_grad_info.grad_fn) {
if (tensor->m_grad_info.grad_fn->key.lock().get() != this) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
if (tensor->m_grad_info->callback) {
if (tensor->m_grad_info_dict.count(this)) {
if (tensor->m_grad_info_dict.at(this)->callback) {
throw py::value_error("callback already set on this tensor");
}
} else {
tensor->m_grad_info.idx = 0;
auto& grad_fn = tensor->m_grad_info.grad_fn;
auto& grad_info = tensor->m_grad_info_dict[this];
grad_info.idx = 0;
auto& grad_fn = grad_info.grad_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn->key = shared_from_this();
grad_fn->slots.resize(1);
tensor->m_grad_info.insert_after(free_vars_head);
grad_info.insert_after(free_vars_head);
tensor->m_flags |= Flags::GRAD;
}
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback);
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
}
template<typename T>
......@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
active = false;
struct CleanupGuard {
GradKey* owner;
CleanupGuard(GradKey* this_) : owner(this_) {}
~CleanupGuard() {owner->cleanup();}
size_t priority_backup;
CleanupGuard(GradKey* this_) : owner(this_) {
priority_backup = sm_min_priority;
sm_min_priority = owner->priority;
}
~CleanupGuard() {
owner->cleanup();
sm_min_priority = priority_backup;
}
} _cleanup_guard(this);
if (tape.empty()) return;
......@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info;
if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) {
grad_info->grad = grads[i]->m_tensor;
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
continue;
}
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
grad_info->grad = grads[i]->m_tensor;
}
std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());
// back-propagation in reverse order
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock();
......@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn;
if (grad_fn && grad_fn->key.lock() == m_key) {
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
int GradKey::sm_min_priority = 0;
GradKey::~GradKey() {
cleanup();
}
......@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
return registry;
}
void GradInfoCollection::_shrink() {
auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); };
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
m_storage.erase(iter, m_storage.end());
}
bool GradInfoCollection::contains(GradKey* key) {
_shrink();
for (auto&& grad_info: m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return true;
}
}
return false;
}
GradInfo& GradInfoCollection::operator[](GradKey* key) {
_shrink();
for (auto&& grad_info: m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
m_storage.emplace_back();
return m_storage.back();
}
GradInfo& GradInfoCollection::at(GradKey* key) {
_shrink();
for (auto&& grad_info: m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
mgb_assert(false);
}
} // namespace mgb::imperative::python
......@@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
bool active = true;
GradInfo::head_t free_vars_head;
std::vector<std::weak_ptr<GradFn>> tape;
int priority = 0;
~GradKey();
void attach(Tensor* tensor, pybind11::object callback);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
void cleanup();
bool is_blocked() const {
return priority < sm_min_priority;
}
private:
static int sm_min_priority;
};
struct GradKeyWrapper {
......@@ -44,6 +50,8 @@ struct GradKeyWrapper {
PyObject* get_name();
void set_name(pybind11::handle name);
PyObject* get_priority();
void set_priority(pybind11::handle priority);
void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
PyObject* is_attached_to(PyObject*const* args, size_t nargs);
......@@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
return bool(ctx.args[i]->m_grad_info.grad_fn);
return !ctx.args[i]->m_grad_info_dict.empty();
}
struct GradRuleFallback : std::exception {};
......
......@@ -15,6 +15,7 @@
namespace mgb::imperative::python {
struct GradKey;
struct GradFn;
struct GradSlot;
......@@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be
GradInfo(GradInfo&&) = default;
GradInfo& operator=(GradInfo&) = default;
GradInfo& operator=(GradInfo&&) = default;
GradInfo(const GradInfo& rhs): GradInfo(const_cast<GradInfo&>(rhs)){}
GradInfo& operator=(const GradInfo& rhs) {
return *this = const_cast<GradInfo&>(rhs);
}
};
} // namespace mgb::imperative::python
......@@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
}
auto op = ctx.op.get();
auto rst = OpDef::apply_on_var_node(*op, vinputs);
......
......@@ -17,6 +17,7 @@
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include <string>
#include <unordered_map>
#include "./pyext17.h"
......@@ -36,6 +37,8 @@ struct ObjectPtr : B {
namespace mgb::imperative::python {
struct GradKey;
extern interpreter::Interpreter::Channel* interpreter_for_py;
class SharedHandle {
......@@ -58,6 +61,34 @@ public:
};
// impl in grad.cpp
class GradInfoCollection {
private:
SmallVector<GradInfo> m_storage;
protected:
void _shrink();
public:
bool contains(GradKey* key);
GradInfo& operator[](GradKey* key);
GradInfo& at(GradKey* key);
bool empty() {
_shrink();
return m_storage.empty();
}
auto begin() {
_shrink();
return m_storage.begin();
}
auto end() {
_shrink();
return m_storage.end();
}
size_t count(GradKey* key) {
return contains(key) ? 1 : 0;
}
};
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using flags_t = uint64_t;
......@@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
flags_t m_flags = 0;
GradInfo m_grad_info;
GradInfoCollection m_grad_info_dict;
TraceInfo m_trace_info;
SharedHandle m_handle;
std::string user_custom_name;
......@@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_handle);
ret->m_flags = m_flags;
ret->m_grad_info = m_grad_info;
ret->m_grad_info_dict = m_grad_info_dict;
ret->m_trace_info = m_trace_info;
ret->m_var = m_var;
return ret;
......
......@@ -108,21 +108,24 @@ def test_grad_2():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
@pytest.mark.skip(reason="high order gradient was not implemented yet")
def test_2nd_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)
ones = as_tensor(np.ones_like(x_np))
grad = Grad().wrt(x, callback=save_to(x))
grad._priority = -1
grad2 = Grad().wrt(x, callback=save_to(x))
grad2._priority = 0
y = cos(x)
grad(y, ones)
z = x.grad
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
grad2(x.grad, ones)
x.grad = None
grad2(z, ones)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np))
......
......@@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy)
.fallback();
}} // copy
namespace { namespace identity {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Identity>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::Identity::make(inputs[0], config);
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // identity
namespace { namespace assert_equal {
auto apply_on_var_node(
const OpDef& def,
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
......@@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy)
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy);
namespace { namespace identity {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Identity>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::Identity::make(inputs[0], config);
}
auto apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return SmallVector<TensorPtr>{inputs[0]};
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
}} // identity
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册