From f4927db2fb8089460d266885ebc8fa8dd19f3950 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 26 Aug 2020 13:32:05 +0800 Subject: [PATCH] feat(mge/functional): support where func GitOrigin-RevId: 9df6421ebee174e6a688a31845cf8072832352cd --- .../python/megengine/functional/tensor.py | 42 +++++++++++------ .../test/unit/functional/test_functional.py | 46 +++++++++---------- 2 files changed, 51 insertions(+), 37 deletions(-) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 15c26bf0a..1d1c68930 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -558,7 +558,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: from megengine import tensor import megengine.functional as F - mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) + mask = tensor(np.array([[True, False], [False, True]], dtype=np.bool)) x = tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)) @@ -572,19 +572,33 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: [[1. 6.] [7. 4.]] """ - raise NotImplementedError - # v0, index0 = mgb.opr.cond_take( - # x, mask, mode=P.CondTake.Mode.EQ, val=1 - # ) - # v1, index1 = mgb.opr.cond_take( - # y, mask, mode=P.CondTake.Mode.EQ, val=0 - # ) - # out = x.flatten() - # index = mgb.opr.concat(index0, index1, axis=0) - # v = mgb.opr.concat(v0, v1, axis=0) - # out = mgb.opr.set_advanced_indexing(out, v)[index] - # out = out.reshape(x.shape) - # return out + + x, y = convert_inputs(x, y) + if not isinstance(x, (TensorWrapperBase, TensorBase)): + raise TypeError("input x must be a tensor") + if not isinstance(y, (TensorWrapperBase, TensorBase)): + raise TypeError("input y must be a tensor") + if not isinstance(mask, (TensorWrapperBase, TensorBase)): + raise TypeError("mask must be a tensor") + if mask.dtype != np.bool_: + raise ValueError("mask must be bool") + if x.device != mask.device: + raise ValueError("ambiguous device: {} vs {}".format(x.device, mask.device)) + + v0, index0 = cond_take(mask, x) + v1, index1 = cond_take(~mask, y) + + if v0.shape == (0,): + out = v1 + elif v1.shape == (0,): + out = v0 + else: + out = concat([v0, v1]) + + out[index0] = v0 + out[index1] = v1 + out = out.reshape(x.shape) + return out def cond_take(mask: Tensor, x: Tensor) -> Tensor: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index beaff6484..d83dc2024 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -122,34 +122,34 @@ def test_flatten(): opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) -# def test_where(): -# maskv0 = np.array([[1, 0], [0, 1]], dtype=np.int32) -# xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32) -# yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32) +def test_where(): + maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_) + xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32) + yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32) -# maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.int32) -# xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) -# yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) + maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_) + xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) + yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) -# cases = [ -# {"input": [maskv0, xv0, yv0]}, -# {"input": [maskv1, xv1, yv1]}, -# ] -# opr_test(cases, F.where, ref_fn=np.where) + cases = [ + {"input": [maskv0, xv0, yv0]}, + {"input": [maskv1, xv1, yv1]}, + ] + opr_test(cases, F.where, ref_fn=np.where) -# maskv2 = np.array([1, 1, 1], dtype=np.int32) -# xv2 = np.array([1, 3, 2], dtype=np.float32) -# yv2 = np.array([5, 6, 9], dtype=np.float32) + maskv2 = np.array([1, 1, 1], dtype=np.bool_) + xv2 = np.array([1, 3, 2], dtype=np.float32) + yv2 = np.array([5, 6, 9], dtype=np.float32) -# maskv3 = np.array([0, 0, 0], dtype=np.int32) -# xv3 = np.array([1, 3, 2], dtype=np.float32) -# yv3 = np.array([5, 6, 9], dtype=np.float32) + maskv3 = np.array([0, 0, 0], dtype=np.bool_) + xv3 = np.array([1, 3, 2], dtype=np.float32) + yv3 = np.array([5, 6, 9], dtype=np.float32) -# cases = [ -# {"input": [maskv2, xv2, yv2]}, -# {"input": [maskv3, xv3, yv3]}, -# ] -# opr_test(cases, F.where, ref_fn=np.where) + cases = [ + {"input": [maskv2, xv2, yv2]}, + {"input": [maskv3, xv3, yv3]}, + ] + opr_test(cases, F.where, ref_fn=np.where) def test_matmul(): -- GitLab