提交 72531f2b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

test(autograd): add more tests for higher order grad

GitOrigin-RevId: 5fc308f87a6c4cb2de9b8edb654fe7a2416333ec
上级 522e556b
...@@ -345,12 +345,12 @@ class GradManager: ...@@ -345,12 +345,12 @@ class GradManager:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.release() self.release()
def __and__(self, other): def __or__(self, other):
if isinstance(other, GradManager): if isinstance(other, GradManager):
return GradManagerGroup([self, other]) return GradManagerGroup([self, other])
return NotImplemented return NotImplemented
__rand__ = __and__ __ror__ = __or__
class GradManagerGroup: class GradManagerGroup:
...@@ -364,8 +364,6 @@ class GradManagerGroup: ...@@ -364,8 +364,6 @@ class GradManagerGroup:
return NotImplemented return NotImplemented
return GradManagerGroup([*self._gms, *other._gms]) return GradManagerGroup([*self._gms, *other._gms])
__and__ = merge_with
__rand__ = merge_with
__or__ = merge_with __or__ = merge_with
__ror__ = merge_with __ror__ = merge_with
......
...@@ -468,7 +468,7 @@ PyObject* GradKeyWrapper::get_priority() { ...@@ -468,7 +468,7 @@ PyObject* GradKeyWrapper::get_priority() {
} }
void GradKeyWrapper::set_priority(pybind11::handle priority) { void GradKeyWrapper::set_priority(pybind11::handle priority) {
m_key->name = py::cast<int>(priority); m_key->priority = py::cast<int>(priority);
} }
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
...@@ -535,7 +535,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -535,7 +535,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
size_t priority_backup; size_t priority_backup;
CleanupGuard(GradKey* this_) : owner(this_) { CleanupGuard(GradKey* this_) : owner(this_) {
priority_backup = sm_min_priority; priority_backup = sm_min_priority;
sm_min_priority = owner->priority; sm_min_priority = owner->priority + 1;
} }
~CleanupGuard() { ~CleanupGuard() {
owner->cleanup(); owner->cleanup();
...@@ -636,7 +636,7 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { ...@@ -636,7 +636,7 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
Py_RETURN_FALSE; Py_RETURN_FALSE;
} }
int GradKey::sm_min_priority = 0; int GradKey::sm_min_priority = std::numeric_limits<int>::min();
GradKey::~GradKey() { GradKey::~GradKey() {
cleanup(); cleanup();
......
...@@ -966,6 +966,7 @@ void init_tensor(py::module m) { ...@@ -966,6 +966,7 @@ void init_tensor(py::module m) {
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to") .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
.def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name") .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
.def_getset<&GradKeyWrapper::get_priority, &GradKeyWrapper::set_priority>("priority")
.finalize(); .finalize();
if (!grad_key_type) throw py::error_already_set(); if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type); py::setattr(m, "GradKey", grad_key_type);
......
...@@ -279,3 +279,69 @@ def test_broadcast_grad(trace_mode): ...@@ -279,3 +279,69 @@ def test_broadcast_grad(trace_mode):
func() func()
worker() worker()
def test_2nd_grad_with_manager():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm:
with gm2:
y = F.cos(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
)
def test_grad_manager_group():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm | gm2:
y = F.cos(x)
gm.backward(y)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -2 * np.sin(x_np), decimal=5)
x.grad = None
def test_grad_manager_group_visibility():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm | gm2:
y = F.cos(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
def test_grad_manager_visibility_by_order():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm2:
with gm:
y = F.cos(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
...@@ -126,7 +126,7 @@ def test_2nd_grad(): ...@@ -126,7 +126,7 @@ def test_2nd_grad():
x.grad = None x.grad = None
grad2(z, ones) grad2(z, ones)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np)) np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5)
def test_grad_with_tensor_wrapper(): def test_grad_with_tensor_wrapper():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册