Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7751a067
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7751a067
编写于
6月 08, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(mge/tensor): add advanced index related docs
GitOrigin-RevId: 31735ddac487826aa604c2a2aec2d72b572fa609
上级
7b0dbe6a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
158 addition
and
32 deletion
+158
-32
python_module/megengine/core/tensor.py
python_module/megengine/core/tensor.py
+158
-32
未找到文件。
python_module/megengine/core/tensor.py
浏览文件 @
7751a067
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...
...
@@ -10,7 +9,7 @@ import collections
import
functools
import
itertools
import
weakref
from
typing
import
Union
from
typing
import
Callable
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -68,24 +67,38 @@ def _wrap_symbolvar_binary_op(f):
return
wrapped
def
wrap_slice
(
inp
):
def
_wrap_slice
(
inp
:
slice
):
r
"""
A wrapper to handle Tensor values in ``inp`` slice.
"""
start
=
inp
.
start
.
_symvar
if
isinstance
(
inp
.
start
,
Tensor
)
else
inp
.
start
stop
=
inp
.
stop
.
_symvar
if
isinstance
(
inp
.
stop
,
Tensor
)
else
inp
.
stop
step
=
inp
.
step
.
_symvar
if
isinstance
(
inp
.
step
,
Tensor
)
else
inp
.
step
return
slice
(
start
,
stop
,
step
)
def
wrap_idx
(
idx
):
def
_wrap_idx
(
idx
:
Tuple
[
Union
[
int
,
"Tensor"
]]):
r
"""
A wrapper to handle Tensor values in ``idx``.
"""
if
not
isinstance
(
idx
,
tuple
):
idx
=
(
idx
,)
idx
=
tuple
(
i
.
_symvar
if
isinstance
(
i
,
Tensor
)
else
i
for
i
in
idx
)
idx
=
tuple
(
wrap_slice
(
i
)
if
isinstance
(
i
,
slice
)
else
i
for
i
in
idx
)
idx
=
tuple
(
_
wrap_slice
(
i
)
if
isinstance
(
i
,
slice
)
else
i
for
i
in
idx
)
return
idx
class
MGBIndexWrapper
:
def
__init__
(
self
,
dest
,
mgb_index
,
val
=
None
):
class
_MGBIndexWrapper
:
r
"""
A wrapper class to handle ``__getitem__`` for index containing Tensor values.
:param dest: a destination Tensor to do indexing on.
:param mgb_index: an ``_internal`` helper function indicating how to index.
:param val: a optional Tensor parameter used for ``mgb_index``.
"""
def
__init__
(
self
,
dest
:
"Tensor"
,
mgb_index
:
Callable
,
val
=
None
):
self
.
dest
=
dest
self
.
val
=
val
self
.
mgb_index
=
mgb_index
...
...
@@ -93,16 +106,22 @@ class MGBIndexWrapper:
def
__getitem__
(
self
,
idx
):
if
self
.
val
is
None
:
return
wrap_io_tensor
(
self
.
mgb_index
(
self
.
dest
.
_symvar
).
__getitem__
)(
wrap_idx
(
idx
)
_
wrap_idx
(
idx
)
)
else
:
return
wrap_io_tensor
(
self
.
mgb_index
(
self
.
dest
.
_symvar
,
self
.
val
.
_symvar
).
__getitem__
)(
wrap_idx
(
idx
))
)(
_wrap_idx
(
idx
))
class
_Guard
:
r
"""
A wrapper class with custom ``__del__`` method calling ``deleter``.
:param deleter: a function to be called in ``__del__``.
"""
class
Guard
:
def
__init__
(
self
,
deleter
):
def
__init__
(
self
,
deleter
:
Callable
):
self
.
deleter
=
deleter
def
__del__
(
self
):
...
...
@@ -161,6 +180,7 @@ class Tensor:
return
self
.
__sym
.
inferred_value
def
item
(
self
):
r
"""If tensor only has only one value, return it."""
return
self
.
numpy
().
item
()
def
_attach
(
self
,
comp_graph
,
*
,
volatile
=
True
):
...
...
@@ -204,7 +224,7 @@ class Tensor:
if
self
is
not
None
:
self
.
__sym_override
=
None
deleters
.
add
(
Guard
(
restore
))
deleters
.
add
(
_
Guard
(
restore
))
self
.
__sym_override
=
symvar
@
property
...
...
@@ -403,43 +423,149 @@ class Tensor:
# mgb indexing family
def
__getitem__
(
self
,
idx
):
return
wrap_io_tensor
(
self
.
_symvar
.
__getitem__
)(
wrap_idx
(
idx
))
return
wrap_io_tensor
(
self
.
_symvar
.
__getitem__
)(
_wrap_idx
(
idx
))
def
set_subtensor
(
self
,
val
:
"Tensor"
):
r
"""
Return a object which supports using ``__getitem__`` to set subtensor.
def
set_subtensor
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_subtensor
,
val
)
``c = a.set_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] = b``.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_subtensor
,
val
)
def
incr_subtensor
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_subtensor
,
val
)
def
incr_subtensor
(
self
,
val
:
"Tensor"
):
r
"""
Return a object which supports using ``__getitem__`` to increase subtensor.
``c = a.incr_subtensor(b)[idx]`` is equivalent to ``c = a.copy()`` and ``c[idx] += b``.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_subtensor
,
val
)
@
property
def
ai
(
self
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
advanced_indexing
)
r
"""
Return a object which supports complex index method to get subtensor.
def
set_ai
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_advanced_indexing
,
val
)
Examples:
def
incr_ai
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_advanced_indexing
,
val
)
.. testcode::
from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.ai[:, [2, 3]])
Outputs:
.. testoutput::
Tensor([[ 2. 3.]
[ 6. 7.]
[10. 11.]
[14. 15.]])
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
advanced_indexing
)
def
set_ai
(
self
,
val
:
"Tensor"
):
r
"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_advanced_indexing
,
val
)
def
incr_ai
(
self
,
val
:
"Tensor"
):
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_advanced_indexing
,
val
)
@
property
def
mi
(
self
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
mesh_indexing
)
r
"""
Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index.
Examples:
def
set_mi
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_mesh_indexing
,
val
)
.. testcode::
from megengine import tensor
a = tensor(np.arange(16, dtype=np.float32).reshape((4, 4)))
print(a.mi[[1, 2], [2, 3]])
# is equal to elements on [1, 2] * [2, 3] = [[(1,2), (1, 3)], [(2, 2), (2, 3)]]
# a[1,2] = 6, a[1,3] = 7, a[2,2] = 10, a[2,3] = 11
Outputs:
.. testoutput::
Tensor([[ 6. 7.]
[10. 11.]])
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
mesh_indexing
)
def
set_mi
(
self
,
val
:
"Tensor"
):
r
"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_mesh_indexing
,
val
)
def
incr_mi
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_mesh_indexing
,
val
)
def
incr_mi
(
self
,
val
:
"Tensor"
):
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_mesh_indexing
,
val
)
@
property
def
batched_mi
(
self
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_mesh_indexing
)
r
"""
Return a object which supports getting subtensor by
batched mesh indexing.
def
batched_set_mi
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_set_mesh_indexing
,
val
)
For Tensor ``a`` and index ``idx``, each value of the ``idx`` need to be a 2-dim matrix or slice.
Cartesian product ``... * idx[k-1][i] * idx[k][i] * idx[k+1][i] * ...`` will be a subtensor from ``a[i]``.
Each matrix ``idx[k]`` should have the size of ``batched_dim`` rows as ``idx[0]`` indicated.
And for slice value, it will apply same slice for each ``batched_dim``. For more details see the example below.
def
batched_incr_mi
(
self
,
val
):
return
MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_incr_mesh_indexing
,
val
)
Examples:
.. testcode::
from megengine import tensor
a = tensor(np.arange(144, dtype=np.float32).reshape((3, 3, 4, 4)))
print(a.batched_mi[:2, [[0],[1]],[[0,1],[2,3]],[[0],[1]]])
# is equal to elements from a[0] with ``[0] * [0,1] * [0] = [[[(0,0,0)], [(0,1,0)]]]``(shape is [1,2,1])
# and from a[1] with ``[1] * [2,3] * [1] = [[[(1,2,1)], [(1,3,1)]]]``(shape is also [1,2,1])
# a[0,0,0,0] = 0, a[0,0,1,0] = 4, a[1,1,2,1] = 73, a[1,1,3,1] = 77
print(a.batched_mi[:2, [[0],[1]], :2, :1])
# is equal to ``a.batched_mi[:2, [[0],[1]], [[0,1],[0,1]],[[0],[0]]]``
Outputs:
.. testoutput::
Tensor([[[[ 0.]
[ 4.]]]
[[[73.]
[77.]]]])
Tensor([[[[ 0.]
[ 4.]]]
[[[64.]
[68.]]]])
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_mesh_indexing
)
def
batched_set_mi
(
self
,
val
:
"Tensor"
):
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_set_mesh_indexing
,
val
)
def
batched_incr_mi
(
self
,
val
:
"Tensor"
):
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_incr_mesh_indexing
,
val
)
def
__array__
(
self
,
dtype
=
None
):
if
dtype
is
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录