Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9f4bffbd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9f4bffbd
编写于
10月 12, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/tensor): fix valid_broadcast
GitOrigin-RevId: 562b7664e23cd336d942568203df03958b67a4b7
上级
af349d61
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
19 addition
and
4 deletion
+19
-4
imperative/python/megengine/core/tensor/indexing.py
imperative/python/megengine/core/tensor/indexing.py
+1
-1
imperative/python/megengine/core/tensor/tensor_wrapper.py
imperative/python/megengine/core/tensor/tensor_wrapper.py
+3
-3
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+15
-0
未找到文件。
imperative/python/megengine/core/tensor/indexing.py
浏览文件 @
9f4bffbd
...
@@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
...
@@ -173,7 +173,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
item
.
append
(
True
)
item
.
append
(
True
)
v
=
get_index
(
v
)
v
=
get_index
(
v
)
assert
np
.
issubdtype
(
v
.
dtype
,
np
.
integer
)
or
np
.
issubdtype
(
assert
np
.
issubdtype
(
v
.
dtype
,
np
.
integer
)
or
np
.
issubdtype
(
v
.
dtype
,
np
.
bool
v
.
dtype
,
np
.
bool
_
),
"var type in the subscript must be int or bool"
),
"var type in the subscript must be int or bool"
tensors
.
append
(
v
)
tensors
.
append
(
v
)
...
...
imperative/python/megengine/core/tensor/tensor_wrapper.py
浏览文件 @
9f4bffbd
...
@@ -65,10 +65,10 @@ def _broadcast(inp, shape):
...
@@ -65,10 +65,10 @@ def _broadcast(inp, shape):
)
)
)
)
if
isinstance
(
src
,
(
Tensor
,
TensorWrapperBase
)):
if
isinstance
(
src
,
(
Tensor
Base
,
TensorWrapperBase
)):
src
=
src
.
numpy
()
src
=
src
.
numpy
()
if
isinstance
(
tar
,
(
Tensor
,
TensorWrapperBase
)):
if
isinstance
(
tar
,
(
Tensor
Base
,
TensorWrapperBase
)):
tar
=
tar
.
numpy
()
tar
=
tar
.
numpy
()
if
len
(
src
)
>
len
(
tar
):
if
len
(
src
)
>
len
(
tar
):
...
@@ -78,8 +78,8 @@ def _broadcast(inp, shape):
...
@@ -78,8 +78,8 @@ def _broadcast(inp, shape):
if
src
[
-
i
-
1
]
!=
1
and
src
[
-
i
-
1
]
!=
tar
[
-
i
-
1
]:
if
src
[
-
i
-
1
]
!=
1
and
src
[
-
i
-
1
]
!=
tar
[
-
i
-
1
]:
failed
()
failed
()
valid_broadcast
(
inp
.
shape
,
shape
)
shape
=
utils
.
astensor1d
(
shape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
utils
.
astensor1d
(
shape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
valid_broadcast
(
inp
.
shape
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
inp
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
inp
,
shape
)
return
result
return
result
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
9f4bffbd
...
@@ -379,3 +379,18 @@ def test_trace_nms():
...
@@ -379,3 +379,18 @@ def test_trace_nms():
f
(
*
make_inputs
(
10
))
f
(
*
make_inputs
(
10
))
f
(
*
make_inputs
(
20
))
f
(
*
make_inputs
(
20
))
f
(
*
make_inputs
(
30
))
f
(
*
make_inputs
(
30
))
def
test_trace_valid_broadcast
():
set_tensor_shape
(
True
)
x1
=
tensor
(
np
.
random
.
randn
(
1
,
1
))
x2
=
tensor
(
np
.
random
.
randn
(
1
,
2
))
shape
=
(
tensor
([
2
]),
tensor
([
2
]))
@
trace
(
symbolic
=
False
)
def
f
(
x
,
shape
):
y
=
F
.
broadcast_to
(
x
,
shape
)
return
y
f
(
x1
,
shape
)
f
(
x2
,
shape
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录