Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a926878c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a926878c
编写于
5月 15, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): remove symbolvar of imperative
GitOrigin-RevId: 16da6d1491526b707ea6851fb68e330c02cc788a
上级
14813d13
变更
25
显示空白变更内容
内联
并排
Showing
25 changed file
with
356 addition
and
398 deletion
+356
-398
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+5
-7
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+2
-3
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+1
-1
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+1
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+2
-2
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+16
-20
imperative/python/megengine/functional/tensor_cache.py
imperative/python/megengine/functional/tensor_cache.py
+2
-2
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+2
-0
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+1
-1
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+1
-1
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+69
-43
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+40
-34
imperative/python/src/graph_rt.h
imperative/python/src/graph_rt.h
+3
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+42
-158
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+5
-7
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+67
-105
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+2
-1
imperative/python/test/helpers/utils.py
imperative/python/test/helpers/utils.py
+1
-1
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+12
-0
imperative/python/test/unit/utils/test_network.py
imperative/python/test/unit/utils/test_network.py
+4
-2
imperative/src/impl/basic_operators.cpp
imperative/src/impl/basic_operators.cpp
+8
-0
imperative/src/include/megbrain/imperative/basic_operators.h
imperative/src/include/megbrain/imperative/basic_operators.h
+17
-0
imperative/src/include/megbrain/imperative/basic_values.h
imperative/src/include/megbrain/imperative/basic_values.h
+19
-0
imperative/src/include/megbrain/imperative/transformations/symbol.h
.../src/include/megbrain/imperative/transformations/symbol.h
+33
-9
imperative/src/include/megbrain/imperative/value.h
imperative/src/include/megbrain/imperative/value.h
+1
-0
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
a926878c
...
@@ -7,9 +7,7 @@ from typing import Union
...
@@ -7,9 +7,7 @@ from typing import Union
import
numpy
as
np
import
numpy
as
np
from
..
import
_config
from
..
import
_config
from
.._imperative_rt.common
import
CompNode
from
.._imperative_rt.core2
import
(
from
.._imperative_rt.core2
import
(
SymbolVar
,
Tensor
,
Tensor
,
apply
,
apply
,
astype_cpp
,
astype_cpp
,
...
@@ -17,9 +15,11 @@ from .._imperative_rt.core2 import (
...
@@ -17,9 +15,11 @@ from .._imperative_rt.core2 import (
broadcast_cpp
,
broadcast_cpp
,
getitem_cpp
,
getitem_cpp
,
matmul_cpp
,
matmul_cpp
,
reshape_cpp
,
setitem_cpp
,
squeeze_cpp
,
transpose_cpp
,
)
)
from
.._imperative_rt.core2
import
reduce_to_scalar
as
_reduce_to_scalar
from
.._imperative_rt.core2
import
reshape_cpp
,
setitem_cpp
,
squeeze_cpp
,
transpose_cpp
from
..ops
import
builtin
from
..ops
import
builtin
from
.
import
amp
from
.
import
amp
from
.utils
import
_normalize_axis
,
astensor1d
,
cast_tensors
,
make_shape_tuple
,
subgraph
from
.utils
import
_normalize_axis
,
astensor1d
,
cast_tensors
,
make_shape_tuple
,
subgraph
...
@@ -189,9 +189,7 @@ def _todo(*_):
...
@@ -189,9 +189,7 @@ def _todo(*_):
def
_expand_args
(
args
):
def
_expand_args
(
args
):
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
if
isinstance
(
if
isinstance
(
args
[
0
],
(
collections
.
abc
.
Sequence
,
Tensor
,
np
.
ndarray
),):
args
[
0
],
(
collections
.
abc
.
Sequence
,
Tensor
,
SymbolVar
,
np
.
ndarray
),
):
args
=
args
[
0
]
args
=
args
[
0
]
return
args
return
args
...
...
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
a926878c
...
@@ -8,7 +8,6 @@ import numpy as np
...
@@ -8,7 +8,6 @@ import numpy as np
from
.._imperative_rt
import
make_const
from
.._imperative_rt
import
make_const
from
.._imperative_rt.core2
import
(
from
.._imperative_rt.core2
import
(
Const
,
Const
,
SymbolVar
,
Tensor
,
Tensor
,
_get_convert_inputs
,
_get_convert_inputs
,
_set_convert_inputs
,
_set_convert_inputs
,
...
@@ -77,7 +76,7 @@ def result_type(*args):
...
@@ -77,7 +76,7 @@ def result_type(*args):
def
isscalar
(
x
):
def
isscalar
(
x
):
if
isinstance
(
x
,
(
Tensor
,
SymbolVar
)
):
if
isinstance
(
x
,
Tensor
):
return
x
.
_isscalar
()
return
x
.
_isscalar
()
return
np
.
isscalar
(
x
)
return
np
.
isscalar
(
x
)
...
@@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device):
...
@@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device):
return
results
return
results
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
def
apply_const
(
value
,
dtype
=
dtype
,
device
=
device
):
return
Const
(
value
,
dtype
,
device
,
None
)
return
Const
(
value
,
dtype
,
device
)
outputs
,
outputs_has_grad
=
func
(
args
,
apply_expr
,
apply_const
)
outputs
,
outputs_has_grad
=
func
(
args
,
apply_expr
,
apply_const
)
outputs
=
[
outputs
=
[
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
a926878c
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import
numpy
as
np
import
numpy
as
np
from
..core._imperative_rt.core2
import
SymbolVar
,
apply
from
..core._imperative_rt.core2
import
apply
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Elemwise
from
..core.ops.builtin
import
Elemwise
from
..core.tensor.array_method
import
_elwise
from
..core.tensor.array_method
import
_elwise
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
a926878c
...
@@ -538,7 +538,7 @@ def topk(
...
@@ -538,7 +538,7 @@ def topk(
op
=
builtin
.
TopK
(
mode
=
mode
)
op
=
builtin
.
TopK
(
mode
=
mode
)
if
not
isinstance
(
k
,
Tensor
):
if
not
isinstance
(
k
,
Tensor
):
k
=
Const
(
k
,
"int32"
,
inp
.
device
,
None
)
k
=
Const
(
k
,
"int32"
,
inp
.
device
)
if
len
(
inp
.
shape
)
==
1
:
if
len
(
inp
.
shape
)
==
1
:
if
kth_only
:
if
kth_only
:
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
a926878c
...
@@ -1222,7 +1222,7 @@ def batch_norm(
...
@@ -1222,7 +1222,7 @@ def batch_norm(
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
if
x
is
None
:
if
x
is
None
:
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
,
None
)
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
)
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
return
result
return
result
...
@@ -1446,7 +1446,7 @@ def sync_batch_norm(
...
@@ -1446,7 +1446,7 @@ def sync_batch_norm(
def
_make_full_if_none
(
x
,
value
):
def
_make_full_if_none
(
x
,
value
):
if
x
is
None
:
if
x
is
None
:
x
=
Const
(
value
,
inp
.
dtype
,
_device
,
None
)
x
=
Const
(
value
,
inp
.
dtype
,
_device
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
reduce_shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
reduce_shape
)
return
result
return
result
elif
x
.
ndim
==
1
:
elif
x
.
ndim
==
1
:
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
a926878c
...
@@ -7,7 +7,6 @@ import numpy as np
...
@@ -7,7 +7,6 @@ import numpy as np
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
(
from
..core._imperative_rt.core2
import
(
Const
,
Const
,
SymbolVar
,
apply
,
apply
,
broadcast_cpp
,
broadcast_cpp
,
dtype_promotion
,
dtype_promotion
,
...
@@ -151,7 +150,7 @@ def full(
...
@@ -151,7 +150,7 @@ def full(
shape
=
(
shape
,)
shape
=
(
shape
,)
if
device
is
None
:
if
device
is
None
:
device
=
get_default_device
()
device
=
get_default_device
()
x
=
Const
(
value
,
dtype
,
device
,
None
)
x
=
Const
(
value
,
dtype
,
device
)
if
type
(
shape
)
in
(
list
,
tuple
)
and
len
(
shape
)
==
0
:
if
type
(
shape
)
in
(
list
,
tuple
)
and
len
(
shape
)
==
0
:
return
x
return
x
return
broadcast_to
(
x
,
shape
)
return
broadcast_to
(
x
,
shape
)
...
@@ -216,7 +215,7 @@ def zeros(
...
@@ -216,7 +215,7 @@ def zeros(
return
full
(
shape
,
0.0
,
dtype
=
dtype
,
device
=
device
)
return
full
(
shape
,
0.0
,
dtype
=
dtype
,
device
=
device
)
def
zeros_like
(
inp
:
Union
[
Tensor
,
SymbolVar
])
->
Union
[
Tensor
,
SymbolVar
]
:
def
zeros_like
(
inp
:
Tensor
)
->
Tensor
:
r
"""Returns a tensor filled with zeros with the same shape and data type as input tensor.
r
"""Returns a tensor filled with zeros with the same shape and data type as input tensor.
Args:
Args:
...
@@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
...
@@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
return
full_like
(
inp
,
0.0
)
return
full_like
(
inp
,
0.0
)
def
ones_like
(
inp
:
Union
[
Tensor
,
SymbolVar
])
->
Union
[
Tensor
,
SymbolVar
]
:
def
ones_like
(
inp
:
Tensor
)
->
Tensor
:
r
"""Returns a tensor filled with ones with the same shape and data type as input tensor.
r
"""Returns a tensor filled with ones with the same shape and data type as input tensor.
Args:
Args:
...
@@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
...
@@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
return
full_like
(
inp
,
1.0
)
return
full_like
(
inp
,
1.0
)
def
full_like
(
def
full_like
(
inp
:
Tensor
,
value
:
Union
[
int
,
float
])
->
Tensor
:
inp
:
Union
[
Tensor
,
SymbolVar
],
value
:
Union
[
int
,
float
]
)
->
Union
[
Tensor
,
SymbolVar
]:
r
"""Returns a tensor filled with given value with the same shape as input tensor.
r
"""Returns a tensor filled with given value with the same shape as input tensor.
Args:
Args:
...
@@ -272,7 +269,7 @@ def full_like(
...
@@ -272,7 +269,7 @@ def full_like(
Tensor([[2 2 2]
Tensor([[2 2 2]
[2 2 2]], dtype=int32, device=xpux:0)
[2 2 2]], dtype=int32, device=xpux:0)
"""
"""
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
,
inp
)
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
)
if
inp
.
ndim
==
0
:
if
inp
.
ndim
==
0
:
return
x
return
x
return
broadcast_to
(
x
,
inp
.
shape
)
return
broadcast_to
(
x
,
inp
.
shape
)
...
@@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
...
@@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
>>> print(v.numpy(), index.numpy())
>>> print(v.numpy(), index.numpy())
[1. 4.] [0 3]
[1. 4.] [0 3]
"""
"""
if
not
isinstance
(
x
,
(
Tensor
,
SymbolVar
)
):
if
not
isinstance
(
x
,
Tensor
):
raise
TypeError
(
"input must be a tensor"
)
raise
TypeError
(
"input must be a tensor"
)
if
not
isinstance
(
mask
,
(
Tensor
,
SymbolVar
)
):
if
not
isinstance
(
mask
,
Tensor
):
raise
TypeError
(
"mask must be a tensor"
)
raise
TypeError
(
"mask must be a tensor"
)
if
mask
.
dtype
!=
np
.
bool_
:
if
mask
.
dtype
!=
np
.
bool_
:
raise
ValueError
(
"mask must be bool"
)
raise
ValueError
(
"mask must be bool"
)
...
@@ -843,15 +840,11 @@ def linspace(
...
@@ -843,15 +840,11 @@ def linspace(
if
not
(
cur_device
is
None
or
device
==
cur_device
):
if
not
(
cur_device
is
None
or
device
==
cur_device
):
raise
(
"ambiguous device for linspace opr"
)
raise
(
"ambiguous device for linspace opr"
)
is_symbolvar
=
list
(
isinstance
(
x
,
SymbolVar
)
for
x
in
[
start
,
stop
,
num
])
if
not
isinstance
(
start
,
Tensor
):
if
any
(
is_symbolvar
)
and
not
all
(
is_symbolvar
):
raise
TypeError
(
"start, stop and num should all be VarNode or none of them"
)
if
not
isinstance
(
start
,
(
Tensor
,
SymbolVar
)):
start
=
Tensor
(
start
,
device
=
device
)
start
=
Tensor
(
start
,
device
=
device
)
if
not
isinstance
(
stop
,
(
Tensor
,
SymbolVar
)
):
if
not
isinstance
(
stop
,
Tensor
):
stop
=
Tensor
(
stop
,
device
=
device
)
stop
=
Tensor
(
stop
,
device
=
device
)
if
not
isinstance
(
num
,
(
Tensor
,
SymbolVar
)
):
if
not
isinstance
(
num
,
Tensor
):
num
=
Tensor
(
num
,
device
=
device
)
num
=
Tensor
(
num
,
device
=
device
)
op
=
builtin
.
Linspace
(
comp_node
=
device
)
op
=
builtin
.
Linspace
(
comp_node
=
device
)
...
@@ -901,8 +894,11 @@ def arange(
...
@@ -901,8 +894,11 @@ def arange(
if
stop
is
None
:
if
stop
is
None
:
start
,
stop
=
0
,
start
start
,
stop
=
0
,
start
if
not
isinstance
(
start
,
Tensor
):
start
=
Tensor
(
start
,
dtype
=
"float32"
)
start
=
Tensor
(
start
,
dtype
=
"float32"
)
if
not
isinstance
(
stop
,
Tensor
):
stop
=
Tensor
(
stop
,
dtype
=
"float32"
)
stop
=
Tensor
(
stop
,
dtype
=
"float32"
)
if
not
isinstance
(
step
,
Tensor
):
step
=
Tensor
(
step
,
dtype
=
"float32"
)
step
=
Tensor
(
step
,
dtype
=
"float32"
)
num
=
ceil
((
stop
-
start
)
/
step
)
num
=
ceil
((
stop
-
start
)
/
step
)
...
...
imperative/python/megengine/functional/tensor_cache.py
浏览文件 @
a926878c
...
@@ -7,11 +7,11 @@ small_tensor_cache = {}
...
@@ -7,11 +7,11 @@ small_tensor_cache = {}
def
_get_scalar_tensor_with_value
(
value
,
dtype
=
None
,
device
=
None
):
def
_get_scalar_tensor_with_value
(
value
,
dtype
=
None
,
device
=
None
):
global
small_tensor_cache
global
small_tensor_cache
if
is_tracing
():
if
is_tracing
():
ret
=
Const
(
value
,
dtype
,
device
,
None
)
ret
=
Const
(
value
,
dtype
,
device
)
else
:
else
:
cache_key
=
(
value
,
dtype
,
device
)
cache_key
=
(
value
,
dtype
,
device
)
if
cache_key
not
in
small_tensor_cache
:
if
cache_key
not
in
small_tensor_cache
:
ret
=
Const
(
value
,
dtype
,
device
,
None
)
ret
=
Const
(
value
,
dtype
,
device
)
small_tensor_cache
[
cache_key
]
=
ret
small_tensor_cache
[
cache_key
]
=
ret
else
:
else
:
ret
=
small_tensor_cache
[
cache_key
]
ret
=
small_tensor_cache
[
cache_key
]
...
...
imperative/python/megengine/tensor.py
浏览文件 @
a926878c
...
@@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
...
@@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
@
name
.
setter
@
name
.
setter
def
name
(
self
,
name
):
def
name
(
self
,
name
):
self
.
_custom_name
=
name
self
.
_custom_name
=
name
if
name
==
None
:
name
=
""
self
.
_name
=
self
.
_prefix
+
"."
+
name
if
self
.
_prefix
else
name
self
.
_name
=
self
.
_prefix
+
"."
+
name
if
self
.
_prefix
else
name
self
.
_set_name
(
self
.
_name
)
self
.
_set_name
(
self
.
_name
)
...
...
imperative/python/megengine/traced_module/expr.py
浏览文件 @
a926878c
...
@@ -756,7 +756,7 @@ class Constant(Expr):
...
@@ -756,7 +756,7 @@ class Constant(Expr):
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
if
isinstance
(
self
.
value
,
RawTensor
):
if
isinstance
(
self
.
value
,
RawTensor
):
return
(
Const
(
self
.
value
.
numpy
(),
None
,
None
,
None
),)
return
(
Const
(
self
.
value
.
numpy
(),
None
,
None
),)
return
(
self
.
value
,)
return
(
self
.
value
,)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
...
imperative/python/megengine/utils/network.py
浏览文件 @
a926878c
...
@@ -395,7 +395,7 @@ class Network:
...
@@ -395,7 +395,7 @@ class Network:
for
ind
,
var
in
enumerate
(
opr
.
outputs
):
for
ind
,
var
in
enumerate
(
opr
.
outputs
):
var
.
owner
=
repl_dict
[
opr
]
var
.
owner
=
repl_dict
[
opr
]
var
.
__dict__
.
update
(
repl_dict
[
opr
].
outputs
[
ind
].
__dict__
)
var
.
__dict__
.
update
(
repl_dict
[
opr
].
outputs
[
ind
].
__dict__
)
var
.
var
=
repl_dict
[
opr
].
outputs
[
ind
].
var
var
.
_reset_var
(
repl_dict
[
opr
].
outputs
[
ind
].
var
)
repl_dict
[
opr
].
outputs
=
opr
.
outputs
repl_dict
[
opr
].
outputs
=
opr
.
outputs
self
.
_compile
()
self
.
_compile
()
...
...
imperative/python/megengine/utils/network_node.py
浏览文件 @
a926878c
...
@@ -6,11 +6,11 @@ from typing import Sequence
...
@@ -6,11 +6,11 @@ from typing import Sequence
import
numpy
as
np
import
numpy
as
np
from
..core
import
_imperative_rt
as
rt
from
..core
import
_imperative_rt
as
rt
from
..core._imperative_rt.core2
import
SymbolVar
,
apply
from
..core._imperative_rt.core2
import
apply
,
set_py_varnode_type
from
..core._trace_option
import
use_symbolic_shape
from
..core._trace_option
import
use_symbolic_shape
from
..core._wrap
import
Device
from
..core._wrap
import
Device
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..
core.tensor.array_method
import
ArrayMethodMixin
from
..
tensor
import
Tensor
from
.comp_graph_tools
import
replace_vars
from
.comp_graph_tools
import
replace_vars
from
.module_stats
import
(
from
.module_stats
import
(
preprocess_receptive_field
,
preprocess_receptive_field
,
...
@@ -23,26 +23,72 @@ class NetworkNode:
...
@@ -23,26 +23,72 @@ class NetworkNode:
pass
pass
class
VarNodeMeta
(
type
(
SymbolVar
),
type
(
ArrayMethodMixin
)):
class
VarNode
(
NetworkNode
,
Tensor
):
pass
_users
=
None
_owner
=
None
_name
=
None
_id
=
None
def
__new__
(
cls
,
var
,
*
,
owner_opr
=
None
,
name
=
None
):
obj
=
Tensor
.
__new__
(
cls
,
var
)
return
obj
class
VarNode
(
NetworkNode
,
SymbolVar
,
ArrayMethodMixin
,
metaclass
=
VarNodeMeta
):
def
__init__
(
self
,
var
,
*
,
owner_opr
=
None
,
name
=
None
):
def
__init__
(
self
,
var
=
None
,
*
,
owner_opr
=
None
,
name
=
None
):
self
.
_owner
=
owner_opr
SymbolVar
.
__init__
(
self
,
var
)
self
.
users
=
[]
# List[OpNode]
self
.
owner
=
owner_opr
self
.
name
=
name
self
.
name
=
name
self
.
id
=
id
(
self
)
@
classmethod
@
classmethod
def
load
(
cls
,
sym_var
,
owner_opr
):
def
load
(
cls
,
sym_var
,
owner_opr
):
obj
=
cls
()
obj
=
cls
(
sym_var
)
obj
.
var
=
sym_var
# mgb varnode
obj
.
var
=
sym_var
# mgb varnode
obj
.
name
=
sym_var
.
name
obj
.
name
=
sym_var
.
name
obj
.
owner
=
owner_opr
obj
.
owner
=
owner_opr
return
obj
return
obj
@
property
def
users
(
self
):
if
self
.
_users
is
None
:
self
.
_users
=
[]
return
self
.
_users
@
property
def
owner
(
self
):
return
self
.
_owner
@
owner
.
setter
def
owner
(
self
,
owner
):
self
.
_owner
=
owner
@
property
def
id
(
self
):
if
self
.
_id
is
None
:
self
.
_id
=
id
(
self
)
return
self
.
_id
@
property
def
var
(
self
):
return
super
().
var
()
@
var
.
setter
def
var
(
self
,
var
):
self
.
_reset
(
var
)
def
_reset
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
other
=
VarNode
(
other
)
super
().
_reset
(
other
)
self
.
owner
=
None
def
_reset_var
(
self
,
var
):
origin_owner
=
self
.
owner
self
.
var
=
var
self
.
var
.
name
=
self
.
name
self
.
owner
=
origin_owner
@
property
def
graph
(
self
):
return
super
().
graph
()
def
_get_var_shape
(
self
,
axis
=
None
):
def
_get_var_shape
(
self
,
axis
=
None
):
opdef
=
(
opdef
=
(
builtin
.
GetVarShape
()
if
axis
is
None
else
builtin
.
GetVarShape
(
axis
=
axis
)
builtin
.
GetVarShape
()
if
axis
is
None
else
builtin
.
GetVarShape
(
axis
=
axis
)
...
@@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
...
@@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
return
rst
return
rst
return
self
.
_get_var_shape
()
if
self
.
var
else
None
return
self
.
_get_var_shape
()
if
self
.
var
else
None
@
property
def
dtype
(
self
):
return
self
.
var
.
dtype
if
self
.
var
else
None
@
property
def
ndim
(
self
):
return
super
().
ndim
def
__bool__
(
self
):
def
__bool__
(
self
):
return
False
return
False
...
@@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
...
@@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
__int__
=
None
__int__
=
None
__float__
=
None
__float__
=
None
__complex__
=
None
__complex__
=
None
__repr__
=
lambda
self
:
"VarNode:"
+
self
.
name
def
__hash__
(
self
):
def
__hash__
(
self
):
return
id
(
self
)
return
id
(
self
)
def
numpy
(
self
):
return
super
().
numpy
()
def
_reset
(
self
,
other
):
if
not
isinstance
(
other
,
VarNode
):
assert
self
.
graph
,
"VarNode _reset must have graph"
node
=
ImmutableTensor
(
other
,
graph
=
self
.
graph
)
node
.
compile
(
self
.
graph
)
other
=
node
.
outputs
[
0
]
if
self
.
owner
is
not
None
:
idx
=
self
.
owner
.
outputs
.
index
(
self
)
self
.
owner
.
outputs
[
idx
]
=
VarNode
(
self
.
var
,
owner_opr
=
self
.
owner
,
name
=
self
.
var
.
name
)
self
.
var
=
other
.
var
self
.
owner
=
None
def
set_owner_opr
(
self
,
owner_opr
):
def
set_owner_opr
(
self
,
owner_opr
):
self
.
owner
=
owner_opr
self
.
owner
=
owner_opr
...
@@ -158,8 +180,7 @@ class OpNode(NetworkNode):
...
@@ -158,8 +180,7 @@ class OpNode(NetworkNode):
assert
len
(
outputs
)
==
len
(
self
.
outputs
)
assert
len
(
outputs
)
==
len
(
self
.
outputs
)
self
.
_opr
=
outputs
[
0
].
owner
self
.
_opr
=
outputs
[
0
].
owner
for
i
in
range
(
len
(
self
.
outputs
)):
for
i
in
range
(
len
(
self
.
outputs
)):
self
.
outputs
[
i
].
var
=
outputs
[
i
]
self
.
outputs
[
i
].
_reset_var
(
outputs
[
i
])
self
.
outputs
[
i
].
var
.
name
=
self
.
outputs
[
i
].
name
assert
self
.
outputs
[
i
].
owner
is
self
assert
self
.
outputs
[
i
].
owner
is
self
def
add_inp_var
(
self
,
x
):
def
add_inp_var
(
self
,
x
):
...
@@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode):
...
@@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode):
outputs
=
rt
.
make_h2d
(
graph
,
self
.
device
,
self
.
dtype
,
self
.
shape
,
self
.
name
)
outputs
=
rt
.
make_h2d
(
graph
,
self
.
device
,
self
.
dtype
,
self
.
shape
,
self
.
name
)
self
.
_opr
=
outputs
.
owner
self
.
_opr
=
outputs
.
owner
if
len
(
self
.
outputs
)
==
0
:
if
len
(
self
.
outputs
)
==
0
:
self
.
outputs
.
append
(
VarNode
(
owner_opr
=
self
,
name
=
self
.
name
))
self
.
outputs
.
append
(
VarNode
(
outputs
,
owner_opr
=
self
,
name
=
self
.
name
))
self
.
outputs
[
0
].
var
=
outputs
else
:
self
.
outputs
[
0
].
_reset_var
(
outputs
)
assert
self
.
outputs
[
0
].
owner
is
self
assert
self
.
outputs
[
0
].
owner
is
self
...
@@ -262,8 +284,9 @@ class ConstOpBase(OpNode):
...
@@ -262,8 +284,9 @@ class ConstOpBase(OpNode):
data
=
data
.
astype
(
np
.
int32
)
data
=
data
.
astype
(
np
.
int32
)
varnode
=
type
(
self
).
rt_fun
(
self
.
graph
,
data
,
cn
,
data
.
dtype
,
self
.
name
)
varnode
=
type
(
self
).
rt_fun
(
self
.
graph
,
data
,
cn
,
data
.
dtype
,
self
.
name
)
if
len
(
self
.
outputs
)
==
0
:
if
len
(
self
.
outputs
)
==
0
:
self
.
outputs
.
append
(
VarNode
(
owner_opr
=
self
,
name
=
self
.
name
))
self
.
outputs
.
append
(
VarNode
(
varnode
,
owner_opr
=
self
,
name
=
self
.
name
))
self
.
outputs
[
0
].
var
=
varnode
else
:
self
.
outputs
[
0
].
_reset_var
(
varnode
)
self
.
_opr
=
varnode
.
owner
self
.
_opr
=
varnode
.
owner
@
classmethod
@
classmethod
...
@@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode):
...
@@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode):
if
bool
(
repl_dict
):
if
bool
(
repl_dict
):
out_vars
=
replace_vars
(
self
.
_opr
.
outputs
,
repl_dict
)
out_vars
=
replace_vars
(
self
.
_opr
.
outputs
,
repl_dict
)
for
ind
,
o
in
enumerate
(
self
.
outputs
):
for
ind
,
o
in
enumerate
(
self
.
outputs
):
o
.
var
=
out_vars
[
ind
]
o
.
_reset_var
(
out_vars
[
ind
])
class
Elemwise
(
OpNode
):
class
Elemwise
(
OpNode
):
...
@@ -785,3 +808,6 @@ class AssertEqual(OpNode):
...
@@ -785,3 +808,6 @@ class AssertEqual(OpNode):
class
CvtColorForward
(
OpNode
):
class
CvtColorForward
(
OpNode
):
type
=
"CvtColor"
type
=
"CvtColor"
opdef
=
builtin
.
CvtColor
opdef
=
builtin
.
CvtColor
set_py_varnode_type
(
VarNode
)
imperative/python/src/graph_rt.cpp
浏览文件 @
a926878c
...
@@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
...
@@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
}
}
}
}
py
::
object
Py_Varnode
=
py
::
none
();
void
init_graph_rt
(
py
::
module
m
)
{
void
init_graph_rt
(
py
::
module
m
)
{
static
const
std
::
unique_ptr
<
mgb
::
OprFootprint
>
_imperative_sm_opr_footprint_ptr
{
static
const
std
::
unique_ptr
<
mgb
::
OprFootprint
>
_imperative_sm_opr_footprint_ptr
{
std
::
make_unique
<
mgb
::
OprFootprint
>
()};
std
::
make_unique
<
mgb
::
OprFootprint
>
()};
...
@@ -124,6 +126,7 @@ void init_graph_rt(py::module m) {
...
@@ -124,6 +126,7 @@ void init_graph_rt(py::module m) {
def_rendezvous
<
TensorAttr
>
(
m
,
"TensorAttrRendezvous"
);
def_rendezvous
<
TensorAttr
>
(
m
,
"TensorAttrRendezvous"
);
Py_Varnode
=
py
::
class_
<
cg
::
VarNode
,
GraphNodePtr
<
cg
::
VarNode
>>
(
m
,
"VarNode"
)
py
::
class_
<
cg
::
VarNode
,
GraphNodePtr
<
cg
::
VarNode
>>
(
m
,
"VarNode"
)
.
def_property_readonly
(
.
def_property_readonly
(
"owner"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
owner_opr
();
})
"owner"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
owner_opr
();
})
...
@@ -132,7 +135,8 @@ void init_graph_rt(py::module m) {
...
@@ -132,7 +135,8 @@ void init_graph_rt(py::module m) {
.
def_property
(
.
def_property
(
"name"
,
py
::
overload_cast
<>
(
&
VarNode
::
name
,
py
::
const_
),
"name"
,
py
::
overload_cast
<>
(
&
VarNode
::
name
,
py
::
const_
),
py
::
overload_cast
<
std
::
string
>
(
&
VarNode
::
name
))
py
::
overload_cast
<
std
::
string
>
(
&
VarNode
::
name
))
.
def_property_readonly
(
"dtype"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
dtype
();
})
.
def_property_readonly
(
"dtype"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
dtype
();
})
.
def_property_readonly
(
.
def_property_readonly
(
"comp_node"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
comp_node
();
})
"comp_node"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
comp_node
();
})
.
def_property_readonly
(
.
def_property_readonly
(
...
@@ -147,7 +151,8 @@ void init_graph_rt(py::module m) {
...
@@ -147,7 +151,8 @@ void init_graph_rt(py::module m) {
auto
&&
mgr
=
v
->
owner_graph
()
->
static_infer_manager
();
auto
&&
mgr
=
v
->
owner_graph
()
->
static_infer_manager
();
auto
&&
type
=
mgr
.
get_infer_type
(
v
);
auto
&&
type
=
mgr
.
get_infer_type
(
v
);
using
InferType
=
cg
::
static_infer
::
InferType
;
using
InferType
=
cg
::
static_infer
::
InferType
;
if
(
!
(
type
.
value
&
(
InferType
::
CONST
|
InferType
::
RT_STATIC
)))
{
if
(
!
(
type
.
value
&
(
InferType
::
CONST
|
InferType
::
RT_STATIC
)))
{
return
py
::
none
();
return
py
::
none
();
}
}
auto
*
val
=
mgr
.
infer_value_fallible
(
v
);
auto
*
val
=
mgr
.
infer_value_fallible
(
v
);
...
@@ -156,7 +161,8 @@ void init_graph_rt(py::module m) {
...
@@ -156,7 +161,8 @@ void init_graph_rt(py::module m) {
}
}
return
py
::
cast
(
*
val
).
attr
(
"numpy"
)();
return
py
::
cast
(
*
val
).
attr
(
"numpy"
)();
})
})
.
def_property_readonly
(
"id"
,
[](
cg
::
VarNode
*
v
)
{
return
(
v
->
id
());
})
.
def_property_readonly
(
"id"
,
[](
cg
::
VarNode
*
v
)
{
return
(
v
->
id
());
})
.
def
(
"__repr__"
,
[](
cg
::
VarNode
*
v
)
{
return
"Var:"
+
v
->
name
();
});
.
def
(
"__repr__"
,
[](
cg
::
VarNode
*
v
)
{
return
"Var:"
+
v
->
name
();
});
py
::
class_
<
cg
::
OperatorNodeBase
,
GraphNodePtr
<
cg
::
OperatorNodeBase
>>
(
py
::
class_
<
cg
::
OperatorNodeBase
,
GraphNodePtr
<
cg
::
OperatorNodeBase
>>
(
...
...
imperative/python/src/graph_rt.h
浏览文件 @
a926878c
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
#include "megbrain/graph.h"
#include "megbrain/graph.h"
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/plugin/opr_footprint.h"
namespace
py
=
pybind11
;
extern
py
::
object
Py_Varnode
;
template
<
typename
T
>
template
<
typename
T
>
class
GraphNodePtr
{
class
GraphNodePtr
{
std
::
shared_ptr
<
mgb
::
cg
::
ComputingGraph
>
m_graph
;
std
::
shared_ptr
<
mgb
::
cg
::
ComputingGraph
>
m_graph
;
...
...
imperative/python/src/tensor.cpp
浏览文件 @
a926878c
...
@@ -48,58 +48,11 @@ namespace mgb::imperative::python {
...
@@ -48,58 +48,11 @@ namespace mgb::imperative::python {
namespace
{
namespace
{
WeakKeyMap
<
ValueWeakRef
,
py
::
object
>
module_trace_info_map
;
WeakKeyMap
<
ValueWeakRef
,
py
::
object
>
module_trace_info_map
;
struct
SymbolVarContext
{
TransformationContext
context
;
std
::
shared_ptr
<
SymbolTransformation
>
symbol_tsf
;
std
::
shared_ptr
<
ScalarTransformation
>
scalar_tsf
;
std
::
shared_ptr
<
DTypePromoteTransformation
>
dtype_promote_tsf
;
std
::
shared_ptr
<
DimExpansionTransformation
>
dim_expansion_tsf
;
SymbolVarContext
(
cg
::
ComputingGraph
*
graph
)
{
symbol_tsf
=
std
::
make_shared
<
SymbolTransformation
>
(
graph
);
scalar_tsf
=
std
::
make_shared
<
ScalarTransformation
>
();
dtype_promote_tsf
=
std
::
make_shared
<
DTypePromoteTransformation
>
();
dim_expansion_tsf
=
std
::
make_shared
<
DimExpansionTransformation
>
();
Transformation
::
swap_context
(
context
);
}
void
init
()
{
symbol_tsf
->
register_at
(
Transformation
::
top
());
scalar_tsf
->
register_at
(
Transformation
::
top
());
dtype_promote_tsf
->
register_at
(
Transformation
::
top
());
dim_expansion_tsf
->
register_at
(
Transformation
::
top
());
}
ValueRef
symvar2val
(
py
::
handle
py_symbol_var
)
{
auto
*
symbol_var
=
py_symbol_var
.
cast
<
PySymbolVar
*>
();
ValueRef
value
=
symbol_tsf
->
value_type
().
make
(
symbol_var
->
m_node
);
if
(
symbol_var
->
is_scalar
)
{
value
=
scalar_tsf
->
value_type
().
make
(
value
);
}
return
value
;
}
py
::
object
val2symvar
(
py
::
handle
typeobj
,
ValueRef
value
)
{
bool
is_scalar
=
false
;
if
(
auto
*
scalar_value
=
value
.
as
(
scalar_tsf
->
value_type
()))
{
value
=
scalar_value
->
value
();
is_scalar
=
true
;
}
auto
*
node
=
value
.
cast
(
symbol_tsf
->
value_type
()).
node
();
auto
py_symbol_var
=
typeobj
(
pybind11
::
cast
(
node
,
pybind11
::
return_value_policy
::
automatic
));
py_symbol_var
.
cast
<
PySymbolVar
*>
()
->
is_scalar
=
is_scalar
;
return
py_symbol_var
;
}
~
SymbolVarContext
()
{
Transformation
::
swap_context
(
context
);
}
};
}
// namespace
}
// namespace
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
=
nullptr
;
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
=
nullptr
;
PyTypeObject
*
py_tensor_type
=
nullptr
;
PyTypeObject
*
py_tensor_type
=
nullptr
;
PyTypeObject
*
py_varnode_type
=
nullptr
;
pybind11
::
handle
py_device_type
=
nullptr
;
pybind11
::
handle
py_device_type
=
nullptr
;
PyObject
*
cpp_use_symbolic_shape
;
PyObject
*
cpp_use_symbolic_shape
;
...
@@ -136,22 +89,6 @@ PyObject* py_apply(
...
@@ -136,22 +89,6 @@ PyObject* py_apply(
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
SmallVector
<
bool
,
8
>
is_symbol_var
(
nargs
,
false
);
ComputingGraph
*
cg
=
nullptr
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
((
!
TensorWrapper
::
try_cast
(
args
[
i
]))
&&
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
i
])))
{
is_symbol_var
[
i
]
=
true
;
ComputingGraph
*
cur_cg
=
py
::
handle
(
args
[
i
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
();
if
(
cg
==
nullptr
)
{
cg
=
cur_cg
;
}
else
{
mgb_assert
(
cg
==
cur_cg
);
}
}
}
mgb
::
CompNode
target_cn
;
mgb
::
CompNode
target_cn
;
mgb
::
DType
target_dtype
;
mgb
::
DType
target_dtype
;
...
@@ -174,35 +111,11 @@ PyObject* py_apply(
...
@@ -174,35 +111,11 @@ PyObject* py_apply(
}
}
};
};
if
(
cg
!=
nullptr
)
{
bool
is_varnode_apply
=
false
;
// swap to a special context to reuse scalar handle
size_t
symbol_var_idx
=
8
;
SymbolVarContext
context
(
cg
);
context
.
init
();
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
(
is_symbol_var
[
i
])
{
if
(
PyObject_TypeCheck
(
args
[
i
],
py_varnode_type
))
{
symbol_var_idx
=
i
;
is_varnode_apply
=
true
;
tensors
[
i
]
=
context
.
symvar2val
(
args
[
i
]);
}
else
if
(
DTypePromoteCfg
::
convert_input_enabled
&&
op
->
same_type
<
Elemwise
>
())
{
tensors
[
i
]
=
convert_pyinput_to_tensor
(
i
);
}
else
{
PyErr_SetString
(
PyExc_TypeError
,
"py_apply expects tensor as inputs"
);
return
nullptr
;
}
}
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
typeobj
=
py
::
handle
(
args
[
symbol_var_idx
]).
get_type
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
ret
[
i
]
=
context
.
val2symvar
(
typeobj
,
outputs
[
i
]);
}
}
return
ret
.
release
().
ptr
();
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
args
[
i
]))
{
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
args
[
i
]))
{
tensors
[
i
]
=
tw
->
m_tensor
->
data
();
tensors
[
i
]
=
tw
->
m_tensor
->
data
();
}
else
if
(
}
else
if
(
...
@@ -218,8 +131,9 @@ PyObject* py_apply(
...
@@ -218,8 +131,9 @@ PyObject* py_apply(
auto
outputs
=
[
&
]
{
return
imperative
::
apply
(
*
op
,
tensors
);
}();
auto
outputs
=
[
&
]
{
return
imperative
::
apply
(
*
op
,
tensors
);
}();
size_t
nout
=
outputs
.
size
();
size_t
nout
=
outputs
.
size
();
auto
ret
=
py
::
tuple
(
nout
);
auto
ret
=
py
::
tuple
(
nout
);
PyTypeObject
*
py_type
=
is_varnode_apply
?
py_varnode_type
:
py_tensor_type
;
for
(
size_t
i
=
0
;
i
<
nout
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nout
;
++
i
)
{
ret
[
i
]
=
TensorWrapper
::
make
(
py_t
ensor_t
ype
,
std
::
move
(
outputs
[
i
]));
ret
[
i
]
=
TensorWrapper
::
make
(
py_type
,
std
::
move
(
outputs
[
i
]));
}
}
return
ret
.
release
().
ptr
();
return
ret
.
release
().
ptr
();
}
}
...
@@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
...
@@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
CreateTensor
::
Kind
kind
=
is_const
?
CreateTensor
::
Const
CreateTensor
::
Kind
kind
=
is_const
?
CreateTensor
::
Const
:
no_cache
?
CreateTensor
::
Unique
:
no_cache
?
CreateTensor
::
Unique
:
CreateTensor
::
Common
;
:
CreateTensor
::
Common
;
ValueRef
val
;
if
(
py
::
isinstance
(
data
,
Py_Varnode
))
{
cg
::
VarNode
*
m_node
=
py
::
handle
(
data
).
cast
<
cg
::
VarNode
*>
();
val
=
imperative
::
apply
(
CreateNode
(
m_node
),
Span
<
ValueRef
>
(
nullptr
,
nullptr
))[
0
];
}
else
{
auto
&&
hval
=
pyobj2hval
(
data
,
cn
,
dtype
);
auto
&&
hval
=
pyobj2hval
(
data
,
cn
,
dtype
);
auto
val
=
imperative
::
apply
(
val
=
imperative
::
apply
(
CreateTensor
(
kind
,
cn
,
hval
.
dtype
,
hval
.
shape
),
hval
.
storage
)[
0
];
CreateTensor
(
kind
,
cn
,
hval
.
dtype
,
hval
.
shape
),
hval
.
storage
)[
0
];
}
m_tensor
.
emplace
(
val
);
m_tensor
.
emplace
(
val
);
}
}
...
@@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() {
...
@@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() {
}
}
}
}
PyObject
*
TensorWrapper
::
_var
()
{
TypedValueRef
<
NodeValue
>
value
=
imperative
::
apply
(
GetVarVal
(),
m_tensor
->
data
())[
0
].
as_ref
<
NodeValue
>
();
auto
*
node
=
value
->
node
();
return
py
::
cast
(
node
).
release
().
ptr
();
}
PyObject
*
TensorWrapper
::
_graph
()
{
TypedValueRef
<
NodeValue
>
value
=
imperative
::
apply
(
GetVarVal
(),
m_tensor
->
data
())[
0
].
as_ref
<
NodeValue
>
();
auto
*
graph
=
value
->
graph
();
return
py
::
cast
(
graph
).
release
().
ptr
();
}
struct
TensorWeakRef
{
struct
TensorWeakRef
{
ValueWeakRef
data
;
ValueWeakRef
data
;
...
@@ -807,6 +743,10 @@ void init_tensor(py::module m) {
...
@@ -807,6 +743,10 @@ void init_tensor(py::module m) {
.
register_at
<
Segment
::
Scalar
>
(
.
register_at
<
Segment
::
Scalar
>
(
std
::
make_shared
<
ScalarTransformation
>
())
std
::
make_shared
<
ScalarTransformation
>
())
.
release
());
.
release
());
MGB_MARK_USED_VAR
(
transformations
.
register_at
<
Segment
::
Symbol
>
(
std
::
make_shared
<
SymbolTransformation
>
())
.
release
());
MGB_MARK_USED_VAR
(
transformations
MGB_MARK_USED_VAR
(
transformations
.
register_at
<
Segment
::
DTypePromote
>
(
.
register_at
<
Segment
::
DTypePromote
>
(
std
::
make_shared
<
DTypePromoteTransformation
>
())
std
::
make_shared
<
DTypePromoteTransformation
>
())
...
@@ -863,6 +803,8 @@ void init_tensor(py::module m) {
...
@@ -863,6 +803,8 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
_detail
>
(
"_detail"
)
.
def
<&
TensorWrapper
::
_detail
>
(
"_detail"
)
.
def
<&
TensorWrapper
::
_set_name
>
(
"_set_name"
)
.
def
<&
TensorWrapper
::
_set_name
>
(
"_set_name"
)
.
def
<&
TensorWrapper
::
_watch
>
(
"_watch"
)
.
def
<&
TensorWrapper
::
_watch
>
(
"_watch"
)
.
def
<&
TensorWrapper
::
_var
>
(
"var"
)
.
def
<&
TensorWrapper
::
_graph
>
(
"graph"
)
.
def_getset
<
.
def_getset
<
&
TensorWrapper
::
module_trace_info
,
&
TensorWrapper
::
module_trace_info
,
&
TensorWrapper
::
set_module_trace_info
>
(
"_NodeMixin__node"
)
&
TensorWrapper
::
set_module_trace_info
>
(
"_NodeMixin__node"
)
...
@@ -875,43 +817,6 @@ void init_tensor(py::module m) {
...
@@ -875,43 +817,6 @@ void init_tensor(py::module m) {
.
def
(
py
::
init
<
const
TensorWrapper
&>
())
.
def
(
py
::
init
<
const
TensorWrapper
&>
())
.
def
(
"__call__"
,
&
TensorWeakRef
::
operator
());
.
def
(
"__call__"
,
&
TensorWeakRef
::
operator
());
py
::
class_
<
PySymbolVar
,
std
::
shared_ptr
<
PySymbolVar
>>
(
m
,
"SymbolVar"
)
.
def_property_readonly
(
"dtype"
,
[](
PySymbolVar
*
v
)
{
return
v
->
m_node
->
dtype
();
})
.
def_property
(
"var"
,
[](
PySymbolVar
*
v
)
{
return
v
->
m_node
;
},
[](
PySymbolVar
*
s
,
cg
::
VarNode
*
v
)
{
s
->
m_node
=
v
;
})
.
def_property_readonly
(
"device"
,
[](
PySymbolVar
*
v
)
{
return
v
->
m_node
->
comp_node
();
})
.
def_property_readonly
(
"graph"
,
[](
PySymbolVar
*
v
)
{
return
v
->
m_node
->
owner_graph
();
})
.
def_property_readonly
(
"shape"
,
[](
PySymbolVar
*
v
)
->
const
TensorShape
*
{
auto
&&
mgr
=
v
->
m_node
->
owner_graph
()
->
static_infer_manager
();
return
mgr
.
infer_shape_fallible
(
v
->
m_node
);
})
.
def
(
"numpy"
,
[](
PySymbolVar
*
v
)
{
auto
&&
mgr
=
v
->
m_node
->
owner_graph
()
->
static_infer_manager
();
auto
&&
type
=
mgr
.
get_infer_type
(
v
->
m_node
);
using
InferType
=
cg
::
static_infer
::
InferType
;
if
(
!
(
type
.
value
&
(
InferType
::
CONST
|
InferType
::
RT_STATIC
)))
{
throw
py
::
value_error
(
"value invalid!"
);
}
auto
*
val
=
mgr
.
infer_value_fallible
(
v
->
m_node
);
if
(
!
val
)
{
throw
py
::
value_error
(
"value invalid!"
);
}
auto
np_val
=
py
::
cast
(
*
val
).
attr
(
"numpy"
)();
return
np_val
;
})
.
def
(
"_isscalar"
,
[](
PySymbolVar
*
v
)
{
return
v
->
is_scalar
;
})
.
def
(
py
::
init
([](
cg
::
VarNode
*
node
)
{
return
std
::
make_shared
<
PySymbolVar
>
(
node
);
}),
py
::
arg
()
=
nullptr
);
static
PyMethodDef
method_defs
[]
=
{
static
PyMethodDef
method_defs
[]
=
{
MGE_PY_INTERFACE
(
apply
,
py_apply
),
MGE_PY_INTERFACE
(
apply
,
py_apply
),
MGE_PY_INTERFACE
(
dtype_promotion
,
dtype_promotion
),
MGE_PY_INTERFACE
(
dtype_promotion
,
dtype_promotion
),
...
@@ -1027,6 +932,10 @@ void init_tensor(py::module m) {
...
@@ -1027,6 +932,10 @@ void init_tensor(py::module m) {
py_tensor_type
=
reinterpret_cast
<
PyTypeObject
*>
(
type_obj
.
inc_ref
().
ptr
());
py_tensor_type
=
reinterpret_cast
<
PyTypeObject
*>
(
type_obj
.
inc_ref
().
ptr
());
});
});
m
.
def
(
"set_py_varnode_type"
,
[](
py
::
object
type_obj
)
{
py_varnode_type
=
reinterpret_cast
<
PyTypeObject
*>
(
type_obj
.
inc_ref
().
ptr
());
});
m
.
def
(
"set_py_device_type"
,
m
.
def
(
"set_py_device_type"
,
[](
py
::
object
type_obj
)
{
py_device_type
=
type_obj
.
inc_ref
();
});
[](
py
::
object
type_obj
)
{
py_device_type
=
type_obj
.
inc_ref
();
});
...
@@ -1217,31 +1126,6 @@ void init_tensor(py::module m) {
...
@@ -1217,31 +1126,6 @@ void init_tensor(py::module m) {
}
}
});
});
m
.
def
(
"reduce_to_scalar"
,
[](
py
::
object
op
,
py
::
object
tensor
)
->
py
::
object
{
auto
reduce_to_scalar
=
[](
const
OpDef
&
op
,
const
ValueRef
&
input
)
{
auto
make_scalar_shape
=
[
&
](
CompNode
device
)
{
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
device
,
dtype
::
Int32
(),
{
0
}),
HostStorage
::
make
(
device
))[
0
];
};
return
imperative
::
apply
(
op
,
input
,
make_scalar_shape
(
*
input
.
device
()))[
0
];
};
if
(
py
::
isinstance
<
PySymbolVar
>
(
tensor
))
{
auto
*
graph
=
tensor
.
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
();
SymbolVarContext
context
(
graph
);
context
.
init
();
auto
output
=
reduce_to_scalar
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
context
.
symvar2val
(
tensor
));
auto
typeobj
=
tensor
.
get_type
();
return
context
.
val2symvar
(
typeobj
,
output
);
}
else
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
output
=
reduce_to_scalar
(
*
op
.
cast
<
std
::
shared_ptr
<
OpDef
>>
(),
tw
->
m_tensor
->
data
());
return
TensorWrapper
::
make
(
py_tensor_type
,
output
);
}
});
m
.
def
(
"name_tensor"
,
[](
std
::
string
name
,
py
::
object
tensor
)
{
m
.
def
(
"name_tensor"
,
[](
std
::
string
name
,
py
::
object
tensor
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
auto
output
=
imperative
::
apply
(
TraceMarkVar
(
name
),
tw
->
m_tensor
->
data
())[
0
];
auto
output
=
imperative
::
apply
(
TraceMarkVar
(
name
),
tw
->
m_tensor
->
data
())[
0
];
...
...
imperative/python/src/tensor.h
浏览文件 @
a926878c
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include "./pyext17.h"
#include "./pyext17.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/utils/span.h"
namespace
mgb
::
imperative
::
python
{
namespace
mgb
::
imperative
::
python
{
...
@@ -27,6 +29,7 @@ namespace mgb::imperative::python {
...
@@ -27,6 +29,7 @@ namespace mgb::imperative::python {
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
extern
PyTypeObject
*
py_tensor_type
;
extern
PyTypeObject
*
py_tensor_type
;
extern
PyTypeObject
*
py_varnode_type
;
extern
pybind11
::
handle
py_device_type
;
extern
pybind11
::
handle
py_device_type
;
extern
PyObject
*
cpp_use_symbolic_shape
;
extern
PyObject
*
cpp_use_symbolic_shape
;
extern
PyObject
*
cpp_astensor1d
;
extern
PyObject
*
cpp_astensor1d
;
...
@@ -126,16 +129,11 @@ public:
...
@@ -126,16 +129,11 @@ public:
void
set_module_trace_info
(
PyObject
*
);
void
set_module_trace_info
(
PyObject
*
);
void
_set_name
(
PyObject
*
);
void
_set_name
(
PyObject
*
);
PyObject
*
_detail
();
PyObject
*
_detail
();
PyObject
*
_var
();
PyObject
*
_graph
();
void
_watch
();
void
_watch
();
};
};
struct
PySymbolVar
{
cg
::
VarNode
*
m_node
=
nullptr
;
bool
is_scalar
=
false
;
PySymbolVar
()
=
default
;
PySymbolVar
(
VarNode
*
m
)
:
m_node
(
m
)
{}
};
PyObject
*
py_apply
(
PyObject
*
py_apply
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
/* , PyObject* kwnames */
);
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
/* , PyObject* kwnames */
);
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
a926878c
...
@@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
...
@@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
continue
;
continue
;
}
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
handle
)))
{
auto
var
=
py
::
handle
(
handle
).
cast
<
PySymbolVar
*>
();
mgb
::
DType
type
=
var
->
m_node
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
continue
;
}
PyArray_Descr
*
descr
=
scalar2dtype
(
handle
);
PyArray_Descr
*
descr
=
scalar2dtype
(
handle
);
if
(
descr
)
{
if
(
descr
)
{
scalars
.
emplace_back
(
descr
);
scalars
.
emplace_back
(
descr
);
...
@@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) {
...
@@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) {
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
bool
is_symvar
=
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
handle
));
if
(
tw
)
{
if
(
tw
||
is_symvar
)
{
if
(
!
valid
)
{
if
(
!
valid
)
{
cn
=
tw
?
tw
->
m_tensor
->
comp_node
()
cn
=
tw
->
m_tensor
->
comp_node
();
:
py
::
handle
(
handle
).
cast
<
PySymbolVar
*>
()
->
m_node
->
comp_node
();
valid
=
true
;
valid
=
true
;
}
else
{
}
else
{
CompNode
cn1
=
tw
?
tw
->
m_tensor
->
comp_node
()
CompNode
cn1
=
tw
->
m_tensor
->
comp_node
();
:
py
::
handle
(
handle
)
.
cast
<
PySymbolVar
*>
()
->
m_node
->
comp_node
();
if
(
cn1
!=
cn
)
{
if
(
cn1
!=
cn
)
{
throw
py
::
value_error
(
ssprintf
(
throw
py
::
value_error
(
ssprintf
(
"ambiguous device: %s (from %s) vs %s (from %s)"
,
"ambiguous device: %s (from %s) vs %s (from %s)"
,
...
@@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
...
@@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
}
}
bool
is_scalar
(
PyObject
*
tensor
)
{
bool
is_scalar
(
PyObject
*
tensor
)
{
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
tensor
)))
{
auto
var
=
py
::
handle
(
tensor
).
cast
<
PySymbolVar
*>
();
return
var
->
is_scalar
;
}
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
);
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
);
if
(
tw
)
{
if
(
tw
)
{
return
tw
->
m_tensor
->
is_scalar
();
return
tw
->
m_tensor
->
is_scalar
();
...
@@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) {
...
@@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) {
}
}
}
}
py
::
object
_Const
(
py
::
object
_Const
(
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
)
{
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
,
py
::
handle
ref_hdl
)
{
py
::
object
val
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
py
::
object
val
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
if
(
PyArray_Check
(
value
.
ptr
()))
{
if
(
PyArray_Check
(
value
.
ptr
()))
{
py
::
tuple
strides
=
py
::
tuple
strides
=
...
@@ -338,32 +319,6 @@ py::object _Const(
...
@@ -338,32 +319,6 @@ py::object _Const(
val
=
val
.
attr
(
"reshape"
)(
orig_shp
);
val
=
val
.
attr
(
"reshape"
)(
orig_shp
);
}
}
}
}
py
::
object
ref
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
ref_hdl
))
{
py
::
tuple
tup
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
ref_hdl
);
if
(
tup
.
size
())
{
ref
=
tup
[
0
];
}
else
{
ref
=
py
::
none
();
}
}
else
{
ref
=
py
::
reinterpret_borrow
<
py
::
object
>
(
ref_hdl
);
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
ref
))
{
auto
ref_var
=
ref
.
cast
<
PySymbolVar
*>
();
auto
*
graph
=
ref_var
->
m_node
->
owner_graph
();
CompNode
cn
;
if
(
device
.
ptr
()
==
Py_None
)
{
cn
=
ref_var
->
m_node
->
comp_node
();
}
else
{
cn
=
device2obj
(
device
).
cast
<
CompNode
>
();
}
OperatorNodeConfig
config
(
cn
);
auto
hv
=
npy
::
np2tensor
(
val
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
.
cast
<
mgb
::
DType
>
());
auto
typeobj
=
ref
.
get_type
();
return
typeobj
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
());
}
py
::
object
device_obj
=
device2obj
(
device
,
true
);
py
::
object
device_obj
=
device2obj
(
device
,
true
);
py
::
tuple
tup
=
py
::
make_tuple
(
val
,
dtype
,
device_obj
,
true
,
false
,
py
::
none
());
py
::
tuple
tup
=
py
::
make_tuple
(
val
,
dtype
,
device_obj
,
true
,
false
,
py
::
none
());
return
TensorWrapper
::
make
(
py_tensor_type
,
tup
.
ptr
(),
nullptr
);
return
TensorWrapper
::
make
(
py_tensor_type
,
tup
.
ptr
(),
nullptr
);
...
@@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) {
...
@@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) {
py
::
list
orig
;
py
::
list
orig
;
py
::
list
ret
(
0
);
py
::
list
ret
(
0
);
auto
solve_one
=
[
&
](
py
::
handle
val
)
{
auto
solve_one
=
[
&
](
py
::
handle
val
)
{
if
(
TensorWrapper
::
try_cast
(
val
.
ptr
())
||
py
::
isinstance
<
PySymbolVar
>
(
val
)
)
{
if
(
TensorWrapper
::
try_cast
(
val
.
ptr
()))
{
py
::
object
np
=
getattr
(
val
,
"numpy"
)();
py
::
object
np
=
getattr
(
val
,
"numpy"
)();
PyArrayObject
*
arr
=
(
PyArrayObject
*
)
np
.
ptr
();
PyArrayObject
*
arr
=
(
PyArrayObject
*
)
np
.
ptr
();
PyObject
*
maybe_list
=
PyArray_ToList
(
arr
);
PyObject
*
maybe_list
=
PyArray_ToList
(
arr
);
...
@@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) {
...
@@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) {
return
py
::
reinterpret_steal
<
py
::
tuple
>
(
PyList_AsTuple
(
ret
.
ptr
()));
return
py
::
reinterpret_steal
<
py
::
tuple
>
(
PyList_AsTuple
(
ret
.
ptr
()));
}
}
bool
is_tensor
_or_symbolvar
(
py
::
handle
arg
)
{
bool
is_tensor
(
py
::
handle
arg
)
{
return
bool
(
TensorWrapper
::
try_cast
(
arg
.
ptr
()))
||
py
::
isinstance
<
PySymbolVar
>
(
arg
)
;
return
bool
(
TensorWrapper
::
try_cast
(
arg
.
ptr
()));
}
}
bool
is_py_sequence
(
py
::
handle
arg
)
{
bool
is_py_sequence
(
py
::
handle
arg
)
{
if
(
PyArray_Check
(
arg
.
ptr
())
||
TensorWrapper
::
try_cast
(
arg
.
ptr
())
||
if
(
PyArray_Check
(
arg
.
ptr
())
||
TensorWrapper
::
try_cast
(
arg
.
ptr
()))
{
py
::
isinstance
<
PySymbolVar
>
(
arg
))
{
return
false
;
return
false
;
}
}
return
PySequence_Check
(
arg
.
ptr
());
return
PySequence_Check
(
arg
.
ptr
());
}
}
mgb
::
DType
_get_dtype
(
py
::
handle
tensor
)
{
py
::
object
get_res_by_refhdl
(
if
(
auto
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
()))
{
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
,
py
::
handle
ref_hdl
)
{
return
tw
->
m_tensor
->
dtype
();
py
::
object
res
=
_Const
(
value
,
dtype
,
device
);
py
::
object
ref
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
ref_hdl
))
{
py
::
tuple
tup
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
ref_hdl
);
if
(
tup
.
size
())
{
ref
=
tup
[
0
];
}
else
{
}
else
{
auto
var
=
tensor
.
cast
<
PySymbolVar
*>
();
ref
=
py
::
none
();
return
var
->
m_node
->
dtype
();
}
}
else
{
ref
=
py
::
reinterpret_borrow
<
py
::
object
>
(
ref_hdl
);
}
if
(
PyObject_TypeCheck
(
ref
.
ptr
(),
py_varnode_type
))
{
auto
temp
=
dtype
.
cast
<
mgb
::
DType
>
();
ComputingGraph
*
graph
=
getattr
(
ref
,
"graph"
).
cast
<
ComputingGraph
*>
();
cg
::
VarNode
*
node
=
getattr
(
ref
,
"var"
).
cast
<
cg
::
VarNode
*>
();
CompNode
cn
;
if
(
device
.
ptr
()
==
Py_None
)
{
cn
=
node
->
comp_node
();
}
else
{
cn
=
device2obj
(
device
).
cast
<
CompNode
>
();
}
OperatorNodeConfig
config
(
cn
);
auto
hv
=
npy
::
np2tensor
(
value
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
.
cast
<
mgb
::
DType
>
());
auto
typeobj
=
ref
.
get_type
();
return
typeobj
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
());
}
}
return
res
;
}
mgb
::
DType
_get_dtype
(
py
::
handle
tensor
)
{
auto
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
return
tw
->
m_tensor
->
dtype
();
}
}
py
::
object
_astype_cpp
(
py
::
handle
tensor
,
py
::
handle
dtype_hdl
)
{
py
::
object
_astype_cpp
(
py
::
handle
tensor
,
py
::
handle
dtype_hdl
)
{
...
@@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
...
@@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
py
::
object
_convert_single_value_cpp
(
py
::
object
_convert_single_value_cpp
(
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
)
{
py
::
handle
value
,
py
::
handle
dtype
,
py
::
handle
device
)
{
if
(
is_tensor
_or_symbolvar
(
value
))
{
if
(
is_tensor
(
value
))
{
if
(
_get_dtype
(
value
).
category
()
!=
DTypeCategory
::
QUANTIZED
)
{
if
(
_get_dtype
(
value
).
category
()
!=
DTypeCategory
::
QUANTIZED
)
{
return
_astype_cpp
(
value
,
dtype
);
return
_astype_cpp
(
value
,
dtype
);
}
}
}
else
{
}
else
{
return
_Const
(
value
,
dtype
,
device
,
py
::
none
()
);
return
_Const
(
value
,
dtype
,
device
);
}
}
return
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
return
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
}
}
...
@@ -475,28 +458,8 @@ py::object _convert_inputs_cpp(
...
@@ -475,28 +458,8 @@ py::object _convert_inputs_cpp(
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
py
::
handle
h
=
py
::
handle
(
args
[
i
]);
py
::
handle
h
=
py
::
handle
(
args
[
i
]);
lis
.
append
(
h
);
lis
.
append
(
h
);
if
(
py
::
isinstance
<
PySymbolVar
>
(
h
))
{
auto
var
=
h
.
cast
<
PySymbolVar
*>
();
auto
g
=
var
->
m_node
->
owner_graph
();
if
(
!
graph
)
{
graph
=
g
;
typeobj
=
h
.
get_type
();
}
else
{
mgb_assert
(
graph
==
g
);
}
}
}
if
(
graph
)
{
CompNode
cn
=
device2obj
(
device
).
cast
<
CompNode
>
();
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
OperatorNodeConfig
config
(
cn
);
auto
hv
=
npy
::
np2tensor
(
lis
[
i
].
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
.
cast
<
mgb
::
DType
>
());
if
(
!
py
::
isinstance
<
PySymbolVar
>
(
lis
[
i
]))
{
lis
[
i
]
=
typeobj
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
());
}
}
}
}
auto
convert
=
[
&
](
py
::
object
value
)
{
auto
convert
=
[
&
](
py
::
object
value
)
{
if
(
value
.
is_none
())
{
if
(
value
.
is_none
())
{
return
value
;
return
value
;
...
@@ -517,7 +480,8 @@ py::object _astensor1d_cpp(
...
@@ -517,7 +480,8 @@ py::object _astensor1d_cpp(
if
(
device
.
ptr
()
!=
Py_None
)
{
if
(
device
.
ptr
()
!=
Py_None
)
{
device_obj
=
device2obj
(
device
);
device_obj
=
device2obj
(
device
);
}
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
value
))
{
if
(
PyObject_TypeCheck
(
value
.
ptr
(),
py_varnode_type
))
{
try
{
try
{
getattr
(
value
,
"ndim"
);
getattr
(
value
,
"ndim"
);
}
catch
(
py
::
error_already_set
&
err
)
{
}
catch
(
py
::
error_already_set
&
err
)
{
...
@@ -537,14 +501,15 @@ py::object _astensor1d_cpp(
...
@@ -537,14 +501,15 @@ py::object _astensor1d_cpp(
return
ret
;
return
ret
;
}
}
}
}
size_t
ndim
=
999
;
size_t
ndim
=
999
;
if
(
hasattr
(
value
,
"ndim"
))
{
if
(
hasattr
(
value
,
"ndim"
))
{
ndim
=
getattr
(
value
,
"ndim"
).
cast
<
size_t
>
();
ndim
=
getattr
(
value
,
"ndim"
).
cast
<
size_t
>
();
if
(
ndim
!=
0
&&
ndim
!=
1
)
{
if
(
ndim
!=
0
&&
ndim
!=
1
)
{
throw
py
::
value_error
(
"ndim != 1 or 0, get : "
+
std
::
to_string
(
ndim
));
throw
py
::
value_error
(
"ndim != 1 or 0, get : "
+
std
::
to_string
(
ndim
));
}
}
if
(
!
is_tensor
_or_symbolvar
(
value
))
{
if
(
!
is_tensor
(
value
))
{
return
_Const
(
value
,
dtype
,
device
,
ref
);
return
get_res_by_refhdl
(
value
,
dtype
,
device
,
ref
);
}
else
{
}
else
{
return
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
return
py
::
reinterpret_borrow
<
py
::
object
>
(
value
);
}
}
...
@@ -555,13 +520,13 @@ py::object _astensor1d_cpp(
...
@@ -555,13 +520,13 @@ py::object _astensor1d_cpp(
py
::
list
lis
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
value
.
ptr
()));
py
::
list
lis
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
value
.
ptr
()));
bool
need_concat
=
false
;
bool
need_concat
=
false
;
for
(
size_t
i
=
0
;
i
<
lis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
lis
.
size
();
++
i
)
{
if
(
is_tensor
_or_symbolvar
(
lis
[
i
]))
{
if
(
is_tensor
(
lis
[
i
]))
{
need_concat
=
true
;
need_concat
=
true
;
break
;
break
;
}
}
}
}
if
(
!
need_concat
)
{
if
(
!
need_concat
)
{
return
_Const
(
value
,
dtype
,
device
,
ref
);
return
get_res_by_refhdl
(
value
,
dtype
,
device
,
ref
);
}
}
if
(
lis
.
size
()
>
1
)
{
if
(
lis
.
size
()
>
1
)
{
std
::
vector
<
PyObject
*>
c_args
(
lis
.
size
()
+
1
);
std
::
vector
<
PyObject
*>
c_args
(
lis
.
size
()
+
1
);
...
@@ -600,10 +565,9 @@ py::object _astensor1d_cpp(
...
@@ -600,10 +565,9 @@ py::object _astensor1d_cpp(
}
}
py
::
object
_get_index
(
py
::
object
tensor
,
py
::
object
src
)
{
py
::
object
_get_index
(
py
::
object
tensor
,
py
::
object
src
)
{
if
(
!
TensorWrapper
::
try_cast
(
tensor
.
ptr
())
&&
if
(
!
TensorWrapper
::
try_cast
(
tensor
.
ptr
()))
{
!
py
::
isinstance
<
PySymbolVar
>
(
tensor
))
{
auto
get_const
=
[
&
](
mgb
::
DType
dtype
)
->
py
::
object
{
auto
get_const
=
[
&
](
mgb
::
DType
dtype
)
->
py
::
object
{
return
_Const
(
tensor
,
py
::
cast
(
dtype
),
src
.
attr
(
"device"
)
,
src
);
return
_Const
(
tensor
,
py
::
cast
(
dtype
),
src
.
attr
(
"device"
));
};
};
if
(
is_bool_list
(
tensor
.
ptr
())
||
is_bool_dtype
(
tensor
.
ptr
()))
{
if
(
is_bool_list
(
tensor
.
ptr
())
||
is_bool_dtype
(
tensor
.
ptr
()))
{
tensor
=
get_const
(
dtype
::
Bool
());
tensor
=
get_const
(
dtype
::
Bool
());
...
@@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) {
...
@@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) {
}
}
py
::
object
iobj
;
py
::
object
iobj
;
if
(
PyArray_Check
(
index
.
ptr
()))
{
if
(
PyArray_Check
(
index
.
ptr
()))
{
iobj
=
iobj
=
_Const
(
_Const
(
index
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Bool
()),
index
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Bool
()),
getattr
(
tensor
,
"device"
));
getattr
(
tensor
,
"device"
),
tensor
);
}
else
{
}
else
{
iobj
=
py
::
reinterpret_borrow
<
py
::
object
>
(
index
);
iobj
=
py
::
reinterpret_borrow
<
py
::
object
>
(
index
);
}
}
...
@@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) {
...
@@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) {
return
py
::
reinterpret_borrow
<
py
::
object
>
(
args
);
return
py
::
reinterpret_borrow
<
py
::
object
>
(
args
);
}
}
py
::
tuple
args_tup
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
args
.
ptr
());
py
::
tuple
args_tup
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
args
.
ptr
());
if
(
args_tup
.
size
()
==
1
&&
(
PySequence_Check
(
args_tup
[
0
].
ptr
())
||
if
(
args_tup
.
size
()
==
1
&&
is_tensor_or_symbolva
r
(
args_tup
[
0
].
ptr
())))
{
(
PySequence_Check
(
args_tup
[
0
].
ptr
())
||
is_tenso
r
(
args_tup
[
0
].
ptr
())))
{
return
py
::
reinterpret_borrow
<
py
::
object
>
(
args_tup
[
0
]);
return
py
::
reinterpret_borrow
<
py
::
object
>
(
args_tup
[
0
]);
}
else
{
}
else
{
return
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
args_tup
.
ptr
()));
return
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
args_tup
.
ptr
()));
...
@@ -948,7 +911,8 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
...
@@ -948,7 +911,8 @@ std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
bool
enable_fastpath
(
py
::
handle
inp
)
{
bool
enable_fastpath
(
py
::
handle
inp
)
{
auto
&&
tm_tr
=
TransformationManager
::
get_instance
()
auto
&&
tm_tr
=
TransformationManager
::
get_instance
()
.
segments
[
TransformationManager
::
Segment
::
ModuleTrace
];
.
segments
[
TransformationManager
::
Segment
::
ModuleTrace
];
if
(
!
TensorWrapper
::
try_cast
(
inp
.
ptr
())
||
bool
is_varnode
=
PyObject_TypeCheck
(
inp
.
ptr
(),
py_varnode_type
);
if
(
is_varnode
||
TransformationManager
::
get_instance
()
TransformationManager
::
get_instance
()
.
segments
[
TransformationManager
::
Segment
::
Trace
]
.
segments
[
TransformationManager
::
Segment
::
Trace
]
.
size
()
>
0
||
.
size
()
>
0
||
...
@@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
...
@@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py
::
object
_setitem_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
idx_hdl
,
py
::
handle
val_hdl
)
{
py
::
object
_setitem_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
idx_hdl
,
py
::
handle
val_hdl
)
{
py
::
object
org_shape
=
getattr
(
inp_hdl
,
"shape"
);
py
::
object
org_shape
=
getattr
(
inp_hdl
,
"shape"
);
py
::
object
val
=
py
::
reinterpret_borrow
<
py
::
object
>
(
val_hdl
);
py
::
object
val
=
py
::
reinterpret_borrow
<
py
::
object
>
(
val_hdl
);
if
(
!
TensorWrapper
::
try_cast
(
val
.
ptr
())
&&
!
py
::
isinstance
<
PySymbolVar
>
(
val
))
{
if
(
!
TensorWrapper
::
try_cast
(
val
.
ptr
()))
{
val
=
val
=
_Const
(
val_hdl
,
getattr
(
inp_hdl
,
"dtype"
),
getattr
(
inp_hdl
,
"device"
));
_Const
(
val_hdl
,
getattr
(
inp_hdl
,
"dtype"
),
getattr
(
inp_hdl
,
"device"
),
inp_hdl
);
}
}
py
::
tuple
up
=
_unpack_indexes
(
inp_hdl
,
idx_hdl
);
py
::
tuple
up
=
_unpack_indexes
(
inp_hdl
,
idx_hdl
);
...
@@ -1308,12 +1270,12 @@ py::object _split_cpp(
...
@@ -1308,12 +1270,12 @@ py::object _split_cpp(
repr
(
nsplits_or_sections_hdl
).
cast
<
std
::
string
>
());
repr
(
nsplits_or_sections_hdl
).
cast
<
std
::
string
>
());
}
}
py
::
object
pos
=
div_points
[
i
]
-
div_points
[
i
-
1
];
py
::
object
pos
=
div_points
[
i
]
-
div_points
[
i
-
1
];
if
(
is_tensor
_or_symbolvar
(
pos
))
{
if
(
is_tensor
(
pos
))
{
partitions
.
append
(
pos
);
partitions
.
append
(
pos
);
}
else
{
}
else
{
partitions
.
append
(
partitions
.
append
(
_Const
(
pos
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
_Const
(
pos
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
getattr
(
inp_hdl
,
"device"
)
,
inp_hdl
));
getattr
(
inp_hdl
,
"device"
)));
}
}
}
}
op
=
Split
::
make
(
axis
,
0
);
op
=
Split
::
make
(
axis
,
0
);
...
@@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
...
@@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
py
::
object
_transpose_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
args
)
{
py
::
object
_transpose_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
args
)
{
py
::
object
obj
=
_expand_args
(
args
);
py
::
object
obj
=
_expand_args
(
args
);
py
::
list
lis
;
py
::
list
lis
;
if
(
!
is_tensor
_or_symbolvar
(
obj
.
ptr
())
&&
PySequence_Check
(
obj
.
ptr
()))
{
if
(
!
is_tensor
(
obj
.
ptr
())
&&
PySequence_Check
(
obj
.
ptr
()))
{
lis
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
obj
.
ptr
()));
lis
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
obj
.
ptr
()));
}
else
{
}
else
{
py
::
object
np
=
getattr
(
obj
,
"numpy"
)();
py
::
object
np
=
getattr
(
obj
,
"numpy"
)();
...
@@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs)
...
@@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs)
PyObject
*
Const
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
PyObject
*
Const
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
try
{
return
_Const
(
args
[
0
],
args
[
1
],
args
[
2
]
,
args
[
3
]
).
release
().
ptr
();
return
_Const
(
args
[
0
],
args
[
1
],
args
[
2
]).
release
().
ptr
();
}
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
...
...
imperative/python/src/transformation.h
浏览文件 @
a926878c
...
@@ -20,11 +20,12 @@ public:
...
@@ -20,11 +20,12 @@ public:
DimExpansion
,
DimExpansion
,
Grad
,
Grad
,
Scalar
,
Scalar
,
Symbol
,
Trace
,
Trace
,
Eval
,
Eval
,
};
};
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
7
>
segments
;
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
8
>
segments
;
private:
private:
template
<
Segment
segment
>
template
<
Segment
segment
>
...
...
imperative/python/test/helpers/utils.py
浏览文件 @
a926878c
...
@@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode
...
@@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode
def
_default_compare_fn
(
x
,
y
):
def
_default_compare_fn
(
x
,
y
):
if
isinstance
(
x
,
tensor
):
if
isinstance
(
x
,
tensor
)
and
not
isinstance
(
x
,
VarNode
)
:
x
=
x
.
numpy
()
x
=
x
.
numpy
()
elif
not
isinstance
(
x
,
np
.
ndarray
):
elif
not
isinstance
(
x
,
np
.
ndarray
):
x
=
get_var_value
(
x
)
x
=
get_var_value
(
x
)
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
a926878c
...
@@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode):
...
@@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode):
assert
isinstance
(
xx
,
type
(
reference
))
assert
isinstance
(
xx
,
type
(
reference
))
np
.
testing
.
assert_equal
(
xx
.
numpy
(),
[
1
,
2
,
3
])
np
.
testing
.
assert_equal
(
xx
.
numpy
(),
[
1
,
2
,
3
])
# varnode
if
is_varnode
:
a
=
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]).
astype
(
"float32"
)
b
=
np
.
array
([[
True
,
False
,
True
],
[
False
,
True
,
True
]])
aa
=
make_tensor
(
a
,
network
)
bb
=
make_tensor
(
b
,
network
)
x
,
y
=
F
.
cond_take
(
bb
,
aa
)
for
dtype
in
[
None
,
"float32"
]:
xx
=
astensor1d
(
x
,
reference
,
dtype
=
dtype
)
assert
isinstance
(
xx
,
type
(
reference
))
np
.
testing
.
assert_equal
(
get_var_value
(
xx
),
get_var_value
(
x
))
def
test_device
():
def
test_device
():
x
=
tensor
([
1
,
2
,
3
],
dtype
=
"float32"
)
x
=
tensor
([
1
,
2
,
3
],
dtype
=
"float32"
)
...
...
imperative/python/test/unit/utils/test_network.py
浏览文件 @
a926878c
...
@@ -114,8 +114,10 @@ def test_replace_opr():
...
@@ -114,8 +114,10 @@ def test_replace_opr():
vara
=
graph
.
var_filter
.
name
(
"a"
).
as_unique
()
vara
=
graph
.
var_filter
.
name
(
"a"
).
as_unique
()
varb
=
graph
.
var_filter
.
name
(
"b"
).
as_unique
()
varb
=
graph
.
var_filter
.
name
(
"b"
).
as_unique
()
out1
=
F
.
sub
(
vara
,
varb
)
out1
=
F
.
mul
(
vara
,
varb
)
out1
=
F
.
relu
(
out1
)
out1
=
F
.
relu
(
out1
)
out1
+=
2
out1
*=
3
out1
=
graph
.
add_dep_oprs
(
out1
)
out1
=
graph
.
add_dep_oprs
(
out1
)
orig_opr
=
graph
.
opr_filter
.
has_input
(
vara
).
as_unique
()
orig_opr
=
graph
.
opr_filter
.
has_input
(
vara
).
as_unique
()
...
@@ -135,7 +137,7 @@ def test_replace_opr():
...
@@ -135,7 +137,7 @@ def test_replace_opr():
load_graph
=
GraphInference
(
modified_model1
)
load_graph
=
GraphInference
(
modified_model1
)
out
=
load_graph
.
run
(
a
,
b
)
out
=
load_graph
.
run
(
a
,
b
)
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
0
,
0
])
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
30
,
6
0
])
def
test_splice_network
():
def
test_splice_network
():
...
...
imperative/src/impl/basic_operators.cpp
浏览文件 @
a926878c
...
@@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const {
...
@@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const {
return
ssprintf
(
"DTRCommandValue{kind=%d}"
,
(
int
)
m_kind
);
return
ssprintf
(
"DTRCommandValue{kind=%d}"
,
(
int
)
m_kind
);
}
}
std
::
string
CreateNode
::
to_string
()
const
{
return
"CreateNode"
;
}
std
::
string
GetName
::
to_string
()
const
{
std
::
string
GetName
::
to_string
()
const
{
return
"GetName{}"
;
return
"GetName{}"
;
}
}
...
@@ -94,5 +98,9 @@ std::string IsScalar::to_string() const {
...
@@ -94,5 +98,9 @@ std::string IsScalar::to_string() const {
return
"IsScalar"
;
return
"IsScalar"
;
}
}
std
::
string
GetVarVal
::
to_string
()
const
{
return
"GetVarVal"
;
}
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb
imperative/src/include/megbrain/imperative/basic_operators.h
浏览文件 @
a926878c
...
@@ -157,5 +157,22 @@ public:
...
@@ -157,5 +157,22 @@ public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
GetVarVal
final
:
public
OperatorImpl
<
GetVarVal
,
Operator
::
GetAttrLike
>
{
public:
std
::
string
to_string
()
const
override
;
};
class
CreateNode
final
:
public
OperatorImpl
<
CreateNode
>
{
private:
cg
::
VarNode
*
m_node
;
public:
CreateNode
(
cg
::
VarNode
*
node
)
:
m_node
(
node
)
{}
cg
::
VarNode
*
node
()
const
{
return
m_node
;
}
std
::
string
to_string
()
const
override
;
};
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb
imperative/src/include/megbrain/imperative/basic_values.h
浏览文件 @
a926878c
...
@@ -173,5 +173,24 @@ public:
...
@@ -173,5 +173,24 @@ public:
std
::
string
to_string
()
const
override
;
std
::
string
to_string
()
const
override
;
};
};
class
NodeStorage
{
private:
cg
::
VarNode
*
m_node
;
public:
NodeStorage
()
=
default
;
NodeStorage
(
VarNode
*
node
)
:
m_node
(
node
)
{}
VarNode
*
node
()
const
{
return
m_node
;
}
ComputingGraph
*
graph
()
const
{
return
m_node
->
owner_graph
();
}
std
::
string
to_string
()
const
{
return
m_node
->
name
();
}
};
class
NodeValue
final
:
public
PrimitiveValue
<
NodeValue
,
NodeStorage
>
{
public:
using
PrimitiveValue
::
PrimitiveValue
;
std
::
string
to_string
()
const
override
{
return
NodeStorage
::
to_string
();
}
};
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb
imperative/src/include/megbrain/imperative/transformations/symbol.h
浏览文件 @
a926878c
...
@@ -39,29 +39,49 @@ private:
...
@@ -39,29 +39,49 @@ private:
ObjectType
<
SymbolValue
>
m_value_type
{
"SymbolValue"
};
ObjectType
<
SymbolValue
>
m_value_type
{
"SymbolValue"
};
public:
public:
SymbolTransformation
(
ComputingGraph
*
graph
)
:
m_graph
(
graph
)
{}
SymbolTransformation
()
{}
ValueRefList
apply_transformation
(
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
{
ComputingGraph
*
cg
=
nullptr
;
if
(
auto
*
node_value
=
op
.
as
<
CreateNode
>
())
{
return
{
m_value_type
.
make
(
node_value
->
node
())};
}
for
(
auto
&&
input
:
inputs
)
{
if
(
auto
*
val
=
input
.
as
(
m_value_type
))
{
auto
*
node
=
val
->
node
();
ComputingGraph
*
cur_cg
=
node
->
owner_graph
();
if
(
cg
==
nullptr
)
{
cg
=
cur_cg
;
}
else
{
mgb_assert
(
cg
==
cur_cg
,
"input varnode gragh should be the same"
);
}
}
}
if
(
!
cg
)
{
return
imperative
::
apply
(
op
,
inputs
);
}
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
SmallVector
<
VarNode
*>
input_nodes
;
SmallVector
<
VarNode
*>
input_nodes
;
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
if
(
!
input
.
is
(
m_value_type
))
{
auto
*
node
=
opr
::
ImmutableTensor
::
make
(
*
cg
,
input
.
numpy
()
->
as_nd
(
true
),
{})
.
node
();
input_nodes
.
push_back
(
node
);
}
else
{
input_nodes
.
push_back
(
input
.
cast
(
m_value_type
).
node
());
input_nodes
.
push_back
(
input
.
cast
(
m_value_type
).
node
());
}
}
}
auto
output_nodes
=
OpDef
::
apply_on_var_node
(
apply_op
->
op
(),
input_nodes
);
auto
output_nodes
=
OpDef
::
apply_on_var_node
(
apply_op
->
op
(),
input_nodes
);
ValueRefList
outputs
(
output_nodes
.
size
());
ValueRefList
outputs
(
output_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
outputs
[
i
]
=
m_value_type
.
make
(
output_nodes
[
i
]);
outputs
[
i
]
=
m_value_type
.
make
(
output_nodes
[
i
]);
}
}
return
outputs
;
return
outputs
;
}
else
if
(
auto
*
create_tensor
=
op
.
as
<
CreateTensor
>
())
{
auto
&&
args
=
create_tensor
->
parse
(
inputs
);
mgb_assert
(
args
.
kind
==
CreateTensor
::
Const
,
"only const value is allowed here"
);
auto
*
node
=
opr
::
ImmutableTensor
::
make
(
*
m_graph
,
*
args
.
host
,
{}).
node
();
return
{
m_value_type
.
make
(
node
)};
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetAttr
>
())
{
auto
*
node
=
inputs
.
item
().
cast
(
m_value_type
).
node
();
auto
*
node
=
inputs
.
item
().
cast
(
m_value_type
).
node
();
auto
*
m_graph
=
node
->
owner_graph
();
switch
(
get_attr
->
attr
())
{
switch
(
get_attr
->
attr
())
{
case
GetAttr
::
DType
:
case
GetAttr
::
DType
:
return
{
DTypeValue
::
make
(
node
->
dtype
())};
return
{
DTypeValue
::
make
(
node
->
dtype
())};
...
@@ -105,6 +125,10 @@ public:
...
@@ -105,6 +125,10 @@ public:
MegBrainError
,
"Symbol: malformed GetAttr: %s"
,
MegBrainError
,
"Symbol: malformed GetAttr: %s"
,
op
.
to_string
().
c_str
());
op
.
to_string
().
c_str
());
}
}
}
else
if
(
auto
*
get_attr
=
op
.
as
<
GetVarVal
>
())
{
cg
::
VarNode
*
node
=
inputs
.
item
().
cast
(
m_value_type
).
node
();
NodeStorage
inp_var
=
NodeStorage
(
node
);
return
{
NodeValue
::
make
(
inp_var
)};
}
else
{
}
else
{
return
op
.
fallback
(
inputs
);
return
op
.
fallback
(
inputs
);
}
}
...
...
imperative/src/include/megbrain/imperative/value.h
浏览文件 @
a926878c
...
@@ -33,6 +33,7 @@ class ShapeValue;
...
@@ -33,6 +33,7 @@ class ShapeValue;
class
DTypeValue
;
class
DTypeValue
;
class
CompNodeValue
;
class
CompNodeValue
;
class
StringValue
;
class
StringValue
;
class
NodeValue
;
class
Operator
;
class
Operator
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录