Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7252825c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7252825c
编写于
6月 20, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(functional): broadcast_to supports mutable target shape
GitOrigin-RevId: ff79456d5d2d669d20112d57fdeb255ae837e868
上级
2484cd27
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
84 addition
and
61 deletion
+84
-61
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+50
-61
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+34
-0
未找到文件。
imperative/python/src/tensor_utils.cpp
浏览文件 @
7252825c
...
...
@@ -924,78 +924,67 @@ bool enable_fastpath(py::handle inp) {
return
true
;
}
py
::
object
_broadcast_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
args
)
{
py
::
object
shape_hdl
=
_expand_args
(
args
);
bool
auto_infer
=
false
;
py
::
list
lis
;
py
::
list
new_shape
;
if
(
PyList_Check
(
shape_hdl
.
ptr
())
||
PyTuple_Check
(
shape_hdl
.
ptr
()))
{
lis
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
shape_hdl
.
ptr
()));
for
(
size_t
i
=
0
;
i
<
lis
.
size
();
++
i
)
{
if
(
lis
[
i
].
is_none
())
{
auto_infer
=
true
;
size_t
right
=
lis
.
size
()
-
i
;
py
::
object
tshp
=
getattr
(
inp_hdl
,
"_tuple_shape"
);
if
(
tshp
.
is_none
())
{
throw
py
::
index_error
(
"does not support `None` with unknown shape"
);
}
py
::
tuple
inp_shape
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
tshp
);
if
(
inp_shape
.
size
()
>=
right
)
{
if
(
enable_fastpath
(
inp_hdl
))
{
lis
[
i
]
=
inp_shape
[
inp_shape
.
size
()
-
right
];
}
new_shape
.
append
(
inp_shape
[
inp_shape
.
size
()
-
right
]);
}
else
{
throw
py
::
value_error
(
"invalid broadcast shape"
);
py
::
object
_broadcast_cpp
(
py
::
handle
input
,
py
::
handle
args
)
{
py
::
object
shape
=
_expand_args
(
args
);
py
::
list
dims
;
bool
all_imm
;
if
(
PyList_Check
(
shape
.
ptr
())
||
PyTuple_Check
(
shape
.
ptr
()))
{
dims
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
shape
.
ptr
()));
mgb_assert
(
!
dims
.
is_none
());
all_imm
=
true
;
py
::
object
inp_shape
=
py
::
none
();
size_t
inp_ndim
;
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
py
::
object
dim
=
dims
[
i
];
if
(
dim
.
is_none
())
{
ptrdiff_t
right
=
(
ptrdiff_t
)
i
-
dims
.
size
();
if
(
inp_shape
.
is_none
())
{
inp_shape
=
input
.
attr
(
"shape"
);
mgb_assert
(
!
inp_shape
.
is_none
());
inp_ndim
=
py
::
len
(
inp_shape
);
}
if
((
ptrdiff_t
)
inp_ndim
+
right
<
0
)
{
throw
py
::
value_error
(
"size connot be `None` for new axis"
);
}
dim
=
inp_shape
.
attr
(
"__getitem__"
)(
right
);
dims
[
i
]
=
dim
;
}
if
(
py
::
int_
::
check_
(
dim
))
{
if
(
dim
.
cast
<
long
>
()
<
0
)
{
throw
py
::
value_error
(
ssprintf
(
"expect shape[%zu] >= 0 or use `None` to auto infer, got "
"%s"
,
i
,
py
::
repr
(
dims
[
i
]).
cast
<
std
::
string
>
().
c_str
()));
}
}
else
{
new_shape
.
append
(
lis
[
i
]);
if
(
PyLong_Check
(
lis
[
i
].
ptr
()))
{
int32_t
s
=
lis
[
i
].
cast
<
int32_t
>
();
if
(
s
<
0
)
{
throw
py
::
value_error
(
"expect shape["
+
std
::
to_string
(
i
)
+
"] >= 0 or use `None` to auto infer, got "
+
std
::
to_string
(
s
));
}
all_imm
=
false
;
}
}
}
}
if
(
auto_infer
)
{
if
(
enable_fastpath
(
inp_hdl
))
{
shape_hdl
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
lis
);
shape
=
dims
;
}
else
{
shape_hdl
=
_astensor1d_cpp
(
new_shape
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
getattr
(
inp_hdl
,
"device"
),
inp_hdl
);
all_imm
=
false
;
}
bool
fastpath
=
all_imm
&&
enable_fastpath
(
input
);
if
((
!
fastpath
)
&&
(
!
is_tensor
(
shape
)))
{
shape
=
_astensor1d_cpp
(
shape
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
input
.
attr
(
"device"
),
input
);
}
py
::
object
shape_tuple
;
try
{
shape_tuple
=
_make_shape_tuple
(
shape_hdl
);
}
catch
(
py
::
error_already_set
&
err
)
{
shape_tuple
=
py
::
reinterpret_borrow
<
py
::
object
>
(
shape_hdl
);
}
auto
[
shape
,
fastpath
]
=
tuple2vector
(
shape_tuple
);
fastpath
&=
enable_fastpath
(
inp_hdl
);
std
::
shared_ptr
<
OpDef
>
op
;
std
::
vector
<
PyObject
*>
p
;
py
::
object
shape_tensor
;
SmallVector
<
PyObject
*>
p
(
2
);
if
(
fastpath
)
{
op
=
Broadcast
::
make
(
shape
);
p
.
resize
(
2
);
std
::
vector
<
int32_t
>
shape_vec
;
for
(
auto
&&
dim
:
dims
)
{
shape_vec
.
push_back
(
dim
.
cast
<
long
>
());
}
op
=
Broadcast
::
make
(
shape_vec
);
}
else
{
op
=
Broadcast
::
make
();
shape_tensor
=
_astensor1d_cpp
(
shape_hdl
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
getattr
(
inp_hdl
,
"device"
),
inp_hdl
);
p
.
resize
(
3
);
p
[
2
]
=
shape_tensor
.
ptr
();
p
.
push_back
(
shape
.
ptr
());
}
py
::
object
O
p
=
py
::
cast
(
op
);
p
[
0
]
=
O
p
.
ptr
();
p
[
1
]
=
inp
_hdl
.
ptr
();
py
::
object
py_o
p
=
py
::
cast
(
op
);
p
[
0
]
=
py_o
p
.
ptr
();
p
[
1
]
=
inp
ut
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
return
ret
[
0
];
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
7252825c
...
...
@@ -753,6 +753,40 @@ def test_broadcast_on_empty_tensor(is_trace):
test
(
func
,
inp
,
comp
,
target_shp
)
@
pytest
.
mark
.
parametrize
(
"input_shape, target_shapes"
,
[
((
3
,),
[(
2
,
1
,
3
),
(
1
,
2
,
3
),
(
2
,
2
,
3
)]),
((
1
,
3
,
1
),
[(
2
,
None
,
3
),
(
3
,
None
,
3
),
(
1
,
None
,
1
)]),
],
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
True
,
False
])
def
test_broadcast_on_trace
(
is_symbolic
,
input_shape
,
target_shapes
):
x
=
F
.
ones
(
input_shape
)
@
trace
(
symbolic
=
is_symbolic
)
def
broadcast
(
inp
,
shape
):
return
F
.
broadcast_to
(
inp
,
shape
)
for
target_shape
in
target_shapes
:
if
None
in
target_shape
:
symbolic_target_shape
=
tuple
(
map
(
lambda
x
:
None
if
x
is
None
else
Tensor
(
x
),
target_shape
)
)
output
=
broadcast
(
x
,
symbolic_target_shape
)
for
i
in
range
(
len
(
target_shape
)):
if
target_shape
[
i
]
is
not
None
:
assert
output
.
_tuple_shape
[
i
]
==
target_shape
[
i
]
else
:
assert
(
output
.
_tuple_shape
[
i
]
==
x
.
_tuple_shape
[
i
-
len
(
target_shape
)]
)
else
:
symbolic_target_shape
=
Tensor
(
target_shape
)
output
=
broadcast
(
x
,
symbolic_target_shape
)
assert
output
.
_tuple_shape
==
target_shape
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_utils_astensor1d
(
is_varnode
):
if
is_varnode
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录