提交 ab309eb5 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let Split support empty IO

GitOrigin-RevId: aad6dc06bfe9b95889b924e0a26f3ea33c52319a
上级 a8292704
...@@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis ...@@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import astensor1d, convert_inputs, get_device from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import ceil, floor_div from .elemwise import ceil
__all__ = [ __all__ = [
"arange", "arange",
...@@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0): ...@@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal = inp.shape[axis] Ntotal = inp.shape[axis]
try: if isinstance(nsplits_or_sections, Sequence):
Nsections = len(nsplits_or_sections) + 1 Nsections = len(nsplits_or_sections) + 1
is_array = True is_array = True
except TypeError: else:
Nsections = int(nsplits_or_sections) Nsections = int(nsplits_or_sections)
is_array = False is_array = False
...@@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0): ...@@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal, axis, Nsections Ntotal, axis, Nsections
) )
) )
partitions = []
func = ( for i in range(Nsections):
floor_div section_size = (Ntotal + Nsections - i - 1) // Nsections
if isinstance(Nsections, (SymbolVar, Tensor)) partitions.append(section_size)
else lambda x, y: x // y
) partitions = [
div_points = [0] + [ part
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) if isinstance(part, (SymbolVar, Tensor))
] else Const(part, dtype="int32", device=inp.device)(inp)[0]
for i in range(2, Nsections + 1): for part in partitions
div_points[i] = div_points[i - 1] + div_points[i] ]
op = builtin.Split(axis=axis)
sub_tensors = [] return apply(op, inp, *partitions)
for i in range(Nsections):
l = div_points[i]
r = div_points[i + 1]
slices = tuple(
[slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1)
)
sub_tensors.append(inp[slices])
return sub_tensors
def _get_idx(index, axis): def _get_idx(index, axis):
......
...@@ -178,6 +178,21 @@ def test_regression_1762(): ...@@ -178,6 +178,21 @@ def test_regression_1762():
gm.backward(loss) gm.backward(loss)
def test_empty_grad_in_backward():
x = mge.Parameter(F.full(100, 0.5))
y = mge.Parameter(F.ones(100))
gm = GradManager()
gm.attach([x, y])
with gm:
z = F.where(x > 0.7, x, y)
loss = z.sum()
gm.backward(loss)
assert np.all(x.grad.numpy() == 0)
assert np.all(y.grad.numpy() == 1)
@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -119,7 +119,7 @@ def test_stack(is_varnode): ...@@ -119,7 +119,7 @@ def test_stack(is_varnode):
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_split(is_varnode): def test_split_basic(is_varnode):
if is_varnode: if is_varnode:
network = Network() network = Network()
saved_symbolic_shape = set_symbolic_shape(False) saved_symbolic_shape = set_symbolic_shape(False)
...@@ -150,15 +150,48 @@ def test_split(is_varnode): ...@@ -150,15 +150,48 @@ def test_split(is_varnode):
pass pass
try: try:
F.split(inp, [3, 3, 5], axis=3) F.split(inp, [3, 2, 5], axis=3)
assert False assert False
except ValueError as e: except ValueError as e:
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
if is_varnode: if is_varnode:
set_symbolic_shape(saved_symbolic_shape) set_symbolic_shape(saved_symbolic_shape)
@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
def ref(inp, nsplits_or_sections, axis):
return np.split(inp, nsplits_or_sections, axis)
def func(inp, nsplits_or_sections, axis):
return F.split(inp, nsplits_or_sections, axis)
cases = [
(inp1, 2, 3),
(inp1, [3], 3),
(inp1, [3, 3, 5], 3),
(inp2, 2, 3),
(inp2, [3], 3),
(inp2, [3, 3, 5], 3),
]
for case in cases:
if symbolic is None:
fn = func
else:
fn = trace(symbolic=symbolic)(func)
for i in range(3 if symbolic is not None else 1):
ref_out = ref(*case)
out = fn(tensor(case[0]), case[1], case[2])
assert len(ref_out) == len(out)
for idx in range(len(ref_out)):
np.testing.assert_equal(ref_out[idx], out[idx].numpy())
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode): def test_reshape(is_varnode):
if is_varnode: if is_varnode:
......
...@@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config) ...@@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config)
} }
for (size_t i = 0; i < m_opt.nr_part; ++ i) for (size_t i = 0; i < m_opt.nr_part; ++ i)
add_output(ssprintf("o%zd", i))->dtype(inp->dtype()); add_output(ssprintf("o%zd", i))->dtype(inp->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
m_output_spec.resize(m_opt.nr_part); m_output_spec.resize(m_opt.nr_part);
} }
...@@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest, ...@@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest,
size_t size = 0; size_t size = 0;
for (size_t i = 0; i < m_opt.nr_part; ++ i) { for (size_t i = 0; i < m_opt.nr_part; ++ i) {
auto p = partition[i]; auto p = partition[i];
mgb_assert(p,
"got zero partition size at part %zu, tot_size=%zu",
i, ishp.shape[axis]);
size += p; size += p;
auto &&cur = m_output_spec[i].shape; auto &&cur = m_output_spec[i].shape;
...@@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const { ...@@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const {
auto rst = OperatorNodeBase::do_make_node_prop(); auto rst = OperatorNodeBase::do_make_node_prop();
rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY); rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
outshape_by_symvar_reset_node_dep_type(rst); outshape_by_symvar_reset_node_dep_type(rst);
rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return rst; return rst;
} }
...@@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) { ...@@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) {
auto &&in = input(0)->dev_tensor(); auto &&in = input(0)->dev_tensor();
auto &&out = output(idx)->dev_tensor(); auto &&out = output(idx)->dev_tensor();
auto &&spec = m_output_spec.at(idx); auto &&spec = m_output_spec.at(idx);
if (out.layout().is_empty()) {
mgb_assert(spec.subspec.layout().is_empty());
return;
}
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>( owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
this, out.comp_node()); this, out.comp_node());
if (spec.mem_fwd_success) { if (spec.mem_fwd_success) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册