Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ae3123b3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ae3123b3
编写于
2月 05, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add python graph for mgb graph editing
GitOrigin-RevId: 6a9d5beba2eb0ebce3908e785f041ba0aa7085a4
上级
c82d8875
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
2555 addition
and
30 deletion
+2555
-30
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+2
-1
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+4
-3
imperative/python/megengine/utils/comp_graph_tools.py
imperative/python/megengine/utils/comp_graph_tools.py
+7
-8
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+682
-0
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+628
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+25
-3
imperative/python/test/unit/utils/test_network.py
imperative/python/test/unit/utils/test_network.py
+351
-0
imperative/python/test/unit/utils/test_opr.py
imperative/python/test/unit/utils/test_opr.py
+712
-0
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+144
-15
未找到文件。
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
ae3123b3
...
...
@@ -11,6 +11,7 @@ from typing import Iterable, Union
import
numpy
as
np
from
.._imperative_rt
import
VarNode
from
.._imperative_rt.core2
import
Tensor
,
apply
,
dtype_promotion
,
get_device
from
..ops
import
builtin
from
..ops.special
import
Const
...
...
@@ -59,7 +60,7 @@ def astype(x, dtype):
def
convert_single_value
(
v
,
*
,
dtype
=
None
,
device
=
None
):
if
isinstance
(
v
,
Tensor
):
if
isinstance
(
v
,
(
Tensor
,
VarNode
)
):
if
not
is_quantize
(
v
.
dtype
):
v
=
astype
(
v
,
dtype
)
else
:
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
ae3123b3
...
...
@@ -12,11 +12,12 @@ import functools
import
numpy
as
np
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.graph
import
VarNode
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Elemwise
from
..core.tensor
import
utils
from
..core.tensor.array_method
import
_elwise_apply
from
..core.tensor.utils
import
isscalar
,
setscalar
from
..core.tensor.utils
import
astype
,
isscalar
,
setscalar
from
..device
import
get_default_device
from
..jit.tracing
import
is_tracing
from
..tensor
import
Tensor
...
...
@@ -77,7 +78,7 @@ __all__ = [
def
_elwise
(
*
args
,
mode
):
tensor_args
=
list
(
filter
(
lambda
x
:
isinstance
(
x
,
Tensor
),
args
))
tensor_args
=
list
(
filter
(
lambda
x
:
isinstance
(
x
,
(
Tensor
,
VarNode
)
),
args
))
if
len
(
tensor_args
)
==
0
:
dtype
=
utils
.
dtype_promotion
(
args
)
first_arg
=
Tensor
(
args
[
0
],
dtype
=
dtype
,
device
=
get_default_device
())
...
...
@@ -109,7 +110,7 @@ def _elwise(*args, mode):
Elemwise
.
Mode
.
ROUND
,
)
and
np
.
issubdtype
(
args
[
0
].
dtype
,
np
.
integer
):
return
args
[
0
]
args
=
tuple
(
map
(
lambda
x
:
x
.
astype
(
"float32"
),
args
))
args
=
tuple
(
map
(
lambda
x
:
astype
(
x
,
"float32"
),
args
))
return
_elwise_apply
(
args
,
mode
)
...
...
imperative/python/megengine/utils/comp_graph_tools.py
浏览文件 @
ae3123b3
...
...
@@ -65,7 +65,6 @@ def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
"""
Gets the inputs of owner opr of a variable.
"""
assert
isinstance
(
var
,
VarNode
)
return
var
.
owner
.
inputs
...
...
@@ -74,7 +73,6 @@ def get_owner_opr_type(var: VarNode) -> str:
Gets the type of owner opr of a variable.
"""
assert
isinstance
(
var
,
VarNode
)
return
var
.
owner
.
type
...
...
@@ -109,7 +107,7 @@ def graph_traversal(outputs: VarNode):
var2oprs
=
collections
.
defaultdict
(
list
)
opr2receivers
=
collections
.
defaultdict
(
list
)
queue
=
list
(
map
(
lambda
x
:
x
.
owner
,
outputs
))
queue
=
list
(
set
(
map
(
lambda
x
:
x
.
owner
,
outputs
)
))
visited
=
set
(
map
(
lambda
x
:
x
.
id
,
queue
))
# iterate through whole comp_graph, fill in meta information
...
...
@@ -143,12 +141,15 @@ def graph_traversal(outputs: VarNode):
return
map_oprs
,
map_vars
,
var2oprs
,
opr2receivers
,
indegree2opr
,
opr2indegree
def
get_oprs_seq
(
outputs
:
List
[
VarNode
],
prune_reshape
=
False
)
->
List
[
OperatorNode
]:
def
get_oprs_seq
(
outputs
:
List
[
VarNode
],
prune_reshape
=
False
,
prune_immtensor
=
True
)
->
List
[
OperatorNode
]:
"""
Gets oprs in some topological order for a dumped model.
:param outputs: model outputs.
:param prune_reshape: whether to prune the useless operators during inference.
:param prune_reshape: whether to prune the useless operators used by Reshape opr during inference.
:param prune_immtensor: whether to prune the ImmutableTensor opr.
:return: opr list with some correct execution order.
"""
...
...
@@ -160,9 +161,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
opr_id
=
indegree2opr
[
0
].
pop
()
opr
=
map_oprs
[
opr_id
]
nr_remain
-=
1
# skip const value generation operator
if
get_opr_type
(
opr
)
!=
"ImmutableTensor"
:
if
opr
.
type
!=
"ImmutableTensor"
or
not
prune_immtensor
:
oprs_seq
.
append
(
opr
)
for
post_id
in
opr2receivers
[
opr_id
]:
...
...
imperative/python/megengine/utils/network.py
0 → 100644
浏览文件 @
ae3123b3
此差异已折叠。
点击以展开。
imperative/python/megengine/utils/network_node.py
0 → 100644
浏览文件 @
ae3123b3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
json
import
sys
from
typing
import
Callable
from
..core
import
_imperative_rt
as
rt
from
..core._wrap
import
Device
from
..core.ops
import
builtin
from
..core.tensor.megbrain_graph
import
InputNode
from
..tensor
import
Tensor
from
.comp_graph_tools
import
replace_vars
class
NetworkNode
:
pass
class
VarNode
(
NetworkNode
):
def
__init__
(
self
,
owner_opr
=
None
,
name
=
None
):
self
.
var
=
None
self
.
owner
=
owner_opr
self
.
name
=
name
self
.
id
=
id
(
self
)
@
classmethod
def
load
(
cls
,
sym_var
,
owner_opr
):
obj
=
cls
()
obj
.
var
=
sym_var
# mgb varnode
obj
.
name
=
sym_var
.
name
obj
.
owner
=
owner_opr
return
obj
@
property
def
shape
(
self
):
rst
=
None
if
self
.
var
:
try
:
rst
=
self
.
var
.
shape
except
:
rst
=
None
return
rst
@
property
def
dtype
(
self
):
return
self
.
var
.
dtype
if
self
.
var
else
None
def
set_owner_opr
(
self
,
owner_opr
):
self
.
owner_opr
=
owner_opr
class
OpNode
(
NetworkNode
):
opdef
=
None
type
=
None
def
__init__
(
self
):
self
.
inputs
=
[]
self
.
outputs
=
[]
self
.
params
=
{}
self
.
_opr
=
None
# mgb opnode
self
.
id
=
id
(
self
)
@
classmethod
def
load
(
cls
,
opr
):
obj
=
cls
()
obj
.
params
=
json
.
loads
(
opr
.
params
)
obj
.
name
=
opr
.
name
obj
.
_opr
=
opr
return
obj
def
compile
(
self
,
graph
=
None
):
op
=
self
.
opdef
(
**
self
.
params
)
args
=
[
i
.
var
for
i
in
self
.
inputs
]
outputs
=
rt
.
invoke_op
(
op
,
args
)
assert
len
(
outputs
)
==
len
(
self
.
outputs
)
self
.
_opr
=
outputs
[
0
].
owner
for
i
in
range
(
len
(
self
.
outputs
)):
self
.
outputs
[
i
].
var
=
outputs
[
i
]
self
.
outputs
[
i
].
var
.
name
=
self
.
outputs
[
i
].
name
assert
self
.
outputs
[
i
].
owner
is
self
def
add_inp_var
(
self
,
x
):
self
.
inputs
.
append
(
x
)
def
add_out_var
(
self
,
x
):
self
.
outputs
.
append
(
x
)
def
str_to_mge_class
(
classname
):
# TODO: use megbrain C++ RTTI to replace type string
if
classname
==
"RNGOpr<MegDNNOpr>"
:
classname
=
"RNGOpr"
oprcls
=
getattr
(
sys
.
modules
[
__name__
],
classname
,
None
)
return
oprcls
if
oprcls
else
ReadOnlyOpNode
class
Host2DeviceCopy
(
OpNode
):
type
=
"Host2DeviceCopy"
def
__init__
(
self
,
shape
=
None
,
dtype
=
None
,
name
=
None
,
device
=
None
):
super
().
__init__
()
self
.
shape
=
shape
self
.
dtype
=
dtype
self
.
name
=
name
self
.
device
=
Device
(
device
).
to_c
()
if
device
else
Device
(
"xpux"
).
to_c
()
self
.
outputs
=
[]
@
classmethod
def
load
(
cls
,
opr
):
self
=
cls
()
self
.
outputs
=
[]
assert
len
(
opr
.
outputs
)
==
1
,
"wrong number of outputs"
self
.
shape
=
opr
.
outputs
[
0
].
shape
self
.
dtype
=
opr
.
outputs
[
0
].
dtype
self
.
name
=
opr
.
outputs
[
0
].
name
self
.
device
=
opr
.
outputs
[
0
].
comp_node
self
.
_opr
=
opr
return
self
def
compile
(
self
,
graph
):
outputs
=
rt
.
make_h2d
(
graph
,
self
.
device
,
self
.
dtype
,
self
.
shape
,
self
.
name
)
self
.
_opr
=
outputs
.
owner
if
len
(
self
.
outputs
)
==
0
:
self
.
outputs
.
append
(
VarNode
(
self
,
self
.
name
))
self
.
outputs
[
0
].
var
=
outputs
assert
self
.
outputs
[
0
].
owner
is
self
class
ImmutableTensor
(
OpNode
):
type
=
"ImmutableTensor"
def
__init__
(
self
,
data
=
None
,
name
=
None
,
device
=
None
,
graph
=
None
):
super
().
__init__
()
self
.
name
=
name
self
.
outputs
=
[]
self
.
graph
=
graph
if
data
is
not
None
:
self
.
set_value
(
data
,
device
)
@
property
def
device
(
self
):
return
self
.
_opr
.
outputs
[
0
].
comp_node
if
self
.
_opr
else
None
@
device
.
setter
def
device
(
self
,
device
):
self
.
set_value
(
self
.
numpy
(),
device
)
@
property
def
shape
(
self
):
return
self
.
outputs
[
0
].
shape
@
property
def
dtype
(
self
):
return
self
.
_opr
.
outputs
[
0
].
dtype
if
self
.
_opr
else
None
def
numpy
(
self
):
return
self
.
_opr
.
outputs
[
0
].
value
if
self
.
_opr
else
None
def
set_value
(
self
,
data
,
device
=
None
):
assert
self
.
graph
is
not
None
cn
=
device
if
device
else
self
.
device
assert
isinstance
(
data
,
(
int
,
float
,
np
.
ndarray
))
if
isinstance
(
data
,
(
int
,
float
)):
data
=
np
.
array
(
data
)
if
data
.
dtype
==
np
.
float64
:
data
=
data
.
astype
(
np
.
float32
)
elif
data
.
dtype
==
np
.
int64
:
data
=
data
.
astype
(
np
.
int32
)
varnode
=
rt
.
make_const
(
self
.
graph
,
data
,
cn
,
data
.
dtype
,
self
.
name
)
if
len
(
self
.
outputs
)
==
0
:
self
.
outputs
.
append
(
VarNode
(
self
,
self
.
name
))
self
.
outputs
[
0
].
var
=
varnode
self
.
_opr
=
varnode
.
owner
@
classmethod
def
load
(
cls
,
opr
):
self
=
cls
()
self
.
outputs
=
[]
self
.
_opr
=
opr
self
.
name
=
opr
.
outputs
[
0
].
name
self
.
graph
=
opr
.
graph
return
self
def
compile
(
self
,
graph
):
assert
self
.
outputs
[
0
].
var
is
self
.
_opr
.
outputs
[
0
]
assert
self
.
outputs
[
0
].
owner
is
self
if
self
.
graph
!=
graph
:
self
.
graph
=
graph
self
.
set_value
(
self
.
numpy
())
if
self
.
name
is
not
None
:
self
.
outputs
[
0
].
var
.
name
=
self
.
name
class
ReadOnlyOpNode
(
OpNode
):
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
ReadOnlyOpNode
,
cls
).
load
(
opr
)
obj
.
type
=
opr
.
type
return
obj
def
compile
(
self
):
assert
self
.
_opr
is
not
None
assert
len
(
self
.
inputs
)
==
len
(
self
.
_opr
.
inputs
)
assert
len
(
self
.
outputs
)
==
len
(
self
.
_opr
.
outputs
)
repl_dict
=
{}
for
ind
,
i
in
enumerate
(
self
.
inputs
):
if
i
.
var
!=
self
.
_opr
.
inputs
[
ind
]:
repl_dict
[
self
.
_opr
.
inputs
[
ind
]]
=
i
.
var
if
bool
(
repl_dict
):
out_vars
=
replace_vars
(
self
.
_opr
.
outputs
,
repl_dict
)
for
ind
,
o
in
enumerate
(
self
.
outputs
):
o
.
var
=
out_vars
[
ind
]
class
Elemwise
(
OpNode
):
type
=
"Elemwise"
opdef
=
builtin
.
Elemwise
class
Reduce
(
OpNode
):
type
=
"Reduce"
opdef
=
builtin
.
Reduce
class
TypeCvt
(
OpNode
):
type
=
"TypeCvt"
opdef
=
builtin
.
TypeCvt
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
TypeCvt
,
cls
).
load
(
opr
)
t_dtype
=
opr
.
outputs
[
0
].
dtype
obj
.
params
[
"dtype"
]
=
t_dtype
return
obj
class
MatrixInverse
(
OpNode
):
type
=
"MatrixInverse"
opdef
=
builtin
.
MatrixInverse
class
MatrixMul
(
OpNode
):
type
=
"MatrixMul"
opdef
=
builtin
.
MatrixMul
class
BatchedMatrixMul
(
OpNode
):
type
=
"BatchedMatmul"
opdef
=
builtin
.
BatchedMatrixMul
class
Dot
(
OpNode
):
type
=
"Dot"
opdef
=
builtin
.
Dot
class
SVD
(
OpNode
):
type
=
"SVD"
opdef
=
builtin
.
SVD
class
ConvolutionForward
(
OpNode
):
type
=
"Convolution"
opdef
=
builtin
.
Convolution
class
ConvolutionBackwardData
(
OpNode
):
type
=
"ConvTranspose"
opdef
=
builtin
.
ConvolutionBackwardData
class
DeformableConvForward
(
OpNode
):
type
=
"DeformableConv"
opdef
=
builtin
.
DeformableConv
class
GroupLocalForward
(
OpNode
):
type
=
"GroupLocal"
opdef
=
builtin
.
GroupLocal
class
PoolingForward
(
OpNode
):
type
=
"Pooling"
opdef
=
builtin
.
Pooling
class
AdaptivePoolingForward
(
OpNode
):
type
=
"AdaptivePooling"
opdef
=
builtin
.
AdaptivePooling
class
ROIPoolingForward
(
OpNode
):
type
=
"ROIPooling"
opdef
=
builtin
.
ROIPooling
class
DeformablePSROIPoolingForward
(
OpNode
):
type
=
"DeformablePSROIPooling"
opdef
=
builtin
.
DeformablePSROIPooling
class
ConvBiasForward
(
OpNode
):
type
=
"ConvBias"
opdef
=
builtin
.
ConvBias
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
ConvBiasForward
,
cls
).
load
(
opr
)
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
class
BatchConvBiasForward
(
OpNode
):
type
=
"BatchConvBias"
opdef
=
builtin
.
BatchConvBias
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
BatchConvBiasForward
,
cls
).
load
(
opr
)
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
class
BatchNormForward
(
OpNode
):
type
=
"BatchNorm"
opdef
=
builtin
.
BatchNorm
class
ROIAlignForward
(
OpNode
):
type
=
"ROIAlign"
opdef
=
builtin
.
ROIAlign
class
WarpPerspectiveForward
(
OpNode
):
type
=
"WarpPerspective"
opdef
=
builtin
.
WarpPerspective
class
WarpAffineForward
(
OpNode
):
type
=
"WarpAffine"
opdef
=
builtin
.
WarpAffine
class
RemapForward
(
OpNode
):
type
=
"Remap"
opdef
=
builtin
.
Remap
class
ResizeForward
(
OpNode
):
type
=
"Resize"
opdef
=
builtin
.
Resize
class
IndexingOneHot
(
OpNode
):
type
=
"IndexingOneHot"
opdef
=
builtin
.
IndexingOneHot
class
IndexingSetOneHot
(
OpNode
):
type
=
"IndexingSetOneHot"
opdef
=
builtin
.
IndexingSetOneHot
class
Copy
(
OpNode
):
type
=
"Copy"
opdef
=
builtin
.
Copy
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
Copy
,
cls
).
load
(
opr
)
obj
.
params
[
"comp_node"
]
=
opr
.
outputs
[
0
].
comp_node
return
obj
class
ArgsortForward
(
OpNode
):
type
=
"Argsort"
opdef
=
builtin
.
Argsort
class
Argmax
(
OpNode
):
type
=
"Argmax"
opdef
=
builtin
.
Argmax
class
Argmin
(
OpNode
):
type
=
"Argmin"
opdef
=
builtin
.
Argmin
class
CondTake
(
OpNode
):
type
=
"CondTake"
opdef
=
builtin
.
CondTake
class
TopK
(
OpNode
):
type
=
"TopK"
opdef
=
builtin
.
TopK
class
NvOf
(
OpNode
):
type
=
"NvOf"
opdef
=
builtin
.
NvOf
class
RNGOpr
(
OpNode
):
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
RNGOpr
,
cls
).
load
(
opr
)
if
len
(
obj
.
params
)
==
3
:
obj
.
opdef
=
builtin
.
GaussianRNG
obj
.
type
=
"GaussianRNG"
else
:
obj
.
opdef
=
builtin
.
UniformRNG
obj
.
type
=
"UniformRNG"
return
obj
class
Linspace
(
OpNode
):
type
=
"Linspace"
opdef
=
builtin
.
Linspace
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
Linspace
,
cls
).
load
(
opr
)
obj
.
params
[
"comp_node"
]
=
opr
.
outputs
[
0
].
comp_node
return
obj
class
Eye
(
OpNode
):
type
=
"Eye"
opdef
=
builtin
.
Eye
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
Eye
,
cls
).
load
(
opr
)
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
obj
.
params
[
"comp_node"
]
=
opr
.
outputs
[
0
].
comp_node
return
obj
class
GetVarShape
(
OpNode
):
type
=
"GetVarShape"
opdef
=
builtin
.
GetVarShape
class
Concat
(
OpNode
):
type
=
"Concat"
opdef
=
builtin
.
Concat
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
Concat
,
cls
).
load
(
opr
)
obj
.
params
[
"comp_node"
]
=
Device
(
"xpux"
).
to_c
()
return
obj
class
Broadcast
(
OpNode
):
type
=
"Broadcast"
opdef
=
builtin
.
Broadcast
class
Identity
(
OpNode
):
type
=
"Identity"
opdef
=
builtin
.
Identity
class
NMSKeep
(
OpNode
):
type
=
"NMSKeep"
opdef
=
builtin
.
NMSKeep
# class ParamPackSplit
# class ParamPackConcat
class
Dimshuffle
(
OpNode
):
type
=
"Dimshuffle"
opdef
=
builtin
.
Dimshuffle
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
Dimshuffle
,
cls
).
load
(
opr
)
del
obj
.
params
[
"ndim"
]
return
obj
class
Reshape
(
OpNode
):
type
=
"Reshape"
opdef
=
builtin
.
Reshape
class
AxisAddRemove
(
OpNode
):
type
=
"AxisAddRemove"
@
classmethod
def
load
(
cls
,
opr
):
obj
=
cls
()
obj
.
name
=
opr
.
name
obj
.
_opr
=
opr
params
=
json
.
loads
(
opr
.
params
)
desc
=
params
[
"desc"
]
method
=
None
axis
=
[]
for
i
in
desc
:
if
method
is
None
:
method
=
i
[
"method"
]
assert
method
==
i
[
"method"
]
axis
.
append
(
i
[
"axisnum"
])
obj
.
params
=
{
"axis"
:
axis
}
obj
.
opdef
=
builtin
.
AddAxis
if
desc
[
0
][
"method"
]
==
0
else
builtin
.
RemoveAxis
return
obj
class
IndexingBase
(
OpNode
):
@
classmethod
def
load
(
cls
,
opr
):
obj
=
cls
()
obj
.
name
=
opr
.
name
obj
.
_opr
=
opr
params
=
json
.
loads
(
opr
.
params
)
items
=
[
[
p
[
"axis"
],
bool
(
p
[
"begin"
]),
bool
(
p
[
"end"
]),
bool
(
p
[
"step"
]),
bool
(
p
[
"idx"
]),
]
for
p
in
params
]
obj
.
params
[
"items"
]
=
items
return
obj
class
Subtensor
(
IndexingBase
):
type
=
"Subtensor"
opdef
=
builtin
.
Subtensor
class
SetSubtensor
(
IndexingBase
):
type
=
"SetSubtensor"
opdef
=
builtin
.
SetSubtensor
class
IncrSubtensor
(
IndexingBase
):
type
=
"IncrSubtensor"
opdef
=
builtin
.
IncrSubtensor
class
IndexingMultiAxisVec
(
IndexingBase
):
type
=
"IndexingMultiAxisVec"
opdef
=
builtin
.
IndexingMultiAxisVec
class
IndexingSetMultiAxisVec
(
IndexingBase
):
type
=
"IndexingSetMultiAxisVec"
opdef
=
builtin
.
IndexingSetMultiAxisVec
class
IndexingIncrMultiAxisVec
(
IndexingBase
):
type
=
"IndexingIncrMultiAxisVec"
opdef
=
builtin
.
IndexingIncrMultiAxisVec
class
MeshIndexing
(
IndexingBase
):
type
=
"MeshIndexing"
opdef
=
builtin
.
MeshIndexing
class
SetMeshIndexing
(
IndexingBase
):
type
=
"SetMeshIndexing"
opdef
=
builtin
.
SetMeshIndexing
class
IncrMeshIndexing
(
IndexingBase
):
type
=
"IncrMeshIndexing"
opdef
=
builtin
.
IncrMeshIndexing
class
BatchedMeshIndexing
(
IndexingBase
):
type
=
"BatchedMeshIndexing"
opdef
=
builtin
.
BatchedMeshIndexing
class
BatchedSetMeshIndexing
(
IndexingBase
):
type
=
"BatchedSetMeshIndexing"
opdef
=
builtin
.
BatchedSetMeshIndexing
class
BatchedIncrMeshIndexing
(
IndexingBase
):
type
=
"BatchedIncrMeshIndexing"
opdef
=
builtin
.
BatchedIncrMeshIndexing
# class CollectiveComm
# class RemoteSend
# class RemoteRecv
# class TQT
# class FakeQuant
# class InplaceAdd
class
AssertEqual
(
OpNode
):
type
=
"AssertEqual"
opdef
=
builtin
.
AssertEqual
class
ElemwiseMultiType
(
OpNode
):
type
=
"ElemwiseMultiType"
opdef
=
builtin
.
ElemwiseMultiType
@
classmethod
def
load
(
cls
,
opr
):
obj
=
super
(
ElemwiseMultiType
,
cls
).
load
(
opr
)
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
class
CvtColorForward
(
OpNode
):
type
=
"CvtColor"
opdef
=
builtin
.
CvtColor
imperative/python/src/tensor.cpp
浏览文件 @
ae3123b3
...
...
@@ -160,6 +160,16 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if
(
ctx
.
op
->
same_type
<
BackwardGraph
>
())
{
ctx
.
backward
=
true
;
}
if
(
py
::
isinstance
<
cg
::
VarNode
>
(
py
::
handle
(
args
[
0
]))){
SmallVector
<
cg
::
VarNode
*>
vinputs
(
nargs
);
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
vinputs
[
i
]
=
py
::
handle
(
args
[
i
]).
cast
<
cg
::
VarNode
*>
();
}
auto
op
=
ctx
.
op
.
get
();
return
to_tuple
(
OpDef
::
apply_on_var_node
(
*
op
,
vinputs
)).
release
().
ptr
();
}
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
args
[
i
]))
{
...
...
@@ -675,6 +685,16 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
tensors
.
emplace_back
(
descr
);
continue
;
}
if
(
py
::
isinstance
<
cg
::
VarNode
>
(
py
::
handle
(
handle
))){
auto
var
=
py
::
handle
(
handle
).
cast
<
cg
::
VarNode
*>
();
mgb
::
DType
type
=
var
->
dtype
();
auto
&&
descr
=
npy
::
dtype_mgb2np_descr
(
type
);
Py_INCREF
(
descr
.
get
());
tensors
.
emplace_back
(
descr
.
get
());
continue
;
}
PyArray_Descr
*
descr
=
scalar2dtype
(
handle
);
if
(
descr
)
{
scalars
.
emplace_back
(
descr
);
...
...
@@ -719,12 +739,14 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
PyObject
*
handle
=
is_tuple
?
PyTuple_GetItem
(
tuple
,
i
)
:
args
[
i
];
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
handle
);
if
(
tw
)
{
bool
is_var
=
py
::
isinstance
<
cg
::
VarNode
>
(
py
::
handle
(
handle
));
if
(
tw
||
is_var
)
{
if
(
!
valid
)
{
cn
=
tw
->
m_tensor
->
comp_node
();
cn
=
tw
?
tw
->
m_tensor
->
comp_node
()
:
py
::
handle
(
handle
).
cast
<
cg
::
VarNode
*>
()
->
comp_node
();
valid
=
true
;
}
else
{
CompNode
cn1
=
tw
->
m_tensor
->
comp_node
();
CompNode
cn1
=
tw
?
tw
->
m_tensor
->
comp_node
()
:
py
::
handle
(
handle
).
cast
<
cg
::
VarNode
*>
()
->
comp_node
();
if
(
cn1
!=
cn
)
{
throw
py
::
value_error
(
ssprintf
(
"ambiguous device: %s vs %s"
,
cn
.
to_string
().
c_str
(),
cn1
.
to_string
().
c_str
()));
...
...
imperative/python/test/unit/utils/test_network.py
0 → 100644
浏览文件 @
ae3123b3
import
io
import
numpy
as
np
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.utils.network_node
as
N
from
megengine.jit.tracing
import
trace
from
megengine.tensor
import
Tensor
from
megengine.utils.comp_graph_tools
import
GraphInference
from
megengine.utils.network
import
Network
as
Net
from
megengine.utils.network
import
as_oprnode
from
megengine.utils.network_node
import
Host2DeviceCopy
,
VarNode
def
test_replace_var
():
a
=
Tensor
([
1
,
2
])
b
=
Tensor
([
3
,
4
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
vara
=
graph
.
var_filter
.
name
(
"a"
).
as_unique
()
varb
=
graph
.
var_filter
.
name
(
"b"
).
as_unique
()
out
=
F
.
mul
(
vara
.
var
,
varb
.
var
)
out
=
F
.
relu
(
out
)
var_list
=
graph
.
add_dep_oprs
(
out
)
opnode
=
list
(
graph
.
opr_filter
.
has_input
(
vara
))
repl_dict
=
{
opnode
[
0
].
outputs
[
0
]:
var_list
[
0
]}
graph
.
replace_vars
(
repl_dict
)
modified_model
=
io
.
BytesIO
()
graph
.
dump
(
modified_model
)
modified_model
.
seek
(
0
)
load_graph
=
GraphInference
(
modified_model
)
out
=
load_graph
.
run
(
a
,
b
)
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
6
,
16
])
def
test_replace_opr
():
a
=
Tensor
([
1
,
2
])
b
=
Tensor
([
3
,
4
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
vara
=
graph
.
var_filter
.
name
(
"a"
).
as_unique
()
varb
=
graph
.
var_filter
.
name
(
"b"
).
as_unique
()
out1
=
F
.
sub
(
vara
.
var
,
varb
.
var
)
out1
=
F
.
relu
(
out1
)
var_list
=
graph
.
add_dep_oprs
(
out1
)
repl_opr
=
as_oprnode
(
var_list
)
orig_opr
=
graph
.
opr_filter
.
has_input
(
vara
).
as_unique
()
repl_dict
=
{
orig_opr
:
repl_opr
}
graph
.
replace_oprs
(
repl_dict
)
modified_model1
=
io
.
BytesIO
()
graph
.
dump
(
modified_model1
)
modified_model1
.
seek
(
0
)
load_graph
=
GraphInference
(
modified_model1
)
out
=
load_graph
.
run
(
a
,
b
)
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
0
,
0
])
def
test_modify_params
():
a
=
Tensor
([
1
,
2
])
b
=
Tensor
([
3
,
4
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
param_const
=
graph
.
params_filter
.
as_unique
()
param_const
.
set_value
(
3
)
modified_model
=
io
.
BytesIO
()
graph
.
dump
(
modified_model
)
modified_model
.
seek
(
0
)
load_graph
=
GraphInference
(
modified_model
)
out
=
load_graph
.
run
(
a
,
b
)
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
12
,
18
])
def
test_make_const
():
a
=
Tensor
([
1
,
2
])
b
=
Tensor
([
3
,
4
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
const_b
=
graph
.
make_const
(
np
.
array
([
0.0
,
0.0
]),
name
=
"b"
)
varb
=
graph
.
var_filter
.
name
(
"b"
).
as_unique
()
repl_dict
=
{
varb
:
const_b
}
graph
.
replace_vars
(
repl_dict
)
modified_model
=
io
.
BytesIO
()
graph
.
dump
(
modified_model
)
modified_model
.
seek
(
0
)
load_graph
=
GraphInference
(
modified_model
)
out
=
load_graph
.
run
(
a
)
np
.
testing
.
assert_equal
(
out
[
"o"
],
[
2
,
4
])
def
test_add_input
():
a
=
Tensor
([
1
,
2
])
b
=
Tensor
([
3
,
4
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
inp_c
=
graph
.
make_input_node
((
2
,),
np
.
int32
,
name
=
"c"
)
varo
=
graph
.
var_filter
.
name
(
"o"
).
as_unique
()
out
=
F
.
add
(
varo
.
var
,
inp_c
.
var
)
out
=
graph
.
add_dep_oprs
(
out
)[
0
]
out
.
name
=
"o1"
graph
.
remove_output
(
varo
)
graph
.
add_output
(
out
)
modified_model
=
io
.
BytesIO
()
graph
.
dump
(
modified_model
)
modified_model
.
seek
(
0
)
load_graph
=
GraphInference
(
modified_model
)
out
=
load_graph
.
run
(
a
,
b
,
a
)
np
.
testing
.
assert_equal
(
out
[
"o1"
],
((
a
+
b
)
*
2
+
a
).
numpy
())
def
test_add_output
():
a
=
Tensor
([
1.0
,
2.0
])
b
=
Tensor
([
3.0
,
4.0
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
a
,
b
):
return
(
a
+
b
)
*
2
fwd
(
a
,
b
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"a"
,
"b"
],
output_names
=
"o"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
net
=
Net
.
load
(
orig_model
)
var_a
=
net
.
var_filter
.
name
(
"a"
).
as_unique
()
var_b
=
net
.
var_filter
.
name
(
"b"
).
as_unique
()
y
=
F
.
add
(
var_a
.
var
,
var_b
.
var
)
y
=
F
.
sigmoid
(
y
)
new_vars
=
net
.
add_dep_oprs
(
y
)[
0
]
new_vars
.
name
=
"o1"
net
.
add_output
(
new_vars
)
modified_model
=
io
.
BytesIO
()
net
.
dump
(
modified_model
)
modified_model
.
seek
(
0
)
g
=
GraphInference
(
modified_model
)
out
=
g
.
run
(
a
.
numpy
(),
b
.
numpy
())
np
.
testing
.
assert_equal
(
out
[
"o"
],
((
a
+
b
)
*
2
).
numpy
())
np
.
testing
.
assert_equal
(
out
[
"o1"
],
(
F
.
sigmoid
((
a
+
b
))).
numpy
())
def
test_query
():
class
Model
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
M
.
Conv2d
(
3
,
32
,
3
)
self
.
conv2
=
M
.
Conv2d
(
32
,
32
,
3
)
self
.
conv3
=
M
.
Conv2d
(
32
,
32
,
3
)
def
forward
(
self
,
data
):
x
=
self
.
conv1
(
data
)
x
=
self
.
conv2
(
x
)
x
=
self
.
conv3
(
x
)
return
x
n
=
Model
()
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
n
(
data
)
fwd
(
Tensor
(
np
.
random
.
random
((
1
,
3
,
224
,
224
))))
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
arg_names
=
[
"data"
],
output_names
=
"o"
,
keep_opr_name
=
True
,
keep_var_name
=
True
,
optimize_for_inference
=
False
,
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
r
=
graph
.
data_providers_filter
.
as_count
()
assert
r
==
1
opr
=
graph
.
get_opr_by_type
(
Host2DeviceCopy
)
assert
isinstance
(
opr
,
Host2DeviceCopy
)
r1
=
graph
.
params_filter
.
as_count
()
assert
r1
==
6
r2
=
graph
.
opr_filter
.
type
(
N
.
ConvolutionForward
).
as_count
()
assert
r2
==
3
r3
=
graph
.
opr_filter
.
not_type
(
N
.
ConvolutionForward
).
as_count
()
assert
r3
==
len
(
graph
.
all_oprs
)
-
r2
var
=
graph
.
var_filter
.
name
(
"data"
).
as_unique
()
r4
=
graph
.
opr_filter
.
has_input
(
var
).
as_count
()
assert
r4
==
1
r5
=
graph
.
opr_filter
.
name
(
"data"
).
as_count
()
assert
r5
==
1
opr
=
graph
.
get_opr_by_name
(
"data"
)
assert
isinstance
(
opr
,
Host2DeviceCopy
)
var
=
graph
.
get_var_by_name
(
"data"
)
assert
isinstance
(
var
,
VarNode
)
r6
=
graph
.
var_filter
.
name
(
"*bias"
).
as_count
()
assert
r6
==
3
def
test_optimize_for_inference
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
return
F
.
exp
(
x
)
orig_model
=
io
.
BytesIO
()
f
(
Tensor
(
5.0
))
f
.
dump
(
orig_model
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
optimize_model
=
io
.
BytesIO
()
net
=
Net
.
load
(
orig_model
)
net
.
dump
(
optimize_model
,
enable_io16xc32
=
True
)
optimize_model
.
seek
(
0
)
res
=
G
.
load_graph
(
optimize_model
)
computing_input
=
res
.
output_vars_list
[
0
].
owner
.
inputs
[
0
]
assert
computing_input
.
dtype
==
np
.
float16
def
test_reset_batchsize
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
return
F
.
exp
(
x
)
orig_model
=
io
.
BytesIO
()
f
(
Tensor
(
np
.
random
.
random
((
3
,
3
,
224
,
224
))))
f
.
dump
(
orig_model
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
modified_model
=
io
.
BytesIO
()
net
=
Net
.
load
(
orig_model
)
net
.
reset_batch_size
(
1
)
net
.
dump
(
modified_model
,
optimize_for_inference
=
False
)
modified_model
.
seek
(
0
)
net1
=
Net
.
load
(
modified_model
)
assert
net1
.
data_providers_filter
.
as_unique
().
shape
[
0
]
==
1
def
test_modify_opr_name
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
return
F
.
exp
(
x
)
orig_model
=
io
.
BytesIO
()
f
(
Tensor
(
np
.
random
.
random
((
3
,
3
,
224
,
224
))))
f
.
dump
(
orig_model
,
arg_names
=
[
"a"
],
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
modified_model
=
io
.
BytesIO
()
net
=
Net
.
load
(
orig_model
)
net
.
modify_opr_names
(
"net"
)
net
.
modify_opr_names
(
lambda
x
:
"net1."
+
x
)
net
.
dump
(
modified_model
,
optimize_for_inference
=
False
)
modified_model
.
seek
(
0
)
net1
=
Net
.
load
(
modified_model
)
assert
net1
.
data_providers_filter
.
as_unique
().
name
==
"net1.net.a"
imperative/python/test/unit/utils/test_opr.py
0 → 100644
浏览文件 @
ae3123b3
import
io
import
os
import
platform
import
numpy
as
np
import
pytest
import
megengine.core.tensor.dtype
as
dtype
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.random
as
rand
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core._wrap
import
Device
from
megengine.core.ops
import
builtin
from
megengine.device
import
is_cuda_available
from
megengine.functional.external
import
tensorrt_runtime_opr
from
megengine.jit.tracing
import
trace
from
megengine.tensor
import
Tensor
from
megengine.utils.comp_graph_tools
import
GraphInference
from
megengine.utils.network
import
Network
as
Net
def
check_pygraph_dump
(
trace_func
,
inp_data
,
expect_results
):
orig_model
=
io
.
BytesIO
()
inp_size
=
len
(
inp_data
)
out_size
=
len
(
expect_results
)
arg_names
=
[
"arg_{}"
.
format
(
i
)
for
i
in
range
(
inp_size
)]
output_names
=
[
"out_{}"
.
format
(
i
)
for
i
in
range
(
out_size
)]
trace_func
.
dump
(
orig_model
,
arg_names
=
arg_names
,
output_names
=
output_names
,
optimize_for_inference
=
False
,
)
orig_model
.
seek
(
0
)
net
=
Net
.
load
(
orig_model
)
file
=
io
.
BytesIO
()
net
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
graph
=
GraphInference
(
file
)
inp_dict
=
dict
([(
arg_names
[
i
],
inp_data
[
i
].
numpy
())
for
i
in
range
(
inp_size
)])
results
=
graph
.
run
(
inp_dict
=
inp_dict
)
for
ind
,
tensor
in
enumerate
(
expect_results
):
np
.
testing
.
assert_equal
(
tensor
.
numpy
(),
results
[
output_names
[
ind
]])
assert
tensor
.
dtype
==
results
[
output_names
[
ind
]].
dtype
def
test_elemwise
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
,
y
):
z1
=
x
*
y
z2
=
x
+
y
z3
=
z1
/
z2
z3
=
z3
**
3
return
z3
x
=
Tensor
([
1.0
,
2.0
])
y
=
Tensor
([
3.0
,
5.0
])
result
=
fwd
(
x
,
y
)
check_pygraph_dump
(
fwd
,
[
x
,
y
],
[
result
])
def
test_reduce
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
x
=
data
.
sum
(
axis
=
2
)
x
=
x
.
mean
(
axis
=
1
)
return
x
data
=
Tensor
(
np
.
random
.
random
((
1
,
32
,
32
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_typecvt
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
data
.
astype
(
dtype
.
qint8
(
0.8
))
x
=
Tensor
(
np
.
random
.
random
((
2
,
3
))
*
255
)
result
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
result
])
def
test_matinv
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
F
.
matinv
(
data
)
data
=
Tensor
(
np
.
random
.
random
((
5
,
5
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_matmul
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data1
,
data2
):
return
F
.
matmul
(
data1
,
data2
)
data1
=
Tensor
(
np
.
random
.
random
((
32
,
64
)))
data2
=
Tensor
(
np
.
random
.
random
((
64
,
16
)))
result
=
fwd
(
data1
,
data2
)
check_pygraph_dump
(
fwd
,
[
data1
,
data2
],
[
result
])
def
test_batchmatmul
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
,
y
):
return
F
.
matmul
(
x
,
y
)
x
=
Tensor
(
np
.
random
.
random
((
3
,
3
,
5
)))
y
=
Tensor
(
np
.
random
.
random
((
3
,
5
,
3
)))
result
=
fwd
(
x
,
y
)
check_pygraph_dump
(
fwd
,
[
x
,
y
],
[
result
])
def
test_dot
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
,
y
):
return
F
.
dot
(
x
,
y
)
x
=
Tensor
([
1.0
,
2.0
,
3.0
])
y
=
Tensor
([
3.0
,
4.0
,
5.0
])
result
=
fwd
(
x
,
y
)
check_pygraph_dump
(
fwd
,
[
x
,
y
],
[
result
])
def
test_svd
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
_
,
out
,
_
=
F
.
svd
(
data
)
return
out
input
=
Tensor
(
np
.
random
.
random
((
1
,
1
,
3
,
3
)))
result
=
fwd
(
input
)
check_pygraph_dump
(
fwd
,
[
input
],
[
result
])
def
test_conv
():
conv
=
M
.
Conv2d
(
3
,
32
,
3
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
conv
(
data
)
data
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_deformable_conv
():
if
not
is_cuda_available
():
return
conv
=
M
.
DeformableConv2d
(
3
,
32
,
3
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
,
offset
,
mask
):
return
conv
(
data
,
offset
,
mask
)
data
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
offset
=
Tensor
(
np
.
ones
((
32
,
3
*
3
*
2
,
30
,
30
)).
astype
(
"int32"
)
*
5
)
mask
=
Tensor
(
np
.
ones
((
32
,
3
*
3
,
30
,
30
)).
astype
(
"int32"
))
out
=
fwd
(
data
,
offset
,
mask
)
check_pygraph_dump
(
fwd
,
[
data
,
offset
,
mask
],
[
out
])
def
test_convtranspose
():
deconv
=
M
.
ConvTranspose2d
(
32
,
32
,
3
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
deconv
(
data
)
data
=
Tensor
(
np
.
random
.
random
((
1
,
32
,
32
,
32
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
@
pytest
.
mark
.
skip
(
reason
=
"pytest aborted"
)
def
test_grouplocal
():
n
=
M
.
LocalConv2d
(
3
,
32
,
32
,
32
,
3
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
n
(
data
)
input
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
result
=
fwd
(
input
)
check_pygraph_dump
(
fwd
,
[
input
],
[
result
])
def
test_pooling
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
out
=
F
.
max_pool2d
(
data
,
2
,
2
)
out
=
F
.
avg_pool2d
(
out
,
2
,
2
)
return
out
data
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
64
,
64
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_adaptivepooling
():
pool1
=
M
.
AdaptiveMaxPool2d
((
2
,
2
))
pool2
=
M
.
AdaptiveAvgPool2d
((
2
,
2
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
out
=
pool1
(
data
)
out
=
pool2
(
out
)
return
out
input
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
result
=
fwd
(
input
)
check_pygraph_dump
(
fwd
,
[
input
],
[
result
])
def
test_roipooling
():
inp
=
Tensor
(
np
.
random
.
random
((
1
,
1
,
128
,
128
)))
rois
=
Tensor
(
np
.
random
.
random
((
4
,
5
)))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
rois
):
return
F
.
nn
.
roi_pooling
(
inp
,
rois
,
(
2
,
2
),
scale
=
2.0
)
output
=
fwd
(
inp
,
rois
)
check_pygraph_dump
(
fwd
,
[
inp
,
rois
],
[
output
])
def
test_deformable_ps_roi_pooling
():
inp
=
Tensor
(
np
.
random
.
random
((
1
,
256
,
64
,
64
)).
astype
(
"float32"
))
rois
=
Tensor
(
np
.
random
.
random
((
1
,
5
)).
astype
(
"float32"
))
trans
=
Tensor
(
np
.
random
.
random
((
24
,
2
,
7
,
7
)).
astype
(
"float32"
))
pooled_h
=
7
pooled_w
=
7
sample_per_part
=
4
no_trans
=
False
part_size
=
7
spatial_scale
=
1.0
/
64
trans_std
=
0.1
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
rois
,
trans
):
y
=
F
.
deformable_psroi_pooling
(
inp
,
rois
,
trans
,
no_trans
,
part_size
,
pooled_h
,
pooled_w
,
sample_per_part
,
spatial_scale
,
trans_std
,
)
return
y
result
=
fwd
(
inp
,
rois
,
trans
)
check_pygraph_dump
(
fwd
,
[
inp
,
rois
,
trans
],
[
result
])
def
test_convbias
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
weight
,
bias
):
return
F
.
quantized
.
conv_bias_activation
(
inp
,
weight
,
bias
,
dtype
=
dtype
.
qint8
(
scale
=
1.0
),
nonlinear_mode
=
"RELU"
)
inp
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
64
,
64
)),
dtype
=
dtype
.
qint8
(
scale
=
1.0
))
weight
=
Tensor
(
np
.
random
.
random
((
32
,
3
,
3
,
3
)),
dtype
=
dtype
.
qint8
(
scale
=
1.0
))
bias
=
Tensor
(
np
.
random
.
random
((
1
,
32
,
1
,
1
)),
dtype
=
dtype
.
qint32
(
scale
=
1.0
))
result
=
fwd
(
inp
,
weight
,
bias
)
check_pygraph_dump
(
fwd
,
[
inp
,
weight
,
bias
],
[
result
])
def
test_batch_convbias
():
if
is_cuda_available
():
return
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
weight
,
bias
):
return
F
.
quantized
.
batch_conv_bias_activation
(
inp
,
weight
,
bias
,
dtype
=
dtype
.
qint8
(
scale
=
1.0
),
nonlinear_mode
=
"RELU"
)
inp
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
64
,
64
)),
dtype
=
dtype
.
qint8
(
scale
=
1.0
))
weight
=
Tensor
(
np
.
random
.
random
((
1
,
32
,
3
,
3
,
3
)),
dtype
=
dtype
.
qint8
(
scale
=
1.0
))
bias
=
Tensor
(
np
.
random
.
random
((
1
,
32
,
1
,
1
)),
dtype
=
dtype
.
qint32
(
scale
=
1.0
))
result
=
fwd
(
inp
,
weight
,
bias
)
check_pygraph_dump
(
fwd
,
[
inp
,
weight
,
bias
],
[
result
])
def
test_batchnorm
():
bn
=
M
.
BatchNorm2d
(
32
)
bn
.
eval
()
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
bn
(
data
)
data
=
Tensor
(
np
.
random
.
random
((
1
,
32
,
32
,
32
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_roialign
():
inp
=
Tensor
(
np
.
random
.
randn
(
1
,
1
,
128
,
128
))
rois
=
Tensor
(
np
.
random
.
random
((
4
,
5
)))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
rois
):
return
F
.
nn
.
roi_align
(
inp
,
rois
,
(
2
,
2
))
output
=
fwd
(
inp
,
rois
)
check_pygraph_dump
(
fwd
,
[
inp
,
rois
],
[
output
])
def
test_warpperspective
():
inp_shape
=
(
1
,
1
,
4
,
4
)
x
=
Tensor
(
np
.
arange
(
16
,
dtype
=
np
.
float32
).
reshape
(
inp_shape
))
M_shape
=
(
1
,
3
,
3
)
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1)
M
=
Tensor
(
np
.
array
(
[[
1.0
,
0.0
,
1.0
],
[
0.0
,
1.0
,
1.0
],
[
0.0
,
0.0
,
1.0
]],
dtype
=
np
.
float32
).
reshape
(
M_shape
)
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
,
M
):
return
F
.
warp_perspective
(
x
,
M
,
(
2
,
2
))
result
=
fwd
(
x
,
M
)
check_pygraph_dump
(
fwd
,
[
x
,
M
],
[
result
])
def
test_warpaffine
():
inp_shape
=
(
1
,
3
,
3
,
3
)
x
=
Tensor
(
np
.
arange
(
27
,
dtype
=
np
.
float32
).
reshape
(
inp_shape
))
weightv
=
Tensor
([[[
1.26666667
,
0.6
,
-
83.33333333
],
[
-
0.33333333
,
1
,
66.66666667
]]])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
,
weightv
):
return
F
.
warp_affine
(
x
,
weightv
,
(
2
,
2
),
border_mode
=
"WRAP"
)
outp
=
fwd
(
x
,
weightv
)
check_pygraph_dump
(
fwd
,
[
x
,
weightv
],
[
outp
])
def
test_remap
():
inp_shape
=
(
1
,
1
,
4
,
4
)
inp
=
Tensor
(
np
.
arange
(
16
,
dtype
=
np
.
float32
).
reshape
(
inp_shape
))
map_xy_shape
=
(
1
,
2
,
2
,
2
)
map_xy
=
Tensor
(
np
.
array
(
[[[
1.0
,
0.0
],
[
0.0
,
1.0
]],
[[
0.0
,
1.0
],
[
0.0
,
1.0
]]],
dtype
=
np
.
float32
).
reshape
(
map_xy_shape
)
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
map_xy
):
return
F
.
remap
(
inp
,
map_xy
)
out
=
fwd
(
inp
,
map_xy
)
check_pygraph_dump
(
fwd
,
[
inp
,
map_xy
],
[
out
])
def
test_resize
():
x
=
Tensor
(
np
.
random
.
randn
(
10
,
3
,
32
,
32
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
):
return
F
.
nn
.
interpolate
(
x
,
size
=
(
16
,
16
),
mode
=
"BILINEAR"
)
out
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
out
])
def
test_index_onehot
():
src
=
Tensor
([[
1.0
,
2.0
]])
index
=
Tensor
([
0
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
src
,
index
):
return
F
.
indexing_one_hot
(
src
,
index
)
out
=
fwd
(
src
,
index
)
check_pygraph_dump
(
fwd
,
[
src
,
index
],
[
out
])
def
test_set_onehot
():
x
=
Tensor
(
np
.
arange
(
1
,
4
,
dtype
=
np
.
int32
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
):
return
F
.
one_hot
(
x
,
num_classes
=
4
)
out
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
out
])
def
test_copy
():
x
=
Tensor
([
1
,
2
,
3
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
):
return
x
.
to
(
"cpu0:0"
)
o
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
o
])
def
test_argsort
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
F
.
argsort
(
data
,
True
)
data
=
Tensor
([
1.0
,
2.0
,
3.0
,
5.0
])
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_argmax_min
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
F
.
argmax
(
data
),
F
.
argmin
(
data
)
data
=
Tensor
(
np
.
random
.
random
((
10
,
10
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
result
)
def
test_condtake
():
mask
=
Tensor
(
np
.
array
([[
True
,
False
],
[
False
,
True
]],
dtype
=
np
.
bool_
))
x
=
Tensor
(
np
.
array
([[
1
,
np
.
inf
],
[
np
.
nan
,
4
]],
dtype
=
np
.
float32
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
mask
,
x
):
v
,
index
=
F
.
cond_take
(
mask
,
x
)
return
v
,
index
v
,
index
=
fwd
(
mask
,
x
)
check_pygraph_dump
(
fwd
,
[
mask
,
x
],
[
v
,
index
])
def
test_topk
():
x
=
Tensor
(
np
.
array
([
2
,
4
,
6
,
8
,
7
,
5
,
3
,
1
],
dtype
=
np
.
float32
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
):
top
,
indices
=
F
.
topk
(
x
,
5
)
return
top
,
indices
top
,
indices
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
top
,
indices
])
def
test_random
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
():
x
=
rand
.
uniform
(
size
=
(
2
,
2
))
y
=
rand
.
normal
(
size
=
(
1
,
3
,
3
,
3
))
return
x
,
y
x
,
y
=
fwd
()
check_pygraph_dump
(
fwd
,
[],
[
x
,
y
])
def
test_tensor_gen
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
():
a
=
F
.
linspace
(
3
,
10
,
3
,
device
=
Device
(
"xpux"
).
to_c
())
b
=
F
.
eye
(
3
,
device
=
Device
(
"xpux"
).
to_c
())
return
a
,
b
a
,
b
=
fwd
()
check_pygraph_dump
(
fwd
,
[],
[
a
,
b
])
def
test_getvarshape
():
op
=
builtin
.
GetVarShape
(
axis
=
1
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
apply
(
op
,
data
)[
0
]
data
=
Tensor
(
np
.
random
.
random
((
1
,
2
,
3
,
4
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_concat
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data1
,
data2
):
return
F
.
concat
([
data1
,
data2
],
axis
=
1
)
x
=
Tensor
(
np
.
random
.
random
((
2
,
3
)))
y
=
Tensor
(
np
.
random
.
random
((
2
,
5
)))
result
=
fwd
(
x
,
y
)
check_pygraph_dump
(
fwd
,
[
x
,
y
],
[
result
])
def
test_broadcast
():
inp
=
Tensor
([[
1
],
[
2
],
[
3
],
[
4
]])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
):
return
F
.
broadcast_to
(
inp
,
(
4
,
4
))
out
=
fwd
(
inp
)
check_pygraph_dump
(
fwd
,
[
inp
],
[
out
])
def
test_identity
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
F
.
copy
(
data
)
data
=
Tensor
([
1.0
,
2.0
])
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
@
pytest
.
mark
.
skip
(
reason
=
"advance indexing trace error"
)
def
test_nms
():
x
=
np
.
zeros
((
100
,
4
))
np
.
random
.
seed
(
42
)
x
[:,
:
2
]
=
np
.
random
.
rand
(
100
,
2
)
*
20
x
[:,
2
:]
=
np
.
random
.
rand
(
100
,
2
)
*
20
+
100
scores
=
Tensor
(
np
.
random
.
rand
(
100
))
inp
=
Tensor
(
x
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
,
scores
):
return
F
.
nn
.
nms
(
inp
,
scores
,
iou_thresh
=
0.7
,
max_output
=
3
)
result
=
fwd
(
inp
,
scores
)
check_pygraph_dump
(
fwd
,
[
inp
,
scores
],
[
result
])
def
test_dimshuffle
():
inp
=
Tensor
([
1
,
2
,
3
,
4
])
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
):
return
inp
.
T
out
=
fwd
(
inp
)
check_pygraph_dump
(
fwd
,
[
inp
],
[
out
])
def
test_reshape
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
return
data
.
reshape
((
1
,
8
))
data
=
Tensor
(
np
.
random
.
random
((
1
,
2
,
2
,
2
)))
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
def
test_add_remove_axis
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
data
):
x
=
F
.
expand_dims
(
data
,
[
0
,
0
])
y
=
F
.
squeeze
(
x
,
0
)
return
y
data
=
Tensor
([
1.0
,
2.0
])
result
=
fwd
(
data
)
check_pygraph_dump
(
fwd
,
[
data
],
[
result
])
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"get"
,
"set"
,
"inc"
])
def
test_subtensor
(
mode
):
items
=
[[
0
,
True
,
True
,
True
,
False
],
[
1
,
False
,
False
,
False
,
True
]]
data
=
[
Tensor
(
np
.
random
.
random
((
5
,
5
))),
Tensor
(
np
.
random
.
random
(
2
))]
if
mode
==
"get"
:
op
=
builtin
.
Subtensor
(
items
)
data
=
data
[:
1
]
if
mode
==
"set"
:
op
=
builtin
.
SetSubtensor
(
items
)
if
mode
==
"inc"
:
op
=
builtin
.
IncrSubtensor
(
items
)
tensors
=
[
Tensor
(
0
),
Tensor
(
4
),
Tensor
(
2
),
Tensor
(
3
)]
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
*
tensors
):
return
apply
(
op
,
*
tensors
)[
0
]
result
=
fwd
(
*
data
,
*
tensors
)
check_pygraph_dump
(
fwd
,
data
+
tensors
,
[
result
])
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"get"
,
"set"
,
"inc"
])
def
test_advance_indexing
(
mode
):
items
=
[[
0
,
False
,
False
,
False
,
True
]]
tensors
=
[
Tensor
([
0
,
4
,
2
])]
data
=
[
Tensor
(
np
.
random
.
random
((
5
,
5
))),
Tensor
(
np
.
random
.
random
((
3
,
5
)))]
if
mode
==
"get"
:
op
=
builtin
.
IndexingMultiAxisVec
(
items
)
data
=
data
[:
1
]
if
mode
==
"set"
:
op
=
builtin
.
IndexingSetMultiAxisVec
(
items
)
if
mode
==
"inc"
:
op
=
builtin
.
IndexingIncrMultiAxisVec
(
items
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
*
tensors
):
return
apply
(
op
,
*
tensors
)[
0
]
result
=
fwd
(
*
data
,
*
tensors
)
check_pygraph_dump
(
fwd
,
data
+
tensors
,
[
result
])
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"get"
,
"set"
,
"inc"
])
def
test_mesh_indexing
(
mode
):
items
=
[[
0
,
True
,
True
,
True
,
False
],
[
1
,
False
,
False
,
False
,
True
]]
tensors
=
[
Tensor
(
0
),
Tensor
(
5
),
Tensor
(
2
),
Tensor
([
1
,
3
])]
data
=
[
Tensor
(
np
.
random
.
random
((
5
,
5
))),
Tensor
(
np
.
random
.
random
((
3
,
2
)))]
if
mode
==
"get"
:
op
=
builtin
.
IndexingMultiAxisVec
(
items
)
data
=
data
[:
1
]
if
mode
==
"set"
:
op
=
builtin
.
IndexingSetMultiAxisVec
(
items
)
if
mode
==
"inc"
:
op
=
builtin
.
IndexingIncrMultiAxisVec
(
items
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
*
tensors
):
return
apply
(
op
,
*
tensors
)[
0
]
result
=
fwd
(
*
data
,
*
tensors
)
check_pygraph_dump
(
fwd
,
data
+
tensors
,
[
result
])
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"get"
,
"set"
,
"inc"
])
def
test_batch_mesh_indexing
(
mode
):
items
=
[[
1
,
False
,
False
,
False
,
True
],
[
2
,
False
,
False
,
False
,
True
]]
tensors
=
[
Tensor
([[
0
,
2
],
[
0
,
2
]]),
Tensor
([[
0
,
1
,
2
],
[
1
,
2
,
3
]])]
data
=
[
Tensor
(
np
.
random
.
random
((
2
,
3
,
4
))),
Tensor
(
np
.
random
.
random
((
2
,
2
,
3
)))]
if
mode
==
"get"
:
op
=
builtin
.
BatchedMeshIndexing
(
items
)
data
=
data
[:
1
]
if
mode
==
"set"
:
op
=
builtin
.
BatchedSetMeshIndexing
(
items
)
if
mode
==
"inc"
:
op
=
builtin
.
BatchedIncrMeshIndexing
(
items
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
*
tensors
):
return
apply
(
op
,
*
tensors
)[
0
]
result
=
fwd
(
*
data
,
*
tensors
)
check_pygraph_dump
(
fwd
,
data
+
tensors
,
[
result
])
@
pytest
.
mark
.
skip
(
reason
=
"tmp skip"
)
def
test_assert_equal
():
g
=
G
.
Graph
()
inp1
=
g
.
make_h2d
(
dtype
=
np
.
float32
,
device
=
"xpux"
)
inp2
=
g
.
make_h2d
(
dtype
=
np
.
float32
,
device
=
"xpux"
)
op
=
builtin
.
AssertEqual
(
maxerr
=
1e-5
)
out
=
G
.
apply_normal_varnode
(
op
,
inp1
.
_node
,
inp2
.
_node
)[
0
]
print
(
out
)
g
.
compile
(
out
)
file
=
io
.
BytesIO
()
out_model
=
G
.
dump_graph
([
out
])
file
.
write
(
out_model
[
0
])
file
.
seek
(
0
)
net
=
Net
.
load
(
file
)
dump_file
=
io
.
BytesIO
()
net
.
dump
(
dump_file
)
dump_file
.
seek
(
0
)
g
=
GraphInference
(
dump_file
)
g
.
run
(
np
.
array
([
1.0
,
2.0
]),
np
.
array
([
1.0
,
2.0
]))
def
test_elemwise_multitype
():
op
=
builtin
.
ElemwiseMultiType
(
mode
=
"QADD"
,
dtype
=
dtype
.
qint32
(
2.0
))
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
,
y
):
return
apply
(
op
,
x
,
y
)[
0
]
x
=
Tensor
(
np
.
random
.
random
(
10
)
*
10
,
dtype
=
dtype
.
qint8
(
2.0
))
y
=
Tensor
(
np
.
random
.
random
(
10
)
*
10
,
dtype
=
dtype
.
qint8
(
2.0
))
result
=
fwd
(
x
,
y
)
check_pygraph_dump
(
fwd
,
[
x
,
y
],
[
result
])
def
test_cvtcolor
():
inp
=
np
.
random
.
randn
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
x
=
Tensor
(
inp
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
inp
):
return
F
.
img_proc
.
cvt_color
(
inp
,
mode
=
"RGB2GRAY"
)
result
=
fwd
(
x
)
check_pygraph_dump
(
fwd
,
[
x
],
[
result
])
src/plugin/impl/opr_footprint.cpp
浏览文件 @
ae3123b3
...
...
@@ -17,9 +17,20 @@
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/lrn.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/standalone/nms_opr.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/internal/indexing_helper.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#if MGB_ENABLE_JSON
#include "megdnn/opr_param_json.h"
#endif
...
...
@@ -354,7 +365,7 @@ uint64_t opr_footprint_func<opr::DeformableConvForward>(
auto
&&
out_shape
=
opr
->
output
()[
0
]
->
shape
();
auto
&&
filter_shape
=
opr
->
input
()[
1
]
->
shape
();
using
Param
=
opr
::
DeformableConvForward
::
Param
;
auto
&&
param
=
opr
->
cast_final_safe
<
opr
::
Convolution
>
().
param
();
auto
&&
param
=
opr
->
cast_final_safe
<
opr
::
DeformableConvForward
>
().
param
();
size_t
fh
,
fw
,
icpg
;
mgb_assert
(
param
.
format
==
Param
::
Format
::
NCHW
);
if
(
param
.
sparse
==
Param
::
Sparse
::
GROUP
)
{
...
...
@@ -425,9 +436,11 @@ uint64_t opr_footprint_func<opr::BatchConvBiasForward>(
auto
&&
filter_shape
=
opr
->
input
()[
1
]
->
shape
();
using
Param
=
opr
::
BatchConvBiasForward
::
Param
;
auto
&&
param
=
opr
->
cast_final_safe
<
opr
::
BatchConvBiasForward
>
().
param
();
mgb_assert
(
param
.
format
==
Param
::
Format
::
NCHW4
);
size_t
packed_channels
=
4
;
size_t
packed_channels
=
1
;
size_t
kern_spatial_pos
=
3
;
if
(
param
.
format
==
Param
::
Format
::
NCHW4
)
{
packed_channels
=
4
;
}
size_t
fh
=
filter_shape
[
kern_spatial_pos
],
fw
=
filter_shape
[
kern_spatial_pos
+
1
];
return
out_shape
.
total_nr_elems
()
*
fh
*
fw
*
src_shape
[
1
]
*
...
...
@@ -508,7 +521,29 @@ REGISTE_PARAM_JSON_FUNC(LocalShareBackwardFilter)
REGISTE_PARAM_JSON_FUNC
(
DeformableConvForward
)
REGISTE_PARAM_JSON_FUNC
(
DeformableConvBackwardFilter
)
REGISTE_PARAM_JSON_FUNC
(
DeformableConvBackwardData
)
REGISTE_PARAM_JSON_FUNC
(
DeformablePSROIPoolingForward
)
REGISTE_PARAM_JSON_FUNC
(
BatchConvBiasForward
)
REGISTE_PARAM_JSON_FUNC
(
BatchNormForward
)
REGISTE_PARAM_JSON_FUNC
(
ElemwiseMultiType
)
REGISTE_PARAM_JSON_FUNC
(
Argsort
)
REGISTE_PARAM_JSON_FUNC
(
Argmax
)
REGISTE_PARAM_JSON_FUNC
(
Argmin
)
REGISTE_PARAM_JSON_FUNC
(
AdaptivePooling
)
REGISTE_PARAM_JSON_FUNC
(
ROIPooling
)
REGISTE_PARAM_JSON_FUNC
(
ROIAlign
)
REGISTE_PARAM_JSON_FUNC
(
WarpPerspective
)
REGISTE_PARAM_JSON_FUNC
(
WarpAffine
)
REGISTE_PARAM_JSON_FUNC
(
Remap
)
REGISTE_PARAM_JSON_FUNC
(
Resize
)
REGISTE_PARAM_JSON_FUNC
(
IndexingOneHot
)
REGISTE_PARAM_JSON_FUNC
(
IndexingSetOneHot
)
REGISTE_PARAM_JSON_FUNC
(
TopK
)
REGISTE_PARAM_JSON_FUNC
(
UniformRNG
)
REGISTE_PARAM_JSON_FUNC
(
GaussianRNG
)
REGISTE_PARAM_JSON_FUNC
(
Linspace
)
REGISTE_PARAM_JSON_FUNC
(
Eye
)
REGISTE_PARAM_JSON_FUNC
(
CvtColor
)
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Dimshuffle
>
(
...
...
@@ -547,24 +582,83 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>(
});
}
std
::
shared_ptr
<
json
::
Value
>
indexing_param_to_json
(
const
std
::
vector
<
opr
::
indexing
::
AxisIndexer
>&
indices
)
{
auto
desc
=
json
::
Array
::
make
();
for
(
auto
&
index
:
indices
)
{
desc
->
add
(
json
::
Object
::
make
({
{
"axis"
,
json
::
NumberInt
::
make
(
index
.
axis
.
get_raw
())},
{
"begin"
,
json
::
NumberInt
::
make
(
index
.
begin
.
node
()
!=
nullptr
)},
{
"end"
,
json
::
NumberInt
::
make
(
index
.
end
.
node
()
!=
nullptr
)},
{
"step"
,
json
::
NumberInt
::
make
(
index
.
step
.
node
()
!=
nullptr
)},
{
"idx"
,
json
::
NumberInt
::
make
(
index
.
idx
.
node
()
!=
nullptr
)},
}));
}
return
desc
;
}
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \
template <> \
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \
cg::OperatorNodeBase * opr) { \
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \
return indexing_param_to_json(indices); \
}
REGISTE_INDEXING_PARAM_JSON_FUNC
(
Subtensor
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
SetSubtensor
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
IncrSubtensor
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
IndexingMultiAxisVec
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
IndexingSetMultiAxisVec
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
IndexingIncrMultiAxisVec
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
MeshIndexing
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
IncrMeshIndexing
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
SetMeshIndexing
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
BatchedMeshIndexing
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
BatchedIncrMeshIndexing
);
REGISTE_INDEXING_PARAM_JSON_FUNC
(
BatchedSetMeshIndexing
);
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Subtensor
>
(
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
Reshape
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
desc
=
json
::
Array
::
make
();
auto
indices
=
opr
->
cast_final_safe
<
opr
::
Subtensor
>
().
index_desc
();
for
(
auto
&
index
:
indices
){
desc
->
add
(
json
::
Object
::
make
({
{
"axis"
,
json
::
NumberInt
::
make
(
index
.
axis
.
get_raw
())},
{
"begin"
,
json
::
NumberInt
::
make
(
index
.
begin
.
node
()
!=
nullptr
)},
{
"end"
,
json
::
NumberInt
::
make
(
index
.
end
.
node
()
!=
nullptr
)},
{
"step"
,
json
::
NumberInt
::
make
(
index
.
step
.
node
()
!=
nullptr
)},
{
"idx"
,
json
::
NumberInt
::
make
(
index
.
idx
.
node
()
!=
nullptr
)},
}));
auto
axis_param
=
opr
->
cast_final_safe
<
opr
::
Reshape
>
().
param
();
if
(
axis_param
.
axis
!=
axis_param
.
MAX_NDIM
){
return
json
::
Object
::
make
({
{
"axis"
,
json
::
NumberInt
::
make
(
axis_param
.
axis
)},
});
}
else
{
return
json
::
Object
::
make
();
}
}
return
desc
;
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
GetVarShape
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
desc
=
json
::
Array
::
make
();
auto
axis_param
=
opr
->
cast_final_safe
<
opr
::
GetVarShape
>
().
param
();
if
(
axis_param
.
axis
!=
axis_param
.
MAX_NDIM
){
return
json
::
Object
::
make
({
{
"axis"
,
json
::
NumberInt
::
make
(
axis_param
.
axis
)},
});
}
else
{
return
json
::
Object
::
make
();
}
}
template
<
>
std
::
shared_ptr
<
json
::
Value
>
opr_param_json_func
<
opr
::
standalone
::
NMSKeep
>
(
cg
::
OperatorNodeBase
*
opr
)
{
auto
nms_param
=
opr
->
cast_final_safe
<
opr
::
standalone
::
NMSKeep
>
().
param
();
return
json
::
Object
::
make
({
{
"iou_thresh"
,
json
::
Number
::
make
(
nms_param
.
iou_thresh
)},
{
"max_output"
,
json
::
Number
::
make
(
nms_param
.
max_output
)},
});
}
#endif // MGB_ENABLE_JSON
}
// namespace
...
...
@@ -632,6 +726,17 @@ void OprFootprint::init_all_footprints() {
add_single_param_json
<
opr
::
Dimshuffle
>
();
add_single_param_json
<
opr
::
AxisAddRemove
>
();
add_single_param_json
<
opr
::
Subtensor
>
();
add_single_param_json
<
opr
::
SetSubtensor
>
();
add_single_param_json
<
opr
::
IncrSubtensor
>
();
add_single_param_json
<
opr
::
IndexingMultiAxisVec
>
();
add_single_param_json
<
opr
::
IndexingSetMultiAxisVec
>
();
add_single_param_json
<
opr
::
IndexingIncrMultiAxisVec
>
();
add_single_param_json
<
opr
::
MeshIndexing
>
();
add_single_param_json
<
opr
::
SetMeshIndexing
>
();
add_single_param_json
<
opr
::
IncrMeshIndexing
>
();
add_single_param_json
<
opr
::
BatchedMeshIndexing
>
();
add_single_param_json
<
opr
::
BatchedSetMeshIndexing
>
();
add_single_param_json
<
opr
::
BatchedIncrMeshIndexing
>
();
add_single_param_json
<
opr
::
Reduce
>
();
add_single_param_json
<
opr
::
LocalShareForward
>
();
add_single_param_json
<
opr
::
LocalShareBackwardData
>
();
...
...
@@ -639,7 +744,31 @@ void OprFootprint::init_all_footprints() {
add_single_param_json
<
opr
::
DeformableConvForward
>
();
add_single_param_json
<
opr
::
DeformableConvBackwardFilter
>
();
add_single_param_json
<
opr
::
DeformableConvBackwardData
>
();
add_single_param_json
<
opr
::
DeformablePSROIPoolingForward
>
();
add_single_param_json
<
opr
::
BatchConvBiasForward
>
();
add_single_param_json
<
opr
::
BatchNormForward
>
();
add_single_param_json
<
opr
::
Reshape
>
();
add_single_param_json
<
opr
::
GetVarShape
>
();
add_single_param_json
<
opr
::
Argsort
>
();
add_single_param_json
<
opr
::
Argmin
>
();
add_single_param_json
<
opr
::
Argmax
>
();
add_single_param_json
<
opr
::
ElemwiseMultiType
>
();
add_single_param_json
<
opr
::
AdaptivePooling
>
();
add_single_param_json
<
opr
::
ROIPooling
>
();
add_single_param_json
<
opr
::
ROIAlign
>
();
add_single_param_json
<
opr
::
WarpPerspective
>
();
add_single_param_json
<
opr
::
Remap
>
();
add_single_param_json
<
opr
::
Resize
>
();
add_single_param_json
<
opr
::
IndexingOneHot
>
();
add_single_param_json
<
opr
::
IndexingSetOneHot
>
();
add_single_param_json
<
opr
::
WarpAffine
>
();
add_single_param_json
<
opr
::
TopK
>
();
add_single_param_json
<
opr
::
UniformRNG
>
();
add_single_param_json
<
opr
::
GaussianRNG
>
();
add_single_param_json
<
opr
::
Linspace
>
();
add_single_param_json
<
opr
::
Eye
>
();
add_single_param_json
<
opr
::
standalone
::
NMSKeep
>
();
add_single_param_json
<
opr
::
CvtColor
>
();
#endif
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录