提交 ab309eb5 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let Split support empty IO

GitOrigin-RevId: aad6dc06bfe9b95889b924e0a26f3ea33c52319a
上级 a8292704
......@@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis
from ..core.tensor.utils import astensor1d, convert_inputs, get_device
from ..device import get_default_device
from ..tensor import Tensor
from .elemwise import ceil, floor_div
from .elemwise import ceil
__all__ = [
"arange",
......@@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal = inp.shape[axis]
try:
if isinstance(nsplits_or_sections, Sequence):
Nsections = len(nsplits_or_sections) + 1
is_array = True
except TypeError:
else:
Nsections = int(nsplits_or_sections)
is_array = False
......@@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal, axis, Nsections
)
)
func = (
floor_div
if isinstance(Nsections, (SymbolVar, Tensor))
else lambda x, y: x // y
)
div_points = [0] + [
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
]
for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i]
sub_tensors = []
for i in range(Nsections):
l = div_points[i]
r = div_points[i + 1]
slices = tuple(
[slice(None)] * axis + [slice(l, r)] + [slice(None)] * (ndim - axis - 1)
)
sub_tensors.append(inp[slices])
return sub_tensors
partitions = []
for i in range(Nsections):
section_size = (Ntotal + Nsections - i - 1) // Nsections
partitions.append(section_size)
partitions = [
part
if isinstance(part, (SymbolVar, Tensor))
else Const(part, dtype="int32", device=inp.device)(inp)[0]
for part in partitions
]
op = builtin.Split(axis=axis)
return apply(op, inp, *partitions)
def _get_idx(index, axis):
......
......@@ -178,6 +178,21 @@ def test_regression_1762():
gm.backward(loss)
def test_empty_grad_in_backward():
x = mge.Parameter(F.full(100, 0.5))
y = mge.Parameter(F.ones(100))
gm = GradManager()
gm.attach([x, y])
with gm:
z = F.where(x > 0.7, x, y)
loss = z.sum()
gm.backward(loss)
assert np.all(x.grad.numpy() == 0)
assert np.all(y.grad.numpy() == 1)
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
@pytest.mark.parametrize(
......
......@@ -119,7 +119,7 @@ def test_stack(is_varnode):
@pytest.mark.parametrize("is_varnode", [True, False])
def test_split(is_varnode):
def test_split_basic(is_varnode):
if is_varnode:
network = Network()
saved_symbolic_shape = set_symbolic_shape(False)
......@@ -150,15 +150,48 @@ def test_split(is_varnode):
pass
try:
F.split(inp, [3, 3, 5], axis=3)
F.split(inp, [3, 2, 5], axis=3)
assert False
except ValueError as e:
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"
if is_varnode:
set_symbolic_shape(saved_symbolic_shape)
@pytest.mark.parametrize("symbolic", [None, False, True])
def test_split(symbolic):
inp1 = np.random.random((3, 4, 5, 6)).astype(np.float32)
inp2 = np.random.random((0, 4, 5, 6)).astype(np.float32)
def ref(inp, nsplits_or_sections, axis):
return np.split(inp, nsplits_or_sections, axis)
def func(inp, nsplits_or_sections, axis):
return F.split(inp, nsplits_or_sections, axis)
cases = [
(inp1, 2, 3),
(inp1, [3], 3),
(inp1, [3, 3, 5], 3),
(inp2, 2, 3),
(inp2, [3], 3),
(inp2, [3, 3, 5], 3),
]
for case in cases:
if symbolic is None:
fn = func
else:
fn = trace(symbolic=symbolic)(func)
for i in range(3 if symbolic is not None else 1):
ref_out = ref(*case)
out = fn(tensor(case[0]), case[1], case[2])
assert len(ref_out) == len(out)
for idx in range(len(ref_out)):
np.testing.assert_equal(ref_out[idx], out[idx].numpy())
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
if is_varnode:
......
......@@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config)
}
for (size_t i = 0; i < m_opt.nr_part; ++ i)
add_output(ssprintf("o%zd", i))->dtype(inp->dtype());
add_output(ssprintf("o%zd", i))->dtype(inp->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
m_output_spec.resize(m_opt.nr_part);
}
......@@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest,
size_t size = 0;
for (size_t i = 0; i < m_opt.nr_part; ++ i) {
auto p = partition[i];
mgb_assert(p,
"got zero partition size at part %zu, tot_size=%zu",
i, ishp.shape[axis]);
size += p;
auto &&cur = m_output_spec[i].shape;
......@@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const {
auto rst = OperatorNodeBase::do_make_node_prop();
rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
outshape_by_symvar_reset_node_dep_type(rst);
rst->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return rst;
}
......@@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) {
auto &&in = input(0)->dev_tensor();
auto &&out = output(idx)->dev_tensor();
auto &&spec = m_output_spec.at(idx);
if (out.layout().is_empty()) {
mgb_assert(spec.subspec.layout().is_empty());
return;
}
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
this, out.comp_node());
if (spec.mem_fwd_success) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册