提交 0a3ca253 编写于 作者: M Megvii Engine Team

fix(mge): fix backward graph optimization

GitOrigin-RevId: 28fd00ac548e7cd663a87ce74070862e9a24b55f
上级 ea8eb4cf
...@@ -19,7 +19,7 @@ import megengine.functional as F ...@@ -19,7 +19,7 @@ import megengine.functional as F
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise, Identity
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import remote_recv, remote_send from megengine.functional.distributed import remote_recv, remote_send
...@@ -193,6 +193,20 @@ def test_grad_inplace(): ...@@ -193,6 +193,20 @@ def test_grad_inplace():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)
def test_identity():
x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np)
dy_np = np.random.rand(*x.shape).astype("float32")
dy = mge.Tensor(dy_np)
grad = Grad().wrt(x, callback=save_to(x))
(y,) = apply(Identity(), x)
grad(y, dy)
np.testing.assert_array_equal(x.grad.numpy(), dy_np)
def test_elemwise_add(): def test_elemwise_add():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10, 10).astype("float32") y_np = np.random.rand(10, 10).astype("float32")
......
...@@ -58,7 +58,7 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe ...@@ -58,7 +58,7 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
// should be marked as always appears in backward // should be marked as always appears in backward
for (size_t i = 0, j = 0; i < mask.size(); ++i) { for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (!mask[i]) continue; if (!mask[i]) continue;
if (i > input_size + output_size) { if (i >= input_size + output_size) {
vinfo[graph.inputs[j]].appears_in_backward = true; vinfo[graph.inputs[j]].appears_in_backward = true;
} }
++j; ++j;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册