Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c34a75d0
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c34a75d0
编写于
1月 28, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(trace): assume result is not scalar when shape is valid
GitOrigin-RevId: beee2d0f28620cc3410d5c4172e0413e012114fd
上级
bebb2cf4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
26 addition
and
19 deletion
+26
-19
imperative/python/megengine/core/tensor/indexing.py
imperative/python/megengine/core/tensor/indexing.py
+9
-5
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+3
-7
imperative/src/impl/transformations/scalar.cpp
imperative/src/impl/transformations/scalar.cpp
+14
-7
未找到文件。
imperative/python/megengine/core/tensor/indexing.py
浏览文件 @
c34a75d0
...
...
@@ -119,12 +119,16 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
else
1
)
else
:
if
ndim_indexed
>
inp
.
ndim
:
raise
IndexError
(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed"
.
format
(
inp
.
ndim
,
len
(
tuple_val
)
try
:
if
ndim_indexed
>
inp
.
ndim
:
raise
IndexError
(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed"
.
format
(
inp
.
ndim
,
len
(
tuple_val
)
)
)
)
except
ValueError
:
# ignore
pass
tuple_val
=
remove_ellipsis
(
inp
,
tuple_val
)
use_subtensor
=
True
...
...
imperative/python/src/tensor.cpp
浏览文件 @
c34a75d0
...
...
@@ -272,16 +272,12 @@ PyObject* TensorWrapper::device() {
PyObject
*
TensorWrapper
::
numpy
()
{
auto
hv
=
m_tensor
->
numpy
();
// if (!hv) {
// PyErr_SetString(PyExc_ValueError, "tensor invalid");
// return nullptr;
// }
auto
arr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
hv
->
as_nd
(
true
),
npy
::
ShareType
::
TRY_SHARE
));
if
(
!
arr
)
{
if
(
!
hv
)
{
PyErr_SetString
(
PyExc_ValueError
,
"tensor invalid"
);
return
nullptr
;
}
auto
arr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
hv
->
as_nd
(
true
),
npy
::
ShareType
::
TRY_SHARE
));
if
(
hv
->
shape
().
is_scalar
())
{
mgb_assert
(
PyArray_Check
(
arr
.
ptr
()));
return
PyArray_Squeeze
(
reinterpret_cast
<
PyArrayObject
*>
(
arr
.
ptr
()));
...
...
imperative/src/impl/transformations/scalar.cpp
浏览文件 @
c34a75d0
...
...
@@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) {
if
(
shape
.
is
<
ScalarValue
>
())
{
return
false
;
}
// may have performance issue
auto
shape_of_shape
=
shape
.
shape
();
if
(
!
shape_of_shape
)
{
// assume not scalar
...
...
@@ -211,14 +212,21 @@ std::vector<ValueRef> subtensor_rule(
const
Subtensor
&
subtensor
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
>=
1
);
auto
input
=
inputs
[
0
];
size_t
ndim
=
input
.
is
<
ScalarValue
>
()
?
0
:
input
.
shape
()
->
ndim
;
for
(
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
:
subtensor
.
items
)
{
if
(
idx
)
{
ndim
--
;
bool
is_scalar
;
mgb_assert
(
!
input
.
is
<
ScalarValue
>
(),
"subtensor shouldn't have scalar input"
);
if
(
auto
shape
=
input
.
shape
())
{
size_t
ndim
=
input
.
shape
()
->
ndim
;
for
(
auto
&&
[
axis
,
begin
,
end
,
step
,
idx
]
:
subtensor
.
items
)
{
if
(
idx
)
{
ndim
--
;
}
}
is_scalar
=
ndim
==
0
;
}
else
{
is_scalar
=
false
;
}
auto
output
=
imperative
::
apply
(
subtensor
,
unwrap_inputs
(
inputs
))[
0
];
if
(
!
ndim
)
{
if
(
is_scalar
)
{
return
{
ScalarValue
::
make
(
output
)};
}
else
{
return
{
output
};
...
...
@@ -261,8 +269,7 @@ std::vector<ValueRef> fastpath_copy_rule(
std
::
vector
<
ValueRef
>
reshape_rule
(
const
Reshape
&
reshape
,
Span
<
ValueRef
>
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
2
);
bool
is_scalar
=
(
!
inputs
[
1
].
is
<
ScalarValue
>
())
&&
*
inputs
[
1
].
shape
()
==
ValueShape
{
0
};
bool
is_scalar
=
is_scalar_shape
(
inputs
[
1
]);
auto
unwrapped_input
=
inputs
[
0
].
is
<
ScalarValue
>
()
?
inputs
[
0
].
cast
<
ScalarValue
>
().
value
()
:
inputs
[
0
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录