Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
60c7d62a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
395
Star
4704
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
60c7d62a
编写于
12月 25, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): remove multidispatch, raw_tensor, register
GitOrigin-RevId: ca5a6ed8eb6c8089b758eb84bf26d6e928ea4d41
上级
b5e46ae9
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
9 addition
and
1706 deletion
+9
-1706
imperative/python/megengine/core/autodiff/builtin_op_utils.py
...rative/python/megengine/core/autodiff/builtin_op_utils.py
+1
-31
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+0
-142
imperative/python/megengine/core/ops/special.py
imperative/python/megengine/core/ops/special.py
+0
-3
imperative/python/megengine/core/tensor/core.py
imperative/python/megengine/core/tensor/core.py
+2
-24
imperative/python/megengine/core/tensor/function.py
imperative/python/megengine/core/tensor/function.py
+0
-154
imperative/python/megengine/core/tensor/multipledispatch/__init__.py
...python/megengine/core/tensor/multipledispatch/__init__.py
+0
-53
imperative/python/megengine/core/tensor/multipledispatch/conflict.py
...python/megengine/core/tensor/multipledispatch/conflict.py
+0
-165
imperative/python/megengine/core/tensor/multipledispatch/core.py
...ive/python/megengine/core/tensor/multipledispatch/core.py
+0
-130
imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py
...thon/megengine/core/tensor/multipledispatch/dispatcher.py
+0
-445
imperative/python/megengine/core/tensor/multipledispatch/utils.py
...ve/python/megengine/core/tensor/multipledispatch/utils.py
+0
-210
imperative/python/megengine/core/tensor/multipledispatch/variadic.py
...python/megengine/core/tensor/multipledispatch/variadic.py
+0
-140
imperative/python/megengine/core/tensor/raw_tensor/__init__.py
...ative/python/megengine/core/tensor/raw_tensor/__init__.py
+0
-136
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+2
-9
imperative/python/megengine/distributed/helper.py
imperative/python/megengine/distributed/helper.py
+1
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+1
-1
imperative/python/megengine/quantization/utils.py
imperative/python/megengine/quantization/utils.py
+1
-1
imperative/python/test/conftest.py
imperative/python/test/conftest.py
+1
-1
imperative/python/test/integration/test_save_load.py
imperative/python/test/integration/test_save_load.py
+0
-1
imperative/python/test/unit/core/test_dtype_quant.py
imperative/python/test/unit/core/test_dtype_quant.py
+0
-1
imperative/python/test/unit/test_dispatch.py
imperative/python/test/unit/test_dispatch.py
+0
-58
未找到文件。
imperative/python/megengine/core/autodiff/builtin_op_utils.py
浏览文件 @
60c7d62a
...
...
@@ -12,6 +12,7 @@ import itertools
import
numpy
as
np
from
.._imperative_rt
import
TensorAttr
,
imperative
from
.._imperative_rt.core2
import
apply
from
..ops.builtin
import
(
Broadcast
,
Elemwise
,
...
...
@@ -25,37 +26,6 @@ from ..ops.builtin import (
Subtensor
,
)
from
..ops.special
import
Const
from
..tensor.core
import
apply
from
..tensor.function
import
Function
@
functools
.
singledispatch
def
builtin_op_get_backward_fn
(
op
:
OpDef
,
inputs
,
outputs
,
input_requires_grad
):
assert
0
@
builtin_op_get_backward_fn
.
register
(
OpDef
)
def
_
(
op
:
OpDef
,
inputs
,
outputs
,
input_requires_grad
):
if
isinstance
(
op
,
Reshape
):
grad_fn
=
reshape_grad_fn
elif
isinstance
(
op
,
Subtensor
):
grad_fn
=
subtensor_grad_fn
elif
isinstance
(
op
,
IndexingMultiAxisVec
):
grad_fn
=
indexingMultiAxisVec_grad_fn
elif
isinstance
(
op
,
Broadcast
)
or
(
isinstance
(
op
,
Elemwise
)
and
op
.
mode
==
Elemwise
.
Mode
.
ADD
):
grad_fn
=
elemwise_add_grad_fn
elif
isinstance
(
op
,
Reduce
)
and
op
.
mode
==
Reduce
.
Mode
.
SUM
:
grad_fn
=
reduce_sum_grad_fn
else
:
grad_fn
=
default_grad_fn
return
grad_fn
(
op
,
inputs
,
outputs
,
input_requires_grad
)
@
builtin_op_get_backward_fn
.
register
(
Function
)
def
_
(
op
:
Function
,
inputs
,
outputs
,
input_requires_grad
):
return
op
.
get_backward_fn
(),
[
True
,]
*
len
(
outputs
)
def
default_grad_fn
(
op
,
inputs
,
outputs
,
input_requires_grad
):
...
...
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
60c7d62a
...
...
@@ -19,8 +19,6 @@ import megengine as mge
from
.._imperative_rt
import
core2
,
ops
from
..ops.builtin
import
Elemwise
,
OpDef
,
RemoteSend
from
..ops.special
import
Const
from
..tensor.core
import
TensorBase
,
TensorWrapperBase
,
apply
from
..tensor.function
import
Function
from
.
import
builtin_op_utils
""" Some notes:
...
...
@@ -48,146 +46,6 @@ def get_grad_managers():
return
[
_grad_manager_dict
[
key
]
for
key
in
_grad_manager_dict
]
def
add
(
a
,
b
):
(
c
,)
=
apply
(
Elemwise
(
Elemwise
.
Mode
.
ADD
),
a
,
b
)
return
c
def
get_tensor
(
x
):
# use recursion to avoid infinite loop
if
isinstance
(
x
,
Tensor
):
return
x
try
:
x
=
x
.
__wrapped__
except
AttributeError
:
raise
TypeError
(
type
(
x
))
return
get_tensor
(
x
)
class
clearable
:
__cleared
=
False
def
__bool__
(
self
):
return
not
self
.
__cleared
def
clear
(
self
):
self
.
__dict__
.
clear
()
self
.
__cleared
=
True
class
OpNode
(
clearable
):
""" OpNode saves all the information to form the computational graph.
"""
def
__init__
(
self
):
self
.
id
=
None
self
.
inputs
=
None
# Could be VariableNode
self
.
outputs
=
None
# Could be VariableNode
self
.
backward
=
None
self
.
has_grad_fn
=
None
self
.
backward_allow_noinput
=
False
class
VariableNode
(
clearable
):
""" VariableNode saves OpNode and callback.
FIXME!!! Explain manager and owner
"""
def
__init__
(
self
,
manager
,
owner
,
opnode
=
None
,
callback
=
None
):
# manager is Grad type
self
.
manager
=
weakref
.
ref
(
manager
)
# owner is Tensor type
self
.
owner
=
weakref
.
ref
(
owner
)
self
.
opnode
=
opnode
self
.
callback
=
callback
class
Tracer
(
clearable
,
TensorBase
):
def
__init__
(
self
,
node
=
None
):
""" type(node) is VariableNode
"""
self
.
node
=
node
@
functools
.
singledispatch
def
check_backward_allow_noinput
(
op
:
OpDef
):
return
False
@
functools
.
singledispatch
def
get_op_has_grad_fn
(
op
:
OpDef
):
assert
0
@
get_op_has_grad_fn
.
register
(
OpDef
)
def
_
(
op
:
OpDef
):
return
default_has_grad_fn
@
get_op_has_grad_fn
.
register
(
Function
)
def
_
(
op
:
Function
):
return
default_has_grad_fn
def
default_has_grad_fn
(
opnode
,
reached
):
for
v
in
opnode
.
outputs
:
if
v
()
in
reached
:
return
True
return
False
@
apply
.
register
()
def
tracer_apply
(
op
:
(
OpDef
,
Function
),
*
args
:
typing
.
Optional
[
Tracer
]):
args
=
tuple
(
i
if
isinstance
(
i
,
Tracer
)
else
None
for
i
in
args
)
input_requires_grad
=
list
(
map
(
bool
,
args
))
if
not
any
(
input_requires_grad
):
return
ctx
=
get_context
()
manager
=
None
assert
len
(
ctx
.
inputs
)
==
len
(
args
)
for
i
,
j
in
zip
(
ctx
.
inputs
,
args
):
if
j
:
j
=
j
.
node
assert
i
is
j
.
owner
()
if
manager
is
None
:
manager
=
j
.
manager
()
assert
manager
else
:
assert
manager
is
j
.
manager
()
if
not
manager
.
_enabled
:
return
# register backward method
# tuple of backward functions corresponding to dy / dx_i
# None means y is not a function of x_i
backward
,
output_need_grad
=
builtin_op_utils
.
builtin_op_get_backward_fn
(
op
,
ctx
.
inputs
,
ctx
.
outputs
,
input_requires_grad
)
assert
len
(
ctx
.
outputs
)
==
len
(
output_need_grad
)
if
not
any
(
output_need_grad
):
return
opnode
,
outputs
=
manager
.
_new_opnode
([
i
and
i
.
node
for
i
in
args
],
ctx
.
outputs
)
if
isinstance
(
op
,
RemoteSend
):
manager
.
remote_send_cache
.
append
(
opnode
)
opnode
.
backward
=
backward
outputs
=
[
x
if
y
else
None
for
(
x
,
y
)
in
zip
(
outputs
,
output_need_grad
)]
opnode
.
backward_allow_noinput
=
check_backward_allow_noinput
(
op
)
opnode
.
has_grad_fn
=
get_op_has_grad_fn
(
op
)
return
tuple
(
outputs
)
@
apply
.
register
()
def
_
(
op
:
Const
,
*
_
:
typing
.
Optional
[
Tracer
]):
return
None
class
Grad
:
def
__init__
(
self
):
self
.
_impl
=
core2
.
GradKey
()
...
...
imperative/python/megengine/core/ops/special.py
浏览文件 @
60c7d62a
...
...
@@ -8,9 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
# from .._imperative_rt.core2 import Tensor
from
..tensor.core
import
OpBase
,
TensorBase
,
apply
class
Const
:
def
__init__
(
self
,
value
=
None
,
*
,
dtype
=
None
,
device
=
None
):
...
...
imperative/python/megengine/core/tensor/core.py
浏览文件 @
60c7d62a
...
...
@@ -13,12 +13,9 @@ import sys
import
typing
from
abc
import
ABC
from
.multipledispatch
import
Dispatcher
class
OpBase
(
ABC
):
def
__call__
(
self
,
*
args
):
return
apply
(
self
,
*
args
)
class
OpBase
:
pass
class
TensorBase
:
...
...
@@ -27,22 +24,3 @@ class TensorBase:
class
TensorWrapperBase
:
pass
apply
=
Dispatcher
(
"apply"
)
OpBase
.
apply
=
apply
@
apply
.
register
()
def
_
(
op
:
OpBase
,
*
args
:
TensorBase
):
raise
NotImplementedError
@
apply
.
register
()
def
_
(
op
:
OpBase
,
*
args
:
TensorWrapperBase
):
assert
args
Wrapper
=
type
(
args
[
0
])
outputs
=
apply
(
op
,
*
(
i
.
__wrapped__
for
i
in
args
))
assert
isinstance
(
outputs
,
tuple
)
return
tuple
(
map
(
Wrapper
,
outputs
))
imperative/python/megengine/core/tensor/function.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..ops.builtin
import
OpDef
from
.core
import
TensorBase
,
TensorWrapperBase
,
apply
class
Function
:
"""
Defines a block of operations with customizable differentiation.
The computation should be defined in ``forward`` method, with gradient
computation defined in ``backward`` method.
Each instance of ``Function`` should be used only once during forwardding.
Examples:
.. testcode::
class Sigmoid(Function):
def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.y = y
return y
def backward(self, output_grads):
y = self.y
return output_grads * y * (1-y)
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
def
__call__
(
self
,
*
args
):
ret
=
apply
(
self
,
*
args
)
if
type
(
ret
)
==
tuple
and
len
(
ret
)
==
1
:
return
ret
[
0
]
return
ret
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
:param input: input tensors.
:return: a tuple of Tensor or a single Tensor.
.. note::
This method should return a tuple of Tensor or a single Tensor representing the output
of the function.
"""
raise
NotImplementedError
def
backward
(
self
,
*
output_grads
):
"""
Compute the gradient of the forward function. It must be overriden by all subclasses.
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`.
.. note::
In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``.
.. note::
This method should return a tuple which containing the gradients of all inputs, in the same order
as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned
instead if there is only one input. If users want to stop the propagation of some gradients,
the corresponding returned values should be set ``None`` .
"""
raise
NotImplementedError
def
get_backward_fn
(
self
):
if
self
.
backward
is
None
:
return
None
def
_backward
(
*
output_grads
):
if
type
(
output_grads
)
is
tuple
:
_output_grads
=
[
TensorWrapper
(
i
)
if
i
is
not
None
else
i
for
i
in
output_grads
]
else
:
_output_grads
=
(
TensorWrapper
(
output_grads
)
if
output_grads
is
not
None
else
output_grads
,
)
ret
=
self
.
backward
(
*
_output_grads
)
if
type
(
ret
)
is
not
tuple
:
ret
=
(
ret
,)
ret
=
tuple
(
i
.
__wrapped__
if
isinstance
(
i
,
TensorWrapper
)
else
i
for
i
in
ret
)
return
ret
return
_backward
Function
.
apply
=
Function
.
__call__
@
apply
.
register
()
def
_
(
op
:
Function
,
*
args
:
TensorWrapperBase
):
assert
args
Wrapper
=
type
(
args
[
0
])
# compute the value for self define function
extra_data_dic
=
{}
for
arg
in
args
:
extra_data_dic
[
arg
.
__wrapped__
]
=
arg
.
__wrapped__
.
_extra_data
arg
.
__wrapped__
.
_extra_data
=
{}
rets
=
op
.
forward
(
*
args
)
for
arg
in
args
:
arg
.
__wrapped__
.
_extra_data
=
extra_data_dic
[
arg
.
__wrapped__
]
# update the gradient information for self define function
inputs
=
tuple
(
map
(
lambda
i
:
i
.
__wrapped__
,
args
))
outputs
=
(
tuple
(
map
(
lambda
i
:
i
.
__wrapped__
,
rets
))
if
type
(
rets
)
is
tuple
else
(
rets
.
__wrapped__
,)
)
for
output
in
outputs
:
if
output
not
in
inputs
:
output
.
_extra_data
=
{}
with
push_context
()
as
ctx
:
ctx
.
inputs
=
inputs
ctx
.
outputs
=
outputs
for
k
in
set
().
union
(
*
(
i
.
_extra_data
for
i
in
inputs
if
isinstance
(
i
,
Tensor
))):
ctx
.
key
=
k
data
=
tuple
(
i
.
_extra_data
.
get
(
k
)
if
isinstance
(
i
,
Tensor
)
else
i
for
i
in
inputs
)
# data are instances of Tracer
# dispatched to apply.add@grad.py
rets
=
apply
(
op
,
*
data
)
if
rets
is
not
None
:
assert
len
(
outputs
)
==
len
(
rets
)
for
t
,
i
in
zip
(
outputs
,
rets
):
t
.
_extra_data
[
k
]
=
i
return
tuple
(
map
(
Wrapper
,
outputs
))
imperative/python/megengine/core/tensor/multipledispatch/__init__.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
# This directory is a fork of multipledispatch.
#
# Repo: https://github.com/mrocklin/multipledispatch
# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834
# Authors: Matthew Rocklin et al.
#
# The original LICENSE file is included in the ACKNOWLEDGEMENT file under
# MegEngine root directory.
from
.core
import
dispatch
from
.dispatcher
import
Dispatcher
imperative/python/megengine/core/tensor/multipledispatch/conflict.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
from
collections
import
OrderedDict
from
.utils
import
_toposort
,
groupby
from
.variadic
import
isvariadic
class
AmbiguityWarning
(
Warning
):
pass
def
supercedes
(
a
,
b
):
""" A is consistent and strictly more specific than B """
if
len
(
a
)
<
len
(
b
):
# only case is if a is empty and b is variadic
return
not
a
and
len
(
b
)
==
1
and
isvariadic
(
b
[
-
1
])
elif
len
(
a
)
==
len
(
b
):
return
all
(
map
(
issubclass
,
a
,
b
))
else
:
# len(a) > len(b)
p1
=
0
p2
=
0
while
p1
<
len
(
a
)
and
p2
<
len
(
b
):
cur_a
=
a
[
p1
]
cur_b
=
b
[
p2
]
if
not
(
isvariadic
(
cur_a
)
or
isvariadic
(
cur_b
)):
if
not
issubclass
(
cur_a
,
cur_b
):
return
False
p1
+=
1
p2
+=
1
elif
isvariadic
(
cur_a
):
assert
p1
==
len
(
a
)
-
1
return
p2
==
len
(
b
)
-
1
and
issubclass
(
cur_a
,
cur_b
)
elif
isvariadic
(
cur_b
):
assert
p2
==
len
(
b
)
-
1
if
not
issubclass
(
cur_a
,
cur_b
):
return
False
p1
+=
1
return
p2
==
len
(
b
)
-
1
and
p1
==
len
(
a
)
def
consistent
(
a
,
b
):
""" It is possible for an argument list to satisfy both A and B """
# Need to check for empty args
if
not
a
:
return
not
b
or
isvariadic
(
b
[
0
])
if
not
b
:
return
not
a
or
isvariadic
(
a
[
0
])
# Non-empty args check for mutual subclasses
if
len
(
a
)
==
len
(
b
):
return
all
(
issubclass
(
aa
,
bb
)
or
issubclass
(
bb
,
aa
)
for
aa
,
bb
in
zip
(
a
,
b
))
else
:
p1
=
0
p2
=
0
while
p1
<
len
(
a
)
and
p2
<
len
(
b
):
cur_a
=
a
[
p1
]
cur_b
=
b
[
p2
]
if
not
issubclass
(
cur_b
,
cur_a
)
and
not
issubclass
(
cur_a
,
cur_b
):
return
False
if
not
(
isvariadic
(
cur_a
)
or
isvariadic
(
cur_b
)):
p1
+=
1
p2
+=
1
elif
isvariadic
(
cur_a
):
p2
+=
1
elif
isvariadic
(
cur_b
):
p1
+=
1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return
isvariadic
(
cur_a
)
and
p2
==
len
(
b
)
or
isvariadic
(
cur_b
)
and
p1
==
len
(
a
)
def
ambiguous
(
a
,
b
):
""" A is consistent with B but neither is strictly more specific """
return
consistent
(
a
,
b
)
and
not
(
supercedes
(
a
,
b
)
or
supercedes
(
b
,
a
))
def
ambiguities
(
signatures
):
""" All signature pairs such that A is ambiguous with B """
signatures
=
list
(
map
(
tuple
,
signatures
))
return
set
(
(
a
,
b
)
for
a
in
signatures
for
b
in
signatures
if
hash
(
a
)
<
hash
(
b
)
and
ambiguous
(
a
,
b
)
and
not
any
(
supercedes
(
c
,
a
)
and
supercedes
(
c
,
b
)
for
c
in
signatures
)
)
def
super_signature
(
signatures
):
""" A signature that would break ambiguities """
n
=
len
(
signatures
[
0
])
assert
all
(
len
(
s
)
==
n
for
s
in
signatures
)
return
[
max
([
type
.
mro
(
sig
[
i
])
for
sig
in
signatures
],
key
=
len
)[
0
]
for
i
in
range
(
n
)]
def
edge
(
a
,
b
,
tie_breaker
=
hash
):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return
supercedes
(
a
,
b
)
and
(
not
supercedes
(
b
,
a
)
or
tie_breaker
(
a
)
>
tie_breaker
(
b
)
)
def
ordering
(
signatures
):
""" A sane ordering of signatures to check, first to last
Topoological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures
=
list
(
map
(
tuple
,
signatures
))
edges
=
[(
a
,
b
)
for
a
in
signatures
for
b
in
signatures
if
edge
(
a
,
b
)]
edges
=
groupby
(
lambda
x
:
x
[
0
],
edges
)
for
s
in
signatures
:
if
s
not
in
edges
:
edges
[
s
]
=
[]
edges
=
OrderedDict
((
k
,
[
b
for
a
,
b
in
v
])
for
k
,
v
in
edges
.
items
())
return
_toposort
(
edges
)
imperative/python/megengine/core/tensor/multipledispatch/core.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
import
inspect
import
sys
from
.dispatcher
import
Dispatcher
,
MethodDispatcher
,
ambiguity_warn
global_namespace
=
dict
()
def
dispatch
(
*
types
,
**
kwargs
):
""" Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.
Examples
--------
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
Specify an isolated namespace with the namespace keyword argument
>>> my_namespace = dict()
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1
Dispatch on instance methods within classes
>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
"""
namespace
=
kwargs
.
get
(
"namespace"
,
global_namespace
)
types
=
tuple
(
types
)
def
_df
(
func
):
name
=
func
.
__name__
if
ismethod
(
func
):
dispatcher
=
inspect
.
currentframe
().
f_back
.
f_locals
.
get
(
name
,
MethodDispatcher
(
name
),
)
else
:
if
name
not
in
namespace
:
namespace
[
name
]
=
Dispatcher
(
name
)
dispatcher
=
namespace
[
name
]
dispatcher
.
add
(
types
,
func
)
return
dispatcher
return
_df
def
ismethod
(
func
):
""" Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if
hasattr
(
inspect
,
"signature"
):
signature
=
inspect
.
signature
(
func
)
return
signature
.
parameters
.
get
(
"self"
,
None
)
is
not
None
else
:
if
sys
.
version_info
.
major
<
3
:
spec
=
inspect
.
getargspec
(
func
)
else
:
spec
=
inspect
.
getfullargspec
(
func
)
return
spec
and
spec
.
args
and
spec
.
args
[
0
]
==
"self"
imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
import
copy
import
inspect
import
itertools
as
itl
from
warnings
import
warn
from
..._imperative_rt.dispatcher
import
Dispatcher
as
CDispatcher
from
.conflict
import
AmbiguityWarning
,
ambiguities
,
ordering
,
super_signature
from
.utils
import
expand_tuples
,
parse_union
from
.variadic
import
Variadic
,
isvariadic
def
ambiguity_warn
(
dispatcher
,
ambiguities
):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn
(
warning_text
(
dispatcher
.
name
,
ambiguities
),
AmbiguityWarning
)
def
variadic_signature_matches_iter
(
types
,
full_signature
):
"""
Check if a set of input types matches a variadic signature.
Notes
-----
The algorithm is as follows:
Initialize the current signature to the first in the sequence
For each type in `types`:
If the current signature is variadic
If the type matches the signature
yield True
Else
Try to get the next signature
If no signatures are left we can't possibly have a match
so yield False
Else
yield True if the type matches the current signature
Get the next signature
"""
sigiter
=
iter
(
full_signature
)
sig
=
next
(
sigiter
)
for
typ
in
types
:
matches
=
issubclass
(
typ
,
sig
)
yield
matches
if
not
isvariadic
(
sig
):
# we're not matching a variadic argument, so move to the next
# element in the signature
sig
=
next
(
sigiter
)
else
:
try
:
sig
=
next
(
sigiter
)
except
StopIteration
:
assert
isvariadic
(
sig
)
yield
True
else
:
# We have signature items left over, so all of our arguments
# haven't matched
yield
False
def
variadic_signature_matches
(
types
,
full_signature
):
# No arguments always matches a variadic signature
assert
full_signature
return
all
(
variadic_signature_matches_iter
(
types
,
full_signature
))
def
get_func_signature
(
function
):
sig
=
inspect
.
signature
(
function
)
types
=
[]
for
p
in
sig
.
parameters
.
values
():
ann
=
p
.
annotation
ann
=
parse_union
(
ann
)
or
ann
if
p
.
kind
in
(
inspect
.
Parameter
.
POSITIONAL_ONLY
,
inspect
.
Parameter
.
POSITIONAL_OR_KEYWORD
,
):
types
.
append
(
ann
)
if
p
.
kind
==
inspect
.
Parameter
.
VAR_POSITIONAL
:
types
.
append
([
ann
])
return
tuple
(
types
)
class
Frame
:
__slots__
=
"args"
,
"types"
,
"mro"
,
"mro_offset"
class
Dispatcher
(
CDispatcher
):
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__
=
"__name__"
,
"name"
,
"funcs"
,
"_ordering"
,
"doc"
def
__init__
(
self
,
name
,
doc
=
None
):
self
.
name
=
self
.
__name__
=
name
self
.
funcs
=
{}
self
.
doc
=
doc
def
register
(
self
,
*
types
,
**
kwargs
):
""" register dispatcher with new implementation
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def
_df
(
func
):
self
.
add
(
types
,
func
,
**
kwargs
)
return
func
return
_df
def
add
(
self
,
signature
,
func
):
""" Add new types/method pair to dispatcher
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if
not
signature
:
signature
=
get_func_signature
(
func
)
# Handle union types
if
any
(
isinstance
(
typ
,
tuple
)
for
typ
in
signature
):
for
typs
in
expand_tuples
(
signature
):
self
.
add
(
typs
,
func
)
return
new_signature
=
[]
for
index
,
typ
in
enumerate
(
signature
,
start
=
1
):
if
not
isinstance
(
typ
,
(
type
,
list
)):
str_sig
=
", "
.
join
(
c
.
__name__
if
isinstance
(
c
,
type
)
else
str
(
c
)
for
c
in
signature
)
raise
TypeError
(
"Tried to dispatch on non-type: %s
\n
"
"In signature: <%s>
\n
"
"In function: %s"
%
(
typ
,
str_sig
,
self
.
name
)
)
# handle variadic signatures
if
isinstance
(
typ
,
list
):
if
index
!=
len
(
signature
):
raise
TypeError
(
"Variadic signature must be the last element"
)
if
len
(
typ
)
!=
1
:
raise
TypeError
(
"Variadic signature must contain exactly one element. "
"To use a variadic union type place the desired types "
"inside of a tuple, e.g., [(int, str)]"
)
new_signature
.
append
(
Variadic
[
typ
[
0
]])
else
:
new_signature
.
append
(
typ
)
l
=
self
.
funcs
.
setdefault
(
tuple
(
new_signature
),
[])
for
i
in
l
:
if
i
is
func
:
raise
ValueError
(
"already registered"
)
l
.
append
(
func
)
self
.
enable
(
func
)
self
.
clear_cache
()
try
:
del
self
.
_ordering
except
AttributeError
:
pass
@
property
def
ordering
(
self
):
try
:
return
self
.
_ordering
except
AttributeError
:
return
self
.
reorder
()
def
reorder
(
self
,
on_ambiguity
=
ambiguity_warn
):
self
.
_ordering
=
od
=
ordering
(
self
.
funcs
)
amb
=
ambiguities
(
self
.
funcs
)
if
amb
:
on_ambiguity
(
self
,
amb
)
return
od
def
__str__
(
self
):
return
"<dispatched %s>"
%
self
.
name
__repr__
=
__str__
def
dispatch
(
self
,
*
types
):
"""
Deterimine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``multipledispatch.conflict`` - module to determine resolution order
"""
if
types
in
self
.
funcs
:
return
self
.
funcs
[
types
][
-
1
]
for
f
in
self
.
dispatch_iter
(
*
types
):
return
f
def
dispatch_iter
(
self
,
*
types
):
n
=
len
(
types
)
for
signature
in
self
.
ordering
:
if
(
len
(
signature
)
==
n
and
all
(
map
(
issubclass
,
types
,
signature
))
or
len
(
signature
)
and
isvariadic
(
signature
[
-
1
])
and
variadic_signature_matches
(
types
,
signature
)
):
yield
from
self
.
funcs
[
signature
][::
-
1
]
def
__getstate__
(
self
):
return
{
"name"
:
self
.
name
,
"funcs"
:
self
.
funcs
}
def
__setstate__
(
self
,
d
):
self
.
name
=
d
[
"name"
]
self
.
funcs
=
d
[
"funcs"
]
self
.
_ordering
=
ordering
(
self
.
funcs
)
self
.
_cache
=
dict
()
@
property
def
__doc__
(
self
):
docs
=
[
"Multiply dispatched method: %s"
%
self
.
name
]
if
self
.
doc
:
docs
.
append
(
self
.
doc
)
other
=
[]
for
sig
in
self
.
ordering
[::
-
1
]:
funcs
=
self
.
funcs
[
sig
]
s
=
"Inputs: <%s>
\n
"
%
str_signature
(
sig
)
sep
=
"-"
*
len
(
s
)
+
"
\n
"
for
i
,
func
in
enumerate
(
funcs
):
s
+=
sep
if
len
(
funcs
)
>
1
:
s
+=
"[Handler %d]
\n\n
"
%
(
i
+
1
)
if
i
:
s
+=
"
\n\n
"
if
func
.
__doc__
:
s
+=
func
.
__doc__
.
strip
()
else
:
s
+=
repr
(
func
)
+
"
\n
"
docs
.
append
(
s
)
return
"
\n\n
"
.
join
(
docs
)
def
_help
(
self
,
*
args
):
return
self
.
dispatch
(
*
map
(
type
,
args
)).
__doc__
def
help
(
self
,
*
args
,
**
kwargs
):
""" Print docstring for the function corresponding to inputs """
print
(
self
.
_help
(
*
args
))
def
_source
(
self
,
*
args
):
func
=
self
.
dispatch
(
*
map
(
type
,
args
))
if
not
func
:
raise
TypeError
(
"No function found"
)
return
source
(
func
)
def
source
(
self
,
*
args
,
**
kwargs
):
""" Print source code for the function corresponding to inputs """
print
(
self
.
_source
(
*
args
))
def
source
(
func
):
s
=
"File: %s
\n\n
"
%
inspect
.
getsourcefile
(
func
)
s
=
s
+
inspect
.
getsource
(
func
)
return
s
class
MethodDispatcher
(
Dispatcher
):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
__slots__
=
(
"obj"
,
"cls"
)
@
classmethod
def
get_func_params
(
cls
,
func
):
if
hasattr
(
inspect
,
"signature"
):
sig
=
inspect
.
signature
(
func
)
return
itl
.
islice
(
sig
.
parameters
.
values
(),
1
,
None
)
def
__get__
(
self
,
instance
,
owner
):
self
.
obj
=
instance
self
.
cls
=
owner
return
self
def
__call__
(
self
,
*
args
,
**
kwargs
):
types
=
tuple
([
type
(
arg
)
for
arg
in
args
])
func
=
self
.
dispatch
(
*
types
)
if
not
func
:
raise
NotImplementedError
(
"Could not find signature for %s: <%s>"
%
(
self
.
name
,
str_signature
(
types
))
)
return
func
(
self
.
obj
,
*
args
,
**
kwargs
)
def
str_signature
(
sig
):
""" String representation of type signature
>>> str_signature((int, float))
'int, float'
"""
return
", "
.
join
(
cls
.
__name__
for
cls
in
sig
)
def
warning_text
(
name
,
amb
):
""" The text for ambiguity warnings """
text
=
"
\n
Ambiguities exist in dispatched function %s
\n\n
"
%
(
name
)
text
+=
"The following signatures may result in ambiguous behavior:
\n
"
for
pair
in
amb
:
text
+=
"
\t
"
+
", "
.
join
(
"["
+
str_signature
(
s
)
+
"]"
for
s
in
pair
)
+
"
\n
"
text
+=
"
\n\n
Consider making the following additions:
\n\n
"
text
+=
"
\n\n
"
.
join
(
[
"@dispatch("
+
str_signature
(
super_signature
(
s
))
+
")
\n
def %s(...)"
%
name
for
s
in
amb
]
)
return
text
imperative/python/megengine/core/tensor/multipledispatch/utils.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
import
sys
import
typing
from
collections
import
OrderedDict
def
raises
(
err
,
lamda
):
try
:
lamda
()
return
False
except
err
:
return
True
def
expand_tuples
(
L
):
"""
>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]
>>> expand_tuples([1, 2])
[(1, 2)]
"""
if
not
L
:
return
[()]
elif
not
isinstance
(
L
[
0
],
tuple
):
rest
=
expand_tuples
(
L
[
1
:])
return
[(
L
[
0
],)
+
t
for
t
in
rest
]
else
:
rest
=
expand_tuples
(
L
[
1
:])
return
[(
item
,)
+
t
for
t
in
rest
for
item
in
L
[
0
]]
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def
_toposort
(
edges
):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges
=
reverse_dict
(
edges
)
incoming_edges
=
OrderedDict
((
k
,
set
(
val
))
for
k
,
val
in
incoming_edges
.
items
())
S
=
OrderedDict
.
fromkeys
(
v
for
v
in
edges
if
v
not
in
incoming_edges
)
L
=
[]
while
S
:
n
,
_
=
S
.
popitem
()
L
.
append
(
n
)
for
m
in
edges
.
get
(
n
,
()):
assert
n
in
incoming_edges
[
m
]
incoming_edges
[
m
].
remove
(
n
)
if
not
incoming_edges
[
m
]:
S
[
m
]
=
None
if
any
(
incoming_edges
.
get
(
v
,
None
)
for
v
in
edges
):
raise
ValueError
(
"Input has cycles"
)
return
L
def
reverse_dict
(
d
):
"""
Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result
=
OrderedDict
()
for
key
in
d
:
for
val
in
d
[
key
]:
result
[
val
]
=
result
.
get
(
val
,
tuple
())
+
(
key
,)
return
result
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def
groupby
(
func
,
seq
):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
See Also:
``countby``
"""
d
=
OrderedDict
()
for
item
in
seq
:
key
=
func
(
item
)
if
key
not
in
d
:
d
[
key
]
=
list
()
d
[
key
].
append
(
item
)
return
d
def
typename
(
type
):
"""
Get the name of `type`.
Parameters
----------
type : Union[Type, Tuple[Type]]
Returns
-------
str
The name of `type` or a tuple of the names of the types in `type`.
Examples
--------
>>> typename(int)
'int'
>>> typename((int, float))
'(int, float)'
"""
try
:
return
type
.
__name__
except
AttributeError
:
if
len
(
type
)
==
1
:
return
typename
(
*
type
)
return
"(%s)"
%
", "
.
join
(
map
(
typename
,
type
))
# parse typing.Union
def
parse_union
(
ann
):
if
hasattr
(
typing
,
"UnionMeta"
):
if
type
(
ann
)
is
not
typing
.
UnionMeta
:
return
return
ann
.
__union_params__
elif
hasattr
(
typing
,
"_Union"
):
if
type
(
ann
)
is
not
typing
.
_Union
:
return
return
ann
.
__args__
elif
hasattr
(
typing
,
"_GenericAlias"
):
if
type
(
ann
)
is
not
typing
.
_GenericAlias
:
if
type
(
ann
)
is
not
typing
.
Union
:
return
else
:
if
ann
.
__origin__
is
not
typing
.
Union
:
return
return
ann
.
__args__
elif
hasattr
(
typing
,
"Union"
):
if
typing
.
get_origin
(
ann
)
is
not
typing
.
Union
:
return
return
typing
.
get_args
(
ann
)
else
:
raise
NotImplementedError
(
"unsupported Python version"
)
imperative/python/megengine/core/tensor/multipledispatch/variadic.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
from
.utils
import
typename
class
VariadicSignatureType
(
type
):
# checking if subclass is a subclass of self
def
__subclasscheck__
(
self
,
subclass
):
other_type
=
subclass
.
variadic_type
if
isvariadic
(
subclass
)
else
(
subclass
,)
return
subclass
is
self
or
all
(
issubclass
(
other
,
self
.
variadic_type
)
for
other
in
other_type
)
def
__eq__
(
self
,
other
):
"""
Return True if other has the same variadic type
Parameters
----------
other : object (type)
The object (type) to check
Returns
-------
bool
Whether or not `other` is equal to `self`
"""
return
isvariadic
(
other
)
and
set
(
self
.
variadic_type
)
==
set
(
other
.
variadic_type
)
def
__hash__
(
self
):
return
hash
((
type
(
self
),
frozenset
(
self
.
variadic_type
)))
def
isvariadic
(
obj
):
"""
Check whether the type `obj` is variadic.
Parameters
----------
obj : type
The type to check
Returns
-------
bool
Whether or not `obj` is variadic
Examples
--------
>>> isvariadic(int)
False
>>> isvariadic(Variadic[int])
True
"""
return
isinstance
(
obj
,
VariadicSignatureType
)
class
VariadicSignatureMeta
(
type
):
"""
A metaclass that overrides ``__getitem__`` on the class. This is used to
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""
def
__getitem__
(
self
,
variadic_type
):
if
not
(
isinstance
(
variadic_type
,
(
type
,
tuple
))
or
type
(
variadic_type
)):
raise
ValueError
(
"Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]"
)
if
not
isinstance
(
variadic_type
,
tuple
):
variadic_type
=
(
variadic_type
,)
return
VariadicSignatureType
(
"Variadic[%s]"
%
typename
(
variadic_type
),
(),
dict
(
variadic_type
=
variadic_type
,
__slots__
=
()),
)
class
Variadic
(
metaclass
=
VariadicSignatureMeta
):
"""
A class whose getitem method can be used to generate a new type
representing a specific variadic signature.
Examples
--------
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
>>> issubclass(int, Variadic[int])
True
>>> issubclass(int, Variadic[(int, str)])
True
>>> issubclass(str, Variadic[(int, str)])
True
>>> issubclass(float, Variadic[(int, str)])
False
"""
imperative/python/megengine/core/tensor/raw_tensor/__init__.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
numpy
as
np
from
..._imperative_rt
import
CompNode
,
DeviceTensorND
from
..._imperative_rt.imperative
import
(
_drop
,
_get_dev_tensor
,
_swap_in
,
_swap_out
,
apply_op
,
delete
,
get_device
,
get_dtype
,
get_shape
,
get_value
,
put
,
)
from
..._wrap
import
device
as
as_device
from
...ops.builtin
import
Copy
,
OpDef
,
TypeCvt
from
...ops.special
import
Const
from
..core
import
OpBase
,
TensorBase
,
apply
class
RawTensor
(
TensorBase
):
_init_cb
=
None
_del_cb
=
None
_handle
=
None
def
__init__
(
self
,
handle
=
None
,
isscalar
=
False
):
self
.
_handle
=
handle
self
.
_isscalar
=
isscalar
if
handle
is
not
None
:
if
self
.
_init_cb
:
self
.
_init_cb
()
@
property
def
dtype
(
self
):
return
get_dtype
(
self
.
_handle
)
@
property
def
device
(
self
):
return
as_device
(
get_device
(
self
.
_handle
))
@
property
def
shape
(
self
):
if
self
.
_isscalar
:
return
()
return
get_shape
(
self
.
_handle
)
def
numpy
(
self
):
ret
=
get_value
(
self
.
_handle
)
if
self
.
_isscalar
:
ret
=
ret
.
squeeze
()
return
ret
def
_dev_tensor
(
self
):
return
_get_dev_tensor
(
self
.
_handle
)
def
_drop
(
self
):
_drop
(
self
.
_handle
)
def
_swap_in
(
self
):
_swap_in
(
self
.
_handle
)
def
_swap_out
(
self
):
_swap_out
(
self
.
_handle
)
def
__repr__
(
self
):
return
"{}({}, device='{}')"
.
format
(
type
(
self
).
__qualname__
,
repr
(
self
.
numpy
()),
self
.
device
)
def
__del__
(
self
):
if
self
.
_handle
is
not
None
:
if
self
.
_del_cb
:
self
.
_del_cb
()
delete
(
self
.
_handle
)
@
apply
.
register
()
def
_
(
op
:
OpDef
,
*
args
:
RawTensor
):
outputs
=
apply_op
(
op
,
tuple
(
i
.
_handle
for
i
in
args
))
return
tuple
(
map
(
RawTensor
,
outputs
))
@
apply
.
register
()
def
_
(
op
:
Const
,
*
args
:
RawTensor
):
dtype
=
op
.
dtype
device
=
as_device
(
op
.
device
).
to_c
()
return
(
as_raw_tensor
(
op
.
value
,
dtype
=
dtype
,
device
=
device
),)
@
functools
.
singledispatch
def
as_raw_tensor
(
obj
,
dtype
=
None
,
device
=
None
):
obj
=
np
.
asarray
(
obj
,
dtype
=
dtype
)
if
obj
.
dtype
==
np
.
float64
:
obj
=
obj
.
astype
(
np
.
float32
)
if
obj
.
dtype
==
np
.
int64
:
obj
=
obj
.
astype
(
np
.
int32
)
return
as_raw_tensor
(
obj
,
device
=
device
)
@
as_raw_tensor
.
register
(
DeviceTensorND
)
def
_
(
data
:
DeviceTensorND
):
return
RawTensor
(
put
(
data
))
@
as_raw_tensor
.
register
(
np
.
ndarray
)
def
_
(
array
:
np
.
ndarray
,
dtype
=
None
,
device
=
None
):
device
=
None
if
device
is
None
else
as_device
(
device
).
to_c
()
if
0
in
array
.
strides
:
array
=
array
.
squeeze
().
reshape
(
array
.
shape
)
return
RawTensor
(
put
(
array
,
dtype
=
dtype
,
device
=
device
),
isscalar
=
(
array
.
ndim
==
0
))
@
as_raw_tensor
.
register
(
RawTensor
)
def
_
(
tensor
:
RawTensor
,
dtype
=
None
,
device
=
None
):
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
if
dtype
!=
tensor
.
dtype
:
(
tensor
,)
=
apply
(
TypeCvt
(
dtype
=
dtype
),
tensor
)
if
device
is
not
None
:
device
=
as_device
(
device
)
if
device
!=
tensor
.
device
:
(
tensor
,)
=
apply
(
Copy
(
comp_node
=
device
.
to_c
()),
tensor
)
return
tensor
imperative/python/megengine/distributed/functional.py
浏览文件 @
60c7d62a
...
...
@@ -9,14 +9,7 @@
from
typing
import
Optional
,
Tuple
from
..core._imperative_rt.core2
import
apply
from
..core.autodiff.builtin_op_utils
import
builtin_op_get_backward_fn
from
..core.autodiff.grad
import
(
Tracer
,
check_backward_allow_noinput
,
get_grad_managers
,
get_op_has_grad_fn
,
tracer_apply
,
)
from
..core.autodiff.grad
import
get_grad_managers
from
..core.ops.builtin
import
CollectiveComm
,
Copy
,
RemoteRecv
,
RemoteSend
from
..device
import
get_default_device
from
..tensor
import
Tensor
...
...
@@ -236,7 +229,7 @@ def remote_recv(
device
=
get_default_device
()
# dummy input
if
inp
==
None
:
inp
=
t
ensor
([
0
],
device
=
device
)
inp
=
T
ensor
([
0
],
device
=
device
)
tracer_set
=
get_client
().
check_remote_tracer
(
key
)
for
grad_manager
in
get_grad_managers
():
if
grad_manager
.
name
in
tracer_set
:
...
...
imperative/python/megengine/distributed/helper.py
浏览文件 @
60c7d62a
...
...
@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
outputs
=
apply
(
op
,
inp
)
for
s
,
x
in
zip
(
shapes
,
outputs
):
if
not
s
:
x
.
_isscalar
=
True
x
.
setscalar
()
return
outputs
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
60c7d62a
...
...
@@ -10,7 +10,7 @@
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
Tensor
,
apply
from
..core._imperative_rt.core2
import
apply
from
..core._trace_option
import
use_symbolic_shape
from
..core.ops
import
builtin
from
..core.ops.builtin
import
BatchNorm
...
...
imperative/python/megengine/quantization/utils.py
浏览文件 @
60c7d62a
...
...
@@ -12,10 +12,10 @@ from typing import Dict
import
numpy
as
np
from
..
import
functional
as
F
from
..core._imperative_rt.core2
import
apply
from
..core.autodiff.grad
import
Function
from
..core.ops
import
builtin
from
..core.tensor
import
megbrain_graph
from
..core.tensor.core
import
apply
from
..core.tensor.dtype
import
_metadata_dict
from
..tensor
import
Tensor
...
...
imperative/python/test/conftest.py
浏览文件 @
60c7d62a
...
...
@@ -3,7 +3,7 @@ import sys
import
pytest
from
megengine.core._imperative_rt.
imperative
import
sync
from
megengine.core._imperative_rt.
core2
import
sync
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"helpers"
))
...
...
imperative/python/test/integration/test_save_load.py
浏览文件 @
60c7d62a
...
...
@@ -4,7 +4,6 @@ import megengine as mge
import
megengine.autodiff
as
ad
import
megengine.optimizer
as
optimizer
from
megengine
import
Parameter
,
tensor
from
megengine.core.tensor.raw_tensor
import
RawTensor
from
megengine.module
import
Module
...
...
imperative/python/test/unit/core/test_dtype_quant.py
浏览文件 @
60c7d62a
...
...
@@ -13,7 +13,6 @@ import pytest
import
megengine.core.tensor.megbrain_graph
as
G
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.tensor.core
import
apply
from
megengine.core.tensor.dtype
import
(
_metadata_dict
,
convert_from_qint4
,
...
...
imperative/python/test/unit/test_dispatch.py
已删除
100644 → 0
浏览文件 @
b5e46ae9
from
megengine.core.tensor.multipledispatch
import
Dispatcher
def
test_register_many
():
f
=
Dispatcher
(
"f"
)
log
=
[]
@
f
.
register
()
def
_
(
x
:
int
):
log
.
append
(
"a"
)
return
log
[
-
1
]
@
f
.
register
()
def
_
(
x
:
int
):
log
.
append
(
"b"
)
return
log
[
-
1
]
assert
f
(
0
)
==
"b"
assert
log
==
[
"b"
]
def
test_return_not_implemented
():
f
=
Dispatcher
(
"f"
)
log
=
[]
@
f
.
register
()
def
_
(
x
:
int
):
log
.
append
(
"a"
)
return
log
[
-
1
]
@
f
.
register
()
def
_
(
x
:
int
):
log
.
append
(
"b"
)
return
NotImplemented
assert
f
(
0
)
==
"a"
assert
log
==
[
"b"
,
"a"
]
def
test_super
():
f
=
Dispatcher
(
"f"
)
log
=
[]
@
f
.
register
()
def
_
(
x
:
int
):
log
.
append
(
"a"
)
return
log
[
-
1
]
@
f
.
register
()
def
_
(
x
:
int
):
log
.
append
(
"b"
)
return
f
.
super
(
x
)
assert
f
(
0
)
==
"a"
assert
log
==
[
"b"
,
"a"
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录