Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b8d8886e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b8d8886e
编写于
6月 08, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/tensor): combine Dict and TensorDict
GitOrigin-RevId: 6b6c03c04b7c97c29d30831e07c21e8afd3c3f40
上级
7751a067
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
31 addition
and
39 deletion
+31
-39
python_module/megengine/core/tensor.py
python_module/megengine/core/tensor.py
+31
-39
未找到文件。
python_module/megengine/core/tensor.py
浏览文件 @
b8d8886e
...
...
@@ -425,7 +425,7 @@ class Tensor:
def
__getitem__
(
self
,
idx
):
return
wrap_io_tensor
(
self
.
_symvar
.
__getitem__
)(
_wrap_idx
(
idx
))
def
set_subtensor
(
self
,
val
:
"Tensor"
):
def
set_subtensor
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Return a object which supports using ``__getitem__`` to set subtensor.
...
...
@@ -433,7 +433,7 @@ class Tensor:
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_subtensor
,
val
)
def
incr_subtensor
(
self
,
val
:
"Tensor"
):
def
incr_subtensor
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Return a object which supports using ``__getitem__`` to increase subtensor.
...
...
@@ -442,7 +442,7 @@ class Tensor:
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_subtensor
,
val
)
@
property
def
ai
(
self
):
def
ai
(
self
)
->
_MGBIndexWrapper
:
r
"""
Return a object which supports complex index method to get subtensor.
...
...
@@ -465,20 +465,20 @@ class Tensor:
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
advanced_indexing
)
def
set_ai
(
self
,
val
:
"Tensor"
):
def
set_ai
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Equal to :meth:`~.Tensor.set_subtensor` which supports advanced indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_advanced_indexing
,
val
)
def
incr_ai
(
self
,
val
:
"Tensor"
):
def
incr_ai
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which supports advanced indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_advanced_indexing
,
val
)
@
property
def
mi
(
self
):
def
mi
(
self
)
->
_MGBIndexWrapper
:
r
"""
Return a object which supports getting subtensor by
the coordinates which is Cartesian product of given index.
...
...
@@ -502,20 +502,20 @@ class Tensor:
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
mesh_indexing
)
def
set_mi
(
self
,
val
:
"Tensor"
):
def
set_mi
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Equal to :meth:`~.Tensor.set_subtensor` which using mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
set_mesh_indexing
,
val
)
def
incr_mi
(
self
,
val
:
"Tensor"
):
def
incr_mi
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which using mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
incr_mesh_indexing
,
val
)
@
property
def
batched_mi
(
self
):
def
batched_mi
(
self
)
->
_MGBIndexWrapper
:
r
"""
Return a object which supports getting subtensor by
batched mesh indexing.
...
...
@@ -555,13 +555,13 @@ class Tensor:
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_mesh_indexing
)
def
batched_set_mi
(
self
,
val
:
"Tensor"
):
def
batched_set_mi
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
return
_MGBIndexWrapper
(
self
,
mgb
.
opr
.
batched_set_mesh_indexing
,
val
)
def
batched_incr_mi
(
self
,
val
:
"Tensor"
):
def
batched_incr_mi
(
self
,
val
:
"Tensor"
)
->
_MGBIndexWrapper
:
r
"""
Equal to :meth:`~.Tensor.incr_subtensor` which using batched mesh indexing.
"""
...
...
@@ -680,18 +680,31 @@ def tensor(
return
Tensor
(
shared_nd
,
requires_grad
=
requires_grad
)
class
Dict
(
collections
.
MutableMapping
):
def
__init__
(
self
,
*
args
,
key
=
None
,
**
kwargs
):
class
TensorDict
(
collections
.
MutableMapping
):
r
"""
A helper class to maintain dict with Tensor key.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
data
=
{}
if
key
:
self
.
keyfn
=
key
for
i
in
args
:
self
.
update
(
i
)
self
.
update
(
**
kwargs
)
@
staticmethod
def
keyfn
(
key
):
# pylint: disable=method-hidden
return
key
class
keyfn
:
def
__new__
(
cls
,
x
:
Tensor
):
if
not
isinstance
(
x
,
Tensor
):
return
x
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
x
:
Tensor
):
self
.
_data
=
x
# do not save id directly to make pickle work
def
__hash__
(
self
):
return
id
(
self
.
_data
)
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
type
(
self
))
and
id
(
self
.
_data
)
==
id
(
other
.
_data
)
def
__getitem__
(
self
,
key
):
_
,
v
=
self
.
data
[
self
.
keyfn
(
key
)]
...
...
@@ -709,24 +722,3 @@ class Dict(collections.MutableMapping):
def
__len__
(
self
):
return
len
(
self
.
data
)
class
TensorDict
(
Dict
):
# pylint: disable=too-many-ancestors
class
keyfn
:
def
__new__
(
cls
,
x
:
Tensor
):
if
not
isinstance
(
x
,
Tensor
):
return
x
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
x
:
Tensor
):
self
.
_data
=
x
# do not save id directly to make pickle work
def
__hash__
(
self
):
return
id
(
self
.
_data
)
def
__eq__
(
self
,
other
):
# pylint: disable=undefined-variable
return
isinstance
(
other
,
__class__
)
and
id
(
self
.
_data
)
==
id
(
other
.
_data
)
def
__init__
(
self
,
*
args
):
super
().
__init__
(
*
args
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录