提交 99a85c40 编写于 作者: M Megvii Engine Team

fix(mge): fix advanced indexing grad

GitOrigin-RevId: 8033c9322dd79db2b72a8cc7df66cdaf270b0b60
上级 409c9881
......@@ -233,7 +233,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items);
SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0]));
......
......@@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec():
def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
y = x[[0, 0, 2, 1], [2, 2, 1, 0]]
refs["x"] = TensorWeakRef(x)
return y
......@@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec():
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy()
np.array([[0, 0, 2], [1, 0, 0], [0, 1, 0]], dtype=np.float32), x.grad.numpy()
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册