提交 c850c7eb 编写于 作者: M Megvii Engine Team

fix(masked_fill): fix error and add some tests

GitOrigin-RevId: 225861f30a77103a3b988968d45054bd28260590
上级 1f0ea02f
...@@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile ...@@ -5,6 +5,7 @@ from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
import pytest import pytest
import torch
from utils import make_tensor from utils import make_tensor
import megengine import megengine
...@@ -13,6 +14,7 @@ import megengine.functional as F ...@@ -13,6 +14,7 @@ import megengine.functional as F
import megengine.jit as jit import megengine.jit as jit
import megengine.random as rand import megengine.random as rand
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine.autodiff import GradManager
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin from megengine.core.ops import builtin
...@@ -622,6 +624,25 @@ def test_advance_indexing_with_bool(test_varnode): ...@@ -622,6 +624,25 @@ def test_advance_indexing_with_bool(test_varnode):
) )
def test_advance_indexing_autodiff():
x = Tensor([2, 2, 3, 4, 5, 6, 7, 8, 2], dtype="float32")
gm = GradManager()
gm.attach(x)
with gm:
a = x + 1
a[x > 3] = 0.3
b = a + 1
gm.backward(b.sum())
torch_x = torch.tensor(
[2, 2, 3, 4, 5, 6, 7, 8, 2], dtype=torch.float32, requires_grad=True
)
a = torch_x + 1
a[torch_x > 3] = 0.3
b = a + 1
(b.sum()).backward()
np.testing.assert_equal(x.grad.numpy(), torch_x.grad.numpy())
@pytest.mark.parametrize("symbolic", [True, False, None]) @pytest.mark.parametrize("symbolic", [True, False, None])
def test_subtensor_on_empty_tensor(symbolic): def test_subtensor_on_empty_tensor(symbolic):
np_x = np.array([], dtype=np.float32).reshape(10, 0, 10) np_x = np.array([], dtype=np.float32).reshape(10, 0, 10)
......
...@@ -306,7 +306,7 @@ cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& in ...@@ -306,7 +306,7 @@ cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& in
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [](const TensorLayout& layout) { layout_checker[1] = [](const TensorLayout& layout) {
return layout.is_contiguous(); return layout.is_contiguous();
}; };
return layout_checker; return layout_checker;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册