提交 522e556b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(autodiff): support higher order grad

GitOrigin-RevId: 86390d217940d2240d6908a29a6956b90f3b7b2e
上级 5198b783
...@@ -20,6 +20,9 @@ class AttachSpec: ...@@ -20,6 +20,9 @@ class AttachSpec:
__slots__ = "tensor", "callbacks" __slots__ = "tensor", "callbacks"
_global_priority = 0
class GradManager: class GradManager:
r""" r"""
GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
...@@ -118,6 +121,7 @@ class GradManager: ...@@ -118,6 +121,7 @@ class GradManager:
self._grad = None self._grad = None
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = {} self._gradients = {}
self._priority = None
def attach(self, tensors: Iterable[Tensor], callbacks=None): def attach(self, tensors: Iterable[Tensor], callbacks=None):
r""" r"""
...@@ -293,6 +297,7 @@ class GradManager: ...@@ -293,6 +297,7 @@ class GradManager:
After this call, you will be able to call :meth:`backward`. After this call, you will be able to call :meth:`backward`.
""" """
global _global_priority
if self._recording: if self._recording:
raise RuntimeError("already recording") raise RuntimeError("already recording")
grad = Grad() grad = Grad()
...@@ -300,6 +305,9 @@ class GradManager: ...@@ -300,6 +305,9 @@ class GradManager:
self._grad = grad self._grad = grad
for spec in self._attach_specs.values(): for spec in self._attach_specs.values():
self._do_record(spec) self._do_record(spec)
if self._priority is None:
grad._priority = _global_priority
_global_priority -= 1
grad.__enter__() grad.__enter__()
def _do_record(self, spec): def _do_record(self, spec):
...@@ -321,11 +329,14 @@ class GradManager: ...@@ -321,11 +329,14 @@ class GradManager:
After this call, you will not be able to call :meth:`backward`. After this call, you will not be able to call :meth:`backward`.
""" """
global _global_priority
if self._grad is not None: if self._grad is not None:
self._grad.__exit__(None, None, None) self._grad.__exit__(None, None, None)
self._grad = None self._grad = None
self._recording = False self._recording = False
self._gradients = dict() self._gradients = dict()
if self._priority is None:
_global_priority += 1
def __enter__(self): def __enter__(self):
self.record() self.record()
...@@ -333,3 +344,41 @@ class GradManager: ...@@ -333,3 +344,41 @@ class GradManager:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.release() self.release()
def __and__(self, other):
if isinstance(other, GradManager):
return GradManagerGroup([self, other])
return NotImplemented
__rand__ = __and__
class GradManagerGroup:
def __init__(self, gms) -> None:
self._gms = list(gms)
def merge_with(self, other):
if isinstance(other, GradManager):
other = GradManagerGroup([other])
elif not isinstance(other, GradManagerGroup):
return NotImplemented
return GradManagerGroup([*self._gms, *other._gms])
__and__ = merge_with
__rand__ = merge_with
__or__ = merge_with
__ror__ = merge_with
def __enter__(self):
global _global_priority
_global_priority += 1
for gm in self._gms:
gm._priority = _global_priority
gm.record()
def __exit__(self, exc_type, exc_val, exc_tb):
global _global_priority
_global_priority -= 1
for gm in self._gms:
gm.release()
gm._priority = None
...@@ -47,6 +47,14 @@ class Grad: ...@@ -47,6 +47,14 @@ class Grad:
self._impl = GradKey(name) self._impl = GradKey(name)
_grad_manager_dict[self._name] = self _grad_manager_dict[self._name] = self
@property
def _priority(self):
return self._impl.priority
@_priority.setter
def _priority(self, priority):
self._impl.priority = priority
@property @property
def _name(self): def _name(self):
return self._impl.name return self._impl.name
......
...@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( ...@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); *(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle());
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); *(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node());
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); *(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty();
} }
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
...@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra ...@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]);
inputs_copy_weak.push_back(inputs_copy.back().get()); inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
} }
ApplyContext ctx_dup = ctx; ApplyContext ctx_dup = ctx;
ctx_dup.args = inputs_copy_weak.data(); ctx_dup.args = inputs_copy_weak.data();
...@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { ...@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
} // namespace } // namespace
apply_result_t apply_grad(ApplyContext& ctx) { apply_result_t apply_grad(ApplyContext& ctx) {
std::shared_ptr<GradKey> grad_key; std::unordered_set<std::shared_ptr<GradKey>> grad_keys;
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
auto* tensor = ctx.args[i]; auto* tensor = ctx.args[i];
if (tensor->m_grad_info.grad_fn) { if (!tensor->m_grad_info_dict.empty()) {
auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); size_t grad_cnt = 0;
// tensor is attached to a live GradKey for (auto&& grad_info: tensor->m_grad_info_dict) {
if (input_grad_key && input_grad_key->active) { auto input_grad_key = grad_info.grad_fn->key.lock();
if (grad_key) { if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) {
if (grad_key != input_grad_key) { grad_keys.insert(input_grad_key);
PyErr_SetString(PyExc_NotImplementedError, "second order grad"); grad_cnt++;
throw pyext17::py_err_set();
}
} else {
grad_key = std::move(input_grad_key);
} }
} else { }
// cleanup stale grad info if (!grad_cnt) {
// under what condition?
tensor->m_grad_info = {};
tensor->m_flags &= ~Flags::GRAD; tensor->m_flags &= ~Flags::GRAD;
} }
} else { } else {
...@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { ...@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
ctx.flags &= ~Flags::GRAD; ctx.flags &= ~Flags::GRAD;
if (!grad_key) { if (grad_keys.empty()) {
return apply(ctx); return apply(ctx);
} }
...@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) { ...@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) {
return backward_graph_grad_rule(ctx, grad_fn_holder); return backward_graph_grad_rule(ctx, grad_fn_holder);
}(); }();
auto& grad_fn = grad_fn_holder.grad_fn; if (!grad_fn_holder.grad_fn) {
if (!grad_fn) {
return outputs; return outputs;
} }
grad_fn->key = grad_key; for (auto&& grad_key: grad_keys) {
grad_fn->slots.resize(outputs.size()); auto grad_fn = std::make_shared<GradFn>();
grad_fn->dsts.reserve(ctx.nargs); grad_fn->backward = grad_fn_holder.grad_fn->backward;
grad_fn->key = grad_key;
grad_fn->slots.resize(outputs.size());
grad_fn->dsts.reserve(ctx.nargs);
std::visit([&](auto& backward) { std::visit([&](auto& backward) {
using T = std::decay_t<decltype(backward)>; using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) { if constexpr (std::is_same_v<T, std::monostate>) {
mgb_assert(0); mgb_assert(0);
} else { } else {
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) {
auto& input_grad_info = ctx.args[i]->m_grad_info; auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get());
grad_fn->dsts.emplace_back(input_grad_info); grad_fn->dsts.emplace_back(input_grad_info);
// register as grad producer // register as grad producer
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head);
} else { } else {
grad_fn->dsts.emplace_back(); grad_fn->dsts.emplace_back();
}
} }
} for (size_t i = 0; i < outputs.size(); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) { if (backward.output_requires_grad(i)) {
if (backward.output_requires_grad(i)) { if (backward.output_captured(i)) {
if (backward.output_captured(i)) { // avoid reference cycle [Tensor <-> GradFn]
// avoid reference cycle [Tensor <-> GradFn] static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>();
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); outputs[i] = python::apply(op, outputs[i])[0];
outputs[i] = python::apply(op, outputs[i])[0]; }
// populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()];
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Flags::GRAD;
} }
// populate grad info of output tensor
auto& grad_info = outputs[i]->m_grad_info;
grad_info.grad_fn = grad_fn;
grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Flags::GRAD;
} }
} }
} }, grad_fn->backward);
}, grad_fn->backward);
// record forward history // record forward history
grad_key->tape.emplace_back(grad_fn); grad_key->tape.emplace_back(grad_fn);
}
return outputs; return outputs;
} }
PyObject* GradKeyWrapper::get_priority() {
return py::cast(m_key->priority).release().ptr();
}
void GradKeyWrapper::set_priority(pybind11::handle priority) {
m_key->name = py::cast<int>(priority);
}
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
if (nargs != 2) { if (nargs != 2) {
throw py::type_error("expect 2 arguments"); throw py::type_error("expect 2 arguments");
...@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { ...@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) {
throw py::value_error("grad key finalized"); throw py::value_error("grad key finalized");
} }
if (tensor->m_grad_info.grad_fn) { if (tensor->m_grad_info_dict.count(this)) {
if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { if (tensor->m_grad_info_dict.at(this)->callback) {
PyErr_SetString(PyExc_NotImplementedError, "second order grad");
throw pyext17::py_err_set();
}
if (tensor->m_grad_info->callback) {
throw py::value_error("callback already set on this tensor"); throw py::value_error("callback already set on this tensor");
} }
} else { } else {
tensor->m_grad_info.idx = 0; auto& grad_info = tensor->m_grad_info_dict[this];
auto& grad_fn = tensor->m_grad_info.grad_fn; grad_info.idx = 0;
auto& grad_fn = grad_info.grad_fn;
grad_fn = std::make_shared<GradFn>(); grad_fn = std::make_shared<GradFn>();
grad_fn->key = shared_from_this(); grad_fn->key = shared_from_this();
grad_fn->slots.resize(1); grad_fn->slots.resize(1);
tensor->m_grad_info.insert_after(free_vars_head); grad_info.insert_after(free_vars_head);
tensor->m_flags |= Flags::GRAD; tensor->m_flags |= Flags::GRAD;
} }
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback);
} }
template<typename T> template<typename T>
...@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
active = false; active = false;
struct CleanupGuard { struct CleanupGuard {
GradKey* owner; GradKey* owner;
CleanupGuard(GradKey* this_) : owner(this_) {} size_t priority_backup;
~CleanupGuard() {owner->cleanup();} CleanupGuard(GradKey* this_) : owner(this_) {
priority_backup = sm_min_priority;
sm_min_priority = owner->priority;
}
~CleanupGuard() {
owner->cleanup();
sm_min_priority = priority_backup;
}
} _cleanup_guard(this); } _cleanup_guard(this);
if (tape.empty()) return; if (tape.empty()) return;
...@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
} }
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info; if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) {
if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { continue;
grad_info->grad = grads[i]->m_tensor;
} }
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this);
grad_info->grad = grads[i]->m_tensor;
} }
std::vector<std::shared_ptr<GradFn>> ref_keeper; std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size()); ref_keeper.reserve(tape.size());
// back-propagation in reverse order // back-propagation in reverse order
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) {
auto&& grad_fn = tape[k].lock(); auto&& grad_fn = tape[k].lock();
...@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { ...@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
PyErr_SetString(PyExc_TypeError, "expect Tensor"); PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr; return nullptr;
} }
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) {
if (grad_fn && grad_fn->key.lock() == m_key) {
Py_RETURN_TRUE; Py_RETURN_TRUE;
} }
Py_RETURN_FALSE; Py_RETURN_FALSE;
} }
int GradKey::sm_min_priority = 0;
GradKey::~GradKey() { GradKey::~GradKey() {
cleanup(); cleanup();
} }
...@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { ...@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
return registry; return registry;
} }
void GradInfoCollection::_shrink() {
auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); };
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred);
m_storage.erase(iter, m_storage.end());
}
bool GradInfoCollection::contains(GradKey* key) {
_shrink();
for (auto&& grad_info: m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return true;
}
}
return false;
}
GradInfo& GradInfoCollection::operator[](GradKey* key) {
_shrink();
for (auto&& grad_info: m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
m_storage.emplace_back();
return m_storage.back();
}
GradInfo& GradInfoCollection::at(GradKey* key) {
_shrink();
for (auto&& grad_info: m_storage) {
if (grad_info.grad_fn->key.lock().get() == key) {
return grad_info;
}
}
mgb_assert(false);
}
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
...@@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { ...@@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
bool active = true; bool active = true;
GradInfo::head_t free_vars_head; GradInfo::head_t free_vars_head;
std::vector<std::weak_ptr<GradFn>> tape; std::vector<std::weak_ptr<GradFn>> tape;
int priority = 0;
~GradKey(); ~GradKey();
void attach(Tensor* tensor, pybind11::object callback); void attach(Tensor* tensor, pybind11::object callback);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
void cleanup(); void cleanup();
bool is_blocked() const {
return priority < sm_min_priority;
}
private:
static int sm_min_priority;
}; };
struct GradKeyWrapper { struct GradKeyWrapper {
...@@ -44,6 +50,8 @@ struct GradKeyWrapper { ...@@ -44,6 +50,8 @@ struct GradKeyWrapper {
PyObject* get_name(); PyObject* get_name();
void set_name(pybind11::handle name); void set_name(pybind11::handle name);
PyObject* get_priority();
void set_priority(pybind11::handle priority);
void attach(PyObject*const* args, size_t nargs); void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
PyObject* is_attached_to(PyObject*const* args, size_t nargs); PyObject* is_attached_to(PyObject*const* args, size_t nargs);
...@@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M ...@@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry();
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { inline bool input_requires_grad(const ApplyContext& ctx, size_t i) {
return bool(ctx.args[i]->m_grad_info.grad_fn); return !ctx.args[i]->m_grad_info_dict.empty();
} }
struct GradRuleFallback : std::exception {}; struct GradRuleFallback : std::exception {};
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
namespace mgb::imperative::python { namespace mgb::imperative::python {
struct GradKey;
struct GradFn; struct GradFn;
struct GradSlot; struct GradSlot;
...@@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be ...@@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be
GradInfo(GradInfo&&) = default; GradInfo(GradInfo&&) = default;
GradInfo& operator=(GradInfo&) = default; GradInfo& operator=(GradInfo&) = default;
GradInfo& operator=(GradInfo&&) = default; GradInfo& operator=(GradInfo&&) = default;
GradInfo(const GradInfo& rhs): GradInfo(const_cast<GradInfo&>(rhs)){}
GradInfo& operator=(const GradInfo& rhs) {
return *this = const_cast<GradInfo&>(rhs);
}
}; };
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
...@@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs); SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
} }
auto op = ctx.op.get(); auto op = ctx.op.get();
auto rst = OpDef::apply_on_var_node(*op, vinputs); auto rst = OpDef::apply_on_var_node(*op, vinputs);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include <string> #include <string>
#include <unordered_map>
#include "./pyext17.h" #include "./pyext17.h"
...@@ -36,6 +37,8 @@ struct ObjectPtr : B { ...@@ -36,6 +37,8 @@ struct ObjectPtr : B {
namespace mgb::imperative::python { namespace mgb::imperative::python {
struct GradKey;
extern interpreter::Interpreter::Channel* interpreter_for_py; extern interpreter::Interpreter::Channel* interpreter_for_py;
class SharedHandle { class SharedHandle {
...@@ -58,6 +61,34 @@ public: ...@@ -58,6 +61,34 @@ public:
}; };
// impl in grad.cpp
class GradInfoCollection {
private:
SmallVector<GradInfo> m_storage;
protected:
void _shrink();
public:
bool contains(GradKey* key);
GradInfo& operator[](GradKey* key);
GradInfo& at(GradKey* key);
bool empty() {
_shrink();
return m_storage.empty();
}
auto begin() {
_shrink();
return m_storage.begin();
}
auto end() {
_shrink();
return m_storage.end();
}
size_t count(GradKey* key) {
return contains(key) ? 1 : 0;
}
};
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using flags_t = uint64_t; using flags_t = uint64_t;
...@@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { ...@@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
flags_t m_flags = 0; flags_t m_flags = 0;
GradInfo m_grad_info; GradInfoCollection m_grad_info_dict;
TraceInfo m_trace_info; TraceInfo m_trace_info;
SharedHandle m_handle; SharedHandle m_handle;
std::string user_custom_name; std::string user_custom_name;
...@@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { ...@@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
inline std::shared_ptr<Tensor> copy() { inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_handle); auto ret = std::make_shared<Tensor>(m_handle);
ret->m_flags = m_flags; ret->m_flags = m_flags;
ret->m_grad_info = m_grad_info; ret->m_grad_info_dict = m_grad_info_dict;
ret->m_trace_info = m_trace_info; ret->m_trace_info = m_trace_info;
ret->m_var = m_var; ret->m_var = m_var;
return ret; return ret;
......
...@@ -108,21 +108,24 @@ def test_grad_2(): ...@@ -108,21 +108,24 @@ def test_grad_2():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
@pytest.mark.skip(reason="high order gradient was not implemented yet")
def test_2nd_grad(): def test_2nd_grad():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np) x = as_tensor(x_np)
ones = as_tensor(np.ones_like(x_np)) ones = as_tensor(np.ones_like(x_np))
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
grad._priority = -1
grad2 = Grad().wrt(x, callback=save_to(x)) grad2 = Grad().wrt(x, callback=save_to(x))
grad2._priority = 0
y = cos(x) y = cos(x)
grad(y, ones) grad(y, ones)
z = x.grad
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
grad2(x.grad, ones) x.grad = None
grad2(z, ones)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np))
......
...@@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy) ...@@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy)
.fallback(); .fallback();
}} // copy }} // copy
namespace { namespace identity {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Identity>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::Identity::make(inputs[0], config);
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // identity
namespace { namespace assert_equal { namespace { namespace assert_equal {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
...@@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy) ...@@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy)
MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy); MGB_DYN_TYPE_OBJ_FINAL_IMPL(FastpathCopy);
namespace { namespace identity {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Identity>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::Identity::make(inputs[0], config);
}
auto apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return SmallVector<TensorPtr>{inputs[0]};
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
}} // identity
} // namespace mgb::imperative } // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册