提交 99a85c40 编写于 作者: M Megvii Engine Team

fix(mge): fix advanced indexing grad

GitOrigin-RevId: 8033c9322dd79db2b72a8cc7df66cdaf270b0b60
上级 409c9881
...@@ -233,7 +233,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( ...@@ -233,7 +233,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) { CustomBackward& backward) {
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items);
SmallVector<ValueRef> inputs2; SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) { if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0])); inputs2.push_back(get_shape(inputs[0]));
......
...@@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec(): ...@@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec():
def f(x): def f(x):
x = x * 1 x = x * 1
y = x[[0, 2], [0, 2]] y = x[[0, 0, 2, 1], [2, 2, 1, 0]]
refs["x"] = TensorWeakRef(x) refs["x"] = TensorWeakRef(x)
return y return y
...@@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec(): ...@@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec():
grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() np.array([[0, 0, 2], [1, 0, 0], [0, 1, 0]], dtype=np.float32), x.grad.numpy()
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册