Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f4927db2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f4927db2
编写于
8月 26, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/functional): support where func
GitOrigin-RevId: 9df6421ebee174e6a688a31845cf8072832352cd
上级
56cb5d6a
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
51 addition
and
37 deletion
+51
-37
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+28
-14
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+23
-23
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
f4927db2
...
...
@@ -558,7 +558,7 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
from megengine import tensor
import megengine.functional as F
mask = tensor(np.array([[
1, 0], [0, 1]], dtype=np.int32
))
mask = tensor(np.array([[
True, False], [False, True]], dtype=np.bool
))
x = tensor(np.array([[1, np.inf], [np.nan, 4]],
dtype=np.float32))
y = tensor(np.array([[5, 6], [7, 8]], dtype=np.float32))
...
...
@@ -572,19 +572,33 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
[[1. 6.]
[7. 4.]]
"""
raise
NotImplementedError
# v0, index0 = mgb.opr.cond_take(
# x, mask, mode=P.CondTake.Mode.EQ, val=1
# )
# v1, index1 = mgb.opr.cond_take(
# y, mask, mode=P.CondTake.Mode.EQ, val=0
# )
# out = x.flatten()
# index = mgb.opr.concat(index0, index1, axis=0)
# v = mgb.opr.concat(v0, v1, axis=0)
# out = mgb.opr.set_advanced_indexing(out, v)[index]
# out = out.reshape(x.shape)
# return out
x
,
y
=
convert_inputs
(
x
,
y
)
if
not
isinstance
(
x
,
(
TensorWrapperBase
,
TensorBase
)):
raise
TypeError
(
"input x must be a tensor"
)
if
not
isinstance
(
y
,
(
TensorWrapperBase
,
TensorBase
)):
raise
TypeError
(
"input y must be a tensor"
)
if
not
isinstance
(
mask
,
(
TensorWrapperBase
,
TensorBase
)):
raise
TypeError
(
"mask must be a tensor"
)
if
mask
.
dtype
!=
np
.
bool_
:
raise
ValueError
(
"mask must be bool"
)
if
x
.
device
!=
mask
.
device
:
raise
ValueError
(
"ambiguous device: {} vs {}"
.
format
(
x
.
device
,
mask
.
device
))
v0
,
index0
=
cond_take
(
mask
,
x
)
v1
,
index1
=
cond_take
(
~
mask
,
y
)
if
v0
.
shape
==
(
0
,):
out
=
v1
elif
v1
.
shape
==
(
0
,):
out
=
v0
else
:
out
=
concat
([
v0
,
v1
])
out
[
index0
]
=
v0
out
[
index1
]
=
v1
out
=
out
.
reshape
(
x
.
shape
)
return
out
def
cond_take
(
mask
:
Tensor
,
x
:
Tensor
)
->
Tensor
:
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
f4927db2
...
...
@@ -122,34 +122,34 @@ def test_flatten():
opr_test
(
cases
,
F
.
flatten
,
compare_fn
=
compare_fn
,
start_axis
=
1
,
end_axis
=
2
)
#
def test_where():
# maskv0 = np.array([[1, 0], [0, 1]], dtype=np.int32
)
#
xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)
#
yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32)
def
test_where
():
maskv0
=
np
.
array
([[
1
,
0
],
[
0
,
1
]],
dtype
=
np
.
bool_
)
xv0
=
np
.
array
([[
1
,
np
.
inf
],
[
np
.
nan
,
4
]],
dtype
=
np
.
float32
)
yv0
=
np
.
array
([[
5
,
6
],
[
7
,
8
]],
dtype
=
np
.
float32
)
# maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.int32
)
#
xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
#
yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
maskv1
=
np
.
array
([[
1
,
0
,
1
],
[
1
,
0
,
0
],
[
1
,
1
,
0
]],
dtype
=
np
.
bool_
)
xv1
=
np
.
array
([[
1
,
np
.
inf
,
2
],
[
0
,
np
.
nan
,
4
],
[
1
,
5
,
7
]],
dtype
=
np
.
float32
)
yv1
=
np
.
array
([[
5
,
6
,
9
],
[
2
,
7
,
8
],
[
2
,
1
,
9
]],
dtype
=
np
.
float32
)
#
cases = [
#
{"input": [maskv0, xv0, yv0]},
#
{"input": [maskv1, xv1, yv1]},
#
]
#
opr_test(cases, F.where, ref_fn=np.where)
cases
=
[
{
"input"
:
[
maskv0
,
xv0
,
yv0
]},
{
"input"
:
[
maskv1
,
xv1
,
yv1
]},
]
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
)
# maskv2 = np.array([1, 1, 1], dtype=np.int32
)
#
xv2 = np.array([1, 3, 2], dtype=np.float32)
#
yv2 = np.array([5, 6, 9], dtype=np.float32)
maskv2
=
np
.
array
([
1
,
1
,
1
],
dtype
=
np
.
bool_
)
xv2
=
np
.
array
([
1
,
3
,
2
],
dtype
=
np
.
float32
)
yv2
=
np
.
array
([
5
,
6
,
9
],
dtype
=
np
.
float32
)
# maskv3 = np.array([0, 0, 0], dtype=np.int32
)
#
xv3 = np.array([1, 3, 2], dtype=np.float32)
#
yv3 = np.array([5, 6, 9], dtype=np.float32)
maskv3
=
np
.
array
([
0
,
0
,
0
],
dtype
=
np
.
bool_
)
xv3
=
np
.
array
([
1
,
3
,
2
],
dtype
=
np
.
float32
)
yv3
=
np
.
array
([
5
,
6
,
9
],
dtype
=
np
.
float32
)
#
cases = [
#
{"input": [maskv2, xv2, yv2]},
#
{"input": [maskv3, xv3, yv3]},
#
]
#
opr_test(cases, F.where, ref_fn=np.where)
cases
=
[
{
"input"
:
[
maskv2
,
xv2
,
yv2
]},
{
"input"
:
[
maskv3
,
xv3
,
yv3
]},
]
opr_test
(
cases
,
F
.
where
,
ref_fn
=
np
.
where
)
def
test_matmul
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录