提交 f4927db2 编写于 作者: M Megvii Engine Team

feat(mge/functional): support where func

GitOrigin-RevId: 9df6421ebee174e6a688a31845cf8072832352cd
上级 56cb5d6a
...@@ -558,7 +558,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: ...@@ -558,7 +558,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool))
x = tensor(np.array([[1, np.inf], [np.nan, 4]], x = tensor(np.array([[1, np.inf], [np.nan, 4]],
dtype=np.float32)) dtype=np.float32))
y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
...@@ -572,19 +572,33 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: ...@@ -572,19 +572,33 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
[[1. 6.] [[1. 6.]
[7. 4.]] [7. 4.]]
""" """
raise NotImplementedError
# v0, index0 = mgb.opr.cond_take( x, y = convert_inputs(x, y)
# x, mask, mode=P.CondTake.Mode.EQ, val=1 if not isinstance(x, (TensorWrapperBase, TensorBase)):
# ) raise TypeError("input x must be a tensor")
# v1, index1 = mgb.opr.cond_take( if not isinstance(y, (TensorWrapperBase, TensorBase)):
# y, mask, mode=P.CondTake.Mode.EQ, val=0 raise TypeError("input y must be a tensor")
# ) if not isinstance(mask, (TensorWrapperBase, TensorBase)):
# out = x.flatten() raise TypeError("mask must be a tensor")
# index = mgb.opr.concat(index0, index1, axis=0) if mask.dtype != np.bool_:
# v = mgb.opr.concat(v0, v1, axis=0) raise ValueError("mask must be bool")
# out = mgb.opr.set_advanced_indexing(out, v)[index] if x.device != mask.device:
# out = out.reshape(x.shape) raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device))
# return out
v0, index0 = cond_take(mask, x)
v1, index1 = cond_take(~mask, y)
if v0.shape == (0,):
out = v1
elif v1.shape == (0,):
out = v0
else:
out = concat([v0, v1])
out[index0] = v0
out[index1] = v1
out = out.reshape(x.shape)
return out
def cond_take(mask: Tensor, x: Tensor) -> Tensor: def cond_take(mask: Tensor, x: Tensor) -> Tensor:
......
...@@ -122,34 +122,34 @@ def test_flatten(): ...@@ -122,34 +122,34 @@ def test_flatten():
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2)
# def test_where(): def test_where():
# maskv0 = np.array([[1, 0], [0, 1]], dtype=np.int32) maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_)
# xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32) xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
# yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32) yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
# maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.int32) maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_)
# xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
# yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
# cases = [ cases = [
# {"input": [maskv0, xv0, yv0]}, {"input": [maskv0, xv0, yv0]},
# {"input": [maskv1, xv1, yv1]}, {"input": [maskv1, xv1, yv1]},
# ] ]
# opr_test(cases, F.where, ref_fn=np.where) opr_test(cases, F.where, ref_fn=np.where)
# maskv2 = np.array([1, 1, 1], dtype=np.int32) maskv2 = np.array([1, 1, 1], dtype=np.bool_)
# xv2 = np.array([1, 3, 2], dtype=np.float32) xv2 = np.array([1, 3, 2], dtype=np.float32)
# yv2 = np.array([5, 6, 9], dtype=np.float32) yv2 = np.array([5, 6, 9], dtype=np.float32)
# maskv3 = np.array([0, 0, 0], dtype=np.int32) maskv3 = np.array([0, 0, 0], dtype=np.bool_)
# xv3 = np.array([1, 3, 2], dtype=np.float32) xv3 = np.array([1, 3, 2], dtype=np.float32)
# yv3 = np.array([5, 6, 9], dtype=np.float32) yv3 = np.array([5, 6, 9], dtype=np.float32)
# cases = [ cases = [
# {"input": [maskv2, xv2, yv2]}, {"input": [maskv2, xv2, yv2]},
# {"input": [maskv3, xv3, yv3]}, {"input": [maskv3, xv3, yv3]},
# ] ]
# opr_test(cases, F.where, ref_fn=np.where) opr_test(cases, F.where, ref_fn=np.where)
def test_matmul(): def test_matmul():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册