Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6c1dbd40
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6c1dbd40
编写于
9月 10, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
docs(mge): fix doctest
GitOrigin-RevId: 131fed87337816c909791719921c122356b5ffc7
上级
e6715910
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
108 addition
and
64 deletion
+108
-64
imperative/python/megengine/core/tensor/function.py
imperative/python/megengine/core/tensor/function.py
+1
-1
imperative/python/megengine/data/transform/vision/transform.py
...ative/python/megengine/data/transform/vision/transform.py
+2
-2
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+7
-5
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+24
-4
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+33
-16
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+17
-7
imperative/python/megengine/functional/utils.py
imperative/python/megengine/functional/utils.py
+0
-13
imperative/python/megengine/module/activation.py
imperative/python/megengine/module/activation.py
+1
-1
imperative/python/megengine/module/batchnorm.py
imperative/python/megengine/module/batchnorm.py
+5
-3
imperative/python/megengine/module/sequential.py
imperative/python/megengine/module/sequential.py
+12
-10
imperative/python/megengine/random/distribution.py
imperative/python/megengine/random/distribution.py
+6
-2
未找到文件。
imperative/python/megengine/core/tensor/function.py
浏览文件 @
6c1dbd40
...
...
@@ -31,7 +31,7 @@ class Function:
self.y = y
return y
def backward(self
.
output_grads):
def backward(self
,
output_grads):
y = self.y
return output_grads * y * (1-y)
...
...
imperative/python/megengine/data/transform/vision/transform.py
浏览文件 @
6c1dbd40
...
...
@@ -194,9 +194,9 @@ class Compose(VisionTransform):
will be random shuffled, the 2nd and 4th transform will also be shuffled.
:param order: The same with :class:`VisionTransform`
Example:
Example
s
:
..testcode::
..
testcode::
from megengine.data.transform import RandomHorizontalFlip, RandomVerticalFlip, CenterCrop, ToMode, Compose
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
6c1dbd40
...
...
@@ -197,8 +197,8 @@ def sqrt(inp: Tensor) -> Tensor:
.. testoutput::
[[0.
1. 1.4142]
[1.7321
2. 2.2361
]]
[[0. 1. 1.4142]
[1.7321
2. 2.2361
]]
"""
return
inp
**
0.5
...
...
@@ -227,8 +227,8 @@ def square(inp: Tensor) -> Tensor:
.. testoutput::
[[
0. 1.
4.]
[
9. 16.
25.]]
[[
0. 1.
4.]
[
9. 16.
25.]]
"""
return
inp
**
2
...
...
@@ -437,7 +437,7 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
:param lower: lower-bound of the range to be clamped to
:param upper: upper-bound of the range to be clamped to
Example:
Example
s
:
.. testcode::
...
...
@@ -452,6 +452,8 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
print(F.clamp(a, upper=3).numpy())
Outputs:
.. testoutput::
[2 2 2 3 4]
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
6c1dbd40
...
...
@@ -58,6 +58,8 @@ def isnan(inp: Tensor) -> Tensor:
print(F.isnan(x).numpy())
Outputs:
.. testoutput::
[False True False]
...
...
@@ -83,6 +85,8 @@ def isinf(inp: Tensor) -> Tensor:
print(F.isinf(x).numpy())
Outputs:
.. testoutput::
[False True False]
...
...
@@ -141,7 +145,9 @@ def sum(
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3))
out = F.sum(data)
print(out.numpy())
Outputs:
.. testoutput::
[21]
...
...
@@ -208,6 +214,8 @@ def mean(
out = F.mean(data)
print(out.numpy())
Outputs:
.. testoutput::
[3.5]
...
...
@@ -250,9 +258,11 @@ def var(
out = F.var(data)
print(out.numpy())
Outputs:
.. testoutput::
[2.916
666
7]
[2.9167]
"""
if
axis
is
None
:
m
=
mean
(
inp
,
axis
=
axis
,
keepdims
=
False
)
...
...
@@ -288,9 +298,11 @@ def std(
out = F.std(data, axis=1)
print(out.numpy())
Outputs:
.. testoutput::
[0.816
4966 0.8164966
]
[0.816
5 0.8165
]
"""
return
var
(
inp
,
axis
=
axis
,
keepdims
=
keepdims
)
**
0.5
...
...
@@ -354,6 +366,8 @@ def max(
y = F.max(x)
print(y.numpy())
Outputs:
.. testoutput::
[6]
...
...
@@ -388,9 +402,11 @@ def norm(
y = F.norm(x)
print(y.numpy())
Outputs:
.. testoutput::
[4.358
89
9]
[4.3589]
"""
if
p
==
0
:
...
...
@@ -426,6 +442,8 @@ def argmin(
y = F.argmin(x)
print(y.numpy())
Outputs:
.. testoutput::
[0]
...
...
@@ -479,6 +497,8 @@ def argmax(
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3))
y = F.argmax(x)
print(y.numpy())
Outputs:
.. testoutput::
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
6c1dbd40
...
...
@@ -372,10 +372,12 @@ def softplus(inp: Tensor) -> Tensor:
x = tensor(np.arange(-3, 3, dtype=np.float32))
y = F.softplus(x)
print(y.numpy())
Outputs:
.. testoutput::
.. output::
[0.04858735 0.126928 0.3132617 0.6931472 1.3132617 2.126928 ]
[0.0486 0.1269 0.3133 0.6931 1.3133 2.1269]
"""
return
log1p
(
exp
(
-
abs
(
inp
)))
+
relu
(
inp
)
...
...
@@ -411,10 +413,12 @@ def log_softmax(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
y = F.log_softmax(x, axis=1)
print(y.numpy())
.. output::
Outputs:
.. testoutput::
[[-4.4519
143 -3.4519143 -2.4519143 -1.4519144 -0.4519144
]
[-4.4519
143 -3.4519143 -2.4519143 -1.4519144 -0.4519144
]]
[[-4.4519
-3.4519 -2.4519 -1.4519 -0.4519
]
[-4.4519
-3.4519 -2.4519 -1.4519 -0.4519
]]
"""
return
inp
-
logsumexp
(
inp
,
axis
,
keepdims
=
True
)
...
...
@@ -432,6 +436,7 @@ def logsigmoid(inp: Tensor) -> Tensor:
:param inp: The input tensor
Examples:
.. testcode::
import numpy as np
...
...
@@ -442,9 +447,12 @@ def logsigmoid(inp: Tensor) -> Tensor:
y = F.logsigmoid(x)
print(y.numpy())
.. output::
Outputs:
.. testoutput::
[-5.0067153 -4.01815 -3.0485873 -2.126928 -1.3132617 -0.6931472 -0.3132617 -0.126928 -0.04858735 -0.01814993]
[-5.0067 -4.0181 -3.0486 -2.1269 -1.3133 -0.6931 -0.3133 -0.1269 -0.0486
-0.0181]
"""
return
-
softplus
(
-
inp
)
...
...
@@ -478,6 +486,7 @@ def logsumexp(
:param keepdims: whether to retain :attr:`axis` or not for the output tensor.
Examples:
.. testcode::
import numpy as np
...
...
@@ -488,9 +497,11 @@ def logsumexp(
y = F.logsumexp(x, axis=1, keepdims=False)
print(y.numpy())
.. output::
Outputs:
.. testoutput::
[-0.548
0856 4.4519143
]
[-0.548
1 4.4519
]
"""
max_value
=
max
(
inp
,
axis
,
keepdims
=
True
)
...
...
@@ -577,8 +588,9 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
Outputs:
.. testoutput::
[[0.01165623 0.03168492 0.08612854 0.23412167 0.6364086 ]
[0.01165623 0.03168492 0.08612854 0.23412167 0.6364086 ]]
[[0.0117 0.0317 0.0861 0.2341 0.6364]
[0.0117 0.0317 0.0861 0.2341 0.6364]]
"""
if
axis
is
None
:
...
...
@@ -1026,7 +1038,7 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
Examples:
.. te
e
stcode::
.. testcode::
import numpy as np
from megengine import tensor
...
...
@@ -1039,9 +1051,10 @@ def dot(inp1: Tensor, inp2: Tensor) -> Tensor:
Outputs:
.. testoutput::
[55.]
.. testoutputs::
"""
op
=
builtin
.
Dot
()
inp1
,
inp2
=
utils
.
convert_inputs
(
inp1
,
inp2
)
...
...
@@ -1058,7 +1071,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
Examples:
.. te
e
stcode::
.. testcode::
import numpy as np
from megengine import tensor
...
...
@@ -1070,7 +1083,9 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
Outputs:
[7.348, 1.]
.. testoutput::
[7.3485 1. ]
"""
op
=
builtin
.
SVD
(
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
)
...
...
@@ -1445,6 +1460,8 @@ def indexing_one_hot(
val = F.indexing_one_hot(src, index)
print(val.numpy())
Outputs:
.. testoutput::
[1.]
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
6c1dbd40
...
...
@@ -60,7 +60,7 @@ __all__ = [
]
def
eye
(
n
:
int
,
*
,
dtype
=
None
,
device
:
Optional
[
CompNode
]
=
None
)
->
Tensor
:
def
eye
(
n
:
int
,
*
,
dtype
=
"float32"
,
device
:
Optional
[
CompNode
]
=
None
)
->
Tensor
:
"""
Returns a 2D tensor with ones on the diagonal and zeros elsewhere.
...
...
@@ -80,7 +80,7 @@ def eye(n: int, *, dtype=None, device: Optional[CompNode] = None) -> Tensor:
data_shape = (4, 6)
n, m = data_shape
out = F.eye(
n, m
, dtype=np.float32)
out = F.eye(
[n, m]
, dtype=np.float32)
print(out.numpy())
Outputs:
...
...
@@ -135,6 +135,8 @@ def zeros_like(inp: Tensor) -> Tensor:
out = F.zeros_like(inp)
print(out.numpy())
Outputs:
.. testoutput::
[[0 0 0]
...
...
@@ -638,7 +640,7 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor:
.. testoutput::
Tensor([1. 4.]) Tensor([0 3], dtype=int32)
[1. 4.] [0 3]
"""
if
not
isinstance
(
x
,
(
TensorWrapperBase
,
TensorBase
)):
...
...
@@ -888,6 +890,8 @@ def linspace(
a = F.linspace(3,10,5)
print(a.numpy())
Outputs:
.. testoutput::
[ 3. 4.75 6.5 8.25 10. ]
...
...
@@ -930,6 +934,8 @@ def arange(
a = F.arange(5)
print(a.numpy())
Outputs:
.. testoutput::
...
...
@@ -977,7 +983,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())
Outputs:
.. testoutput::
[1]
...
...
@@ -1000,7 +1008,7 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
:param offsets: device value of offsets
:param offsets_val: offsets of inputs, length of 2 * n,
format [begin0, end0, begin1, end1].
:return:
spli
t tensors
:return:
conca
t tensors
Examples:
...
...
@@ -1013,10 +1021,12 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
a = tensor(np.ones((1,), np.int32))
b = tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = tensor(offsets, np.int32)
offsets = tensor(offsets
_val
, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())
Outputs:
.. testoutput::
[1 1 1 1 1 1 1 1 1 1]
...
...
imperative/python/megengine/functional/utils.py
浏览文件 @
6c1dbd40
...
...
@@ -63,19 +63,6 @@ def accuracy(
return
accs
def
zero_grad
(
inp
:
Tensor
)
->
Tensor
:
r
"""
Returns a tensor which is treated as constant during backward gradient calcuation,
i.e. its gradient is zero.
:param inp: Input tensor.
See implementation of :func:`~.softmax` for example.
"""
print
(
"zero_grad is obsoleted, please use detach instead"
)
raise
NotImplementedError
def
copy
(
inp
,
cn
):
r
"""
Copy tensor to another device.
...
...
imperative/python/megengine/module/activation.py
浏览文件 @
6c1dbd40
...
...
@@ -219,7 +219,7 @@ class LeakyReLU(Module):
.. testoutput::
[-0.08
-0.12 6. 10. ]
[-0.08 -0.12 6. 10. ]
"""
...
...
imperative/python/megengine/module/batchnorm.py
浏览文件 @
6c1dbd40
...
...
@@ -267,15 +267,17 @@ class BatchNorm2d(_BatchNorm):
m = M.BatchNorm2d(4)
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32"))
oup = m(inp)
print(m.weight
, m.bias
)
print(m.weight
.numpy(), m.bias.numpy()
)
# Without Learnable Parameters
m = M.BatchNorm2d(4, affine=False)
oup = m(inp)
print(m.weight, m.bias)
Outputs:
.. testoutput::
Tensor([1. 1. 1. 1.]) Tensor([0. 0. 0. 0.])
[1. 1. 1. 1.] [0. 0. 0. 0.]
None None
"""
...
...
imperative/python/megengine/module/sequential.py
浏览文件 @
6c1dbd40
...
...
@@ -17,23 +17,25 @@ class Sequential(Module):
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example:
Examples:
.. testcode::
import numpy as np
import megengine.nn as nn
import megengine.
nn.
functional as F
from megengine import tensor
import megengine.functional as F
batch_size = 64
data =
nn.Input("data", shape=(batch_size, 1, 28, 28), dtype=np.float32, value=np.zeros((batch_size, 1, 28, 28))
)
label =
nn.Input("label", shape=(batch_size,), dtype=np.int32, value=np.zeros(batch_size,)
)
data =
tensor(np.zeros((batch_size, 1, 28, 28)), dtype=np.float32
)
label =
tensor(np.zeros(batch_size,), dtype=np.int32
)
data = data.reshape(batch_size, -1)
net =
nn
.Sequential(
nn
.Linear(28 * 28, 320),
nn
.Linear(320, 500),
nn
.Linear(500, 320),
nn
.Linear(320, 10)
net =
M
.Sequential(
M
.Linear(28 * 28, 320),
M
.Linear(320, 500),
M
.Linear(500, 320),
M
.Linear(320, 10)
)
pred = net(data)
...
...
imperative/python/megengine/random/distribution.py
浏览文件 @
6c1dbd40
...
...
@@ -37,7 +37,9 @@ def normal(
x = rand.normal(mean=0, std=1, size=(2, 2))
print(x.numpy())
Outputs:
.. testoutput::
:options: +SKIP
...
...
@@ -73,7 +75,9 @@ def uniform(
x = rand.uniform(size=(2, 2))
print(x.numpy())
Outputs:
.. testoutput::
:options: +SKIP
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录