提交 6cd01d5a 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(imperative/functional): let elemwise support empty IO & add some tests

GitOrigin-RevId: a5dc3b997ca3c721c985d3eff6a54d0bea771914
上级 dea52781
...@@ -7,12 +7,14 @@ ...@@ -7,12 +7,14 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import pytest
import megengine.functional as F import megengine.functional as F
import megengine.functional.elemwise as elemwise import megengine.functional.elemwise as elemwise
from megengine import tensor from megengine import tensor
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.functional.elemwise import Elemwise, _elwise from megengine.functional.elemwise import Elemwise, _elwise
from megengine.jit import trace
def test_abs(): def test_abs():
...@@ -180,3 +182,80 @@ def test_int32_input(): ...@@ -180,3 +182,80 @@ def test_int32_input():
inp = (x,) * nargs inp = (x,) * nargs
y = op(*inp) y = op(*inp)
y.numpy() y.numpy()
@pytest.mark.parametrize("is_trace", [True, False])
def test_empty_tensor(is_trace):
binary_func = []
unary_func = []
for op_name in elemwise.__all__:
op = getattr(elemwise, op_name)
nargs = op.__code__.co_argcount
if op_name == "clip":
unary_func.append(["clip", lambda x, f=op: f(x, lower=0, upper=1)])
elif op_name.endswith("_shift"):
unary_func.append(
[op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="int32"), 1)]
)
elif op_name.startswith("logical_"): # logical_xxx op only accept boolean type
if nargs == 1:
unary_func.append(
[op_name, lambda x, f=op: f(tensor(x.numpy(), dtype="bool"))]
)
else:
assert nargs == 2
binary_func.append(
[
op_name,
lambda x, y, f=op: f(
tensor(x.numpy(), dtype="bool"),
tensor(y.numpy(), dtype="bool"),
),
]
)
elif nargs == 1:
unary_func.append([op_name, op])
elif nargs == 2:
binary_func.append([op_name, op])
else:
print(nargs)
raise NotImplementedError
def run_test(func, args, ref_shape, is_trace, sym=False):
args = [tensor(t, dtype="float32") for t in args]
if is_trace:
func = trace(symbolic=sym)(func)
for _ in range(3):
out = func(*args)
assert out.numpy().shape == ref_shape
else:
out = func(*args)
assert out.numpy().shape == ref_shape
print(out.numpy().shape)
inps = [
np.array([]).astype("float32"),
np.random.randn(2, 0, 3).astype("float32"),
123,
]
for op_name, op in unary_func:
if is_trace:
for sym in [True, False]:
run_test(op, [inps[0],], inps[0].shape, True, sym)
run_test(op, [inps[1],], inps[1].shape, True, sym)
else:
run_test(op, [inps[0],], inps[0].shape, False)
run_test(op, [inps[1],], inps[1].shape, False)
for op_name, op in binary_func:
if is_trace:
for sym in [True, False]:
run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, True, sym)
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, True, sym)
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, True, sym)
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, True, sym)
else:
run_test(op, [inps[0], inps[0]], (inps[0] + inps[0]).shape, False)
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False)
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False)
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False)
...@@ -19,6 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape ...@@ -19,6 +19,7 @@ from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
from megengine.utils.network import Network, set_symbolic_shape from megengine.utils.network import Network, set_symbolic_shape
from megengine.utils.network_node import VarNode from megengine.utils.network_node import VarNode
...@@ -177,6 +178,48 @@ def test_reshape(is_varnode): ...@@ -177,6 +178,48 @@ def test_reshape(is_varnode):
np.testing.assert_equal(yy.numpy(), y) np.testing.assert_equal(yy.numpy(), y)
@pytest.mark.parametrize("is_trace", [True, False])
def test_reshape_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32))
input2_shape = (10, 0)
output2_shape = (0,)
data2 = tensor(np.random.random(input2_shape).astype(np.float32))
input3_shape = (10, 0, 10)
output3_shape = (0, 1, 2, 3)
data3 = tensor(np.random.random(input3_shape).astype(np.float32))
def comp(out, target_shp):
assert out._tuple_shape == target_shp
def func(x, shp):
return F.reshape(x, shp)
cases = [
[data1, output1_shape],
[data2, output2_shape],
[data3, output3_shape],
]
def test(func, inp, comp, target_shp):
out = func(inp, target_shp)
comp(out, target_shp)
if is_trace:
for symbolic in [False, True]:
for inp, target_shp in cases:
func_traced = trace(symbolic=symbolic)(func)
test(func_traced, inp, comp, target_shp)
test(func_traced, inp, comp, target_shp)
test(func_traced, inp, comp, target_shp)
else:
for inp, target_shp in cases:
test(func, inp, comp, target_shp)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape_shape_inference(is_varnode): def test_reshape_shape_inference(is_varnode):
if is_varnode: if is_varnode:
...@@ -480,6 +523,48 @@ def test_broadcast(is_varnode): ...@@ -480,6 +523,48 @@ def test_broadcast(is_varnode):
F.broadcast_to(x, (1, 3)) F.broadcast_to(x, (1, 3))
@pytest.mark.parametrize("is_trace", [True, False])
def test_broadcast_on_empty_tensor(is_trace):
input1_shape = (100, 0, 1)
output1_shape = (100, 0, 10)
data1 = tensor(np.random.random(input1_shape).astype(np.float32))
input2_shape = (10, 0)
output2_shape = (10, 10, 0)
data2 = tensor(np.random.random(input2_shape).astype(np.float32))
input3_shape = (0, 0, 1, 10)
output3_shape = (10, 0, 0, 10, 10)
data3 = tensor(np.random.random(input3_shape).astype(np.float32))
def comp(out, target_shp):
assert out._tuple_shape == target_shp
def func(x, shp):
return F.broadcast_to(x, shp)
cases = [
[data1, output1_shape],
[data2, output2_shape],
[data3, output3_shape],
]
def test(func, inp, comp, target_shp):
out = func(inp, target_shp)
comp(out, target_shp)
if is_trace:
for symbolic in [False, True]:
for inp, target_shp in cases:
func_traced = trace(symbolic=symbolic)(func)
test(func_traced, inp, comp, target_shp)
test(func_traced, inp, comp, target_shp)
test(func_traced, inp, comp, target_shp)
else:
for inp, target_shp in cases:
test(func, inp, comp, target_shp)
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode): def test_utils_astensor1d(is_varnode):
if is_varnode: if is_varnode:
......
...@@ -259,6 +259,10 @@ void Elemwise::perform( ...@@ -259,6 +259,10 @@ void Elemwise::perform(
mgb_assert(t.comp_node() == out_cn); mgb_assert(t.comp_node() == out_cn);
mgb_assert(t.dtype() == out_dt); mgb_assert(t.dtype() == out_dt);
} }
if (t.shape().is_empty()) {
mgb_assert(dest.empty());
return;
}
inp_shapes[i] = t.shape(); inp_shapes[i] = t.shape();
} }
if (!opr) { if (!opr) {
......
...@@ -1064,4 +1064,37 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { ...@@ -1064,4 +1064,37 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
} }
TEST(TestOprBasicArithElemwise, PerformEmptyIO) {
auto cn = CompNode::load("xpu0");
HostTensorGenerator<> gen;
auto host_x1 = gen({2, 0, 3, 4}),
host_x2 = gen({1});
auto dev_x1 = std::make_shared<DeviceTensorND>(cn),
dev_x2 = std::make_shared<DeviceTensorND>(cn);
dev_x1->copy_from(*host_x1);
dev_x2->copy_from(*host_x2);
auto dev_y = std::make_shared<DeviceTensorND>(cn, dev_x1->dtype());
dev_y->resize(dev_x1->shape());
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn);
// test unary mode
for (auto mode: {Mode::NEGATE, Mode::EXP, Mode::LOG}) {
SmallVector<DeviceTensorND> inputs = {*dev_x1};
ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr));
ASSERT_TRUE(dev_y->empty());
ASSERT_TRUE(dev_y->shape().is_empty());
MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape());
}
// test binary mode
for (auto mode: {Mode::ADD, Mode::MUL, Mode::LT}) {
SmallVector<DeviceTensorND> inputs = {*dev_x1, *dev_x2};
ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr));
ASSERT_TRUE(dev_y->empty());
ASSERT_TRUE(dev_y->shape().is_empty());
MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape());
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册