Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
47db29aa
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
47db29aa
编写于
2月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/module): add kwargs param for all modules
GitOrigin-RevId: 7245e669a7d5bcf718d448d9a59e7b31e8ec52d2
上级
6fb19b66
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
42 addition
and
31 deletion
+42
-31
imperative/python/megengine/module/activation.py
imperative/python/megengine/module/activation.py
+6
-6
imperative/python/megengine/module/adaptive_pooling.py
imperative/python/megengine/module/adaptive_pooling.py
+2
-4
imperative/python/megengine/module/batchnorm.py
imperative/python/megengine/module/batchnorm.py
+4
-2
imperative/python/megengine/module/conv.py
imperative/python/megengine/module/conv.py
+10
-1
imperative/python/megengine/module/conv_bn.py
imperative/python/megengine/module/conv_bn.py
+2
-0
imperative/python/megengine/module/dropout.py
imperative/python/megengine/module/dropout.py
+2
-2
imperative/python/megengine/module/elemwise.py
imperative/python/megengine/module/elemwise.py
+2
-2
imperative/python/megengine/module/embedding.py
imperative/python/megengine/module/embedding.py
+2
-1
imperative/python/megengine/module/external.py
imperative/python/megengine/module/external.py
+2
-4
imperative/python/megengine/module/normalization.py
imperative/python/megengine/module/normalization.py
+6
-6
imperative/python/megengine/module/pooling.py
imperative/python/megengine/module/pooling.py
+2
-1
imperative/python/megengine/module/sequential.py
imperative/python/megengine/module/sequential.py
+2
-2
未找到文件。
imperative/python/megengine/module/activation.py
浏览文件 @
47db29aa
...
...
@@ -48,8 +48,8 @@ class Softmax(Module):
"""
def
__init__
(
self
,
axis
=
None
):
super
().
__init__
()
def
__init__
(
self
,
axis
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
axis
=
axis
def
forward
(
self
,
inputs
):
...
...
@@ -167,8 +167,8 @@ class PReLU(Module):
"""
def
__init__
(
self
,
num_parameters
:
int
=
1
,
init
:
float
=
0.25
):
super
().
__init__
()
def
__init__
(
self
,
num_parameters
:
int
=
1
,
init
:
float
=
0.25
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
num_parameters
=
num_parameters
if
num_parameters
>
1
:
# Assume format is NCHW
...
...
@@ -225,8 +225,8 @@ class LeakyReLU(Module):
"""
def
__init__
(
self
,
negative_slope
:
float
=
0.01
):
super
().
__init__
()
def
__init__
(
self
,
negative_slope
:
float
=
0.01
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
negative_slope
=
negative_slope
def
forward
(
self
,
inputs
):
...
...
imperative/python/megengine/module/adaptive_pooling.py
浏览文件 @
47db29aa
...
...
@@ -15,10 +15,8 @@ from .module import Module
class
_AdaptivePoolNd
(
Module
):
def
__init__
(
self
,
oshp
:
Union
[
Tuple
[
int
,
int
],
int
,
Tensor
],
):
super
(
_AdaptivePoolNd
,
self
).
__init__
()
def
__init__
(
self
,
oshp
:
Union
[
Tuple
[
int
,
int
],
int
,
Tensor
],
**
kwargs
):
super
(
_AdaptivePoolNd
,
self
).
__init__
(
**
kwargs
)
self
.
oshp
=
oshp
@
abstractmethod
...
...
imperative/python/megengine/module/batchnorm.py
浏览文件 @
47db29aa
...
...
@@ -26,8 +26,9 @@ class _BatchNorm(Module):
affine
=
True
,
track_running_stats
=
True
,
freeze
=
False
,
**
kwargs
):
super
(
_BatchNorm
,
self
).
__init__
()
super
(
_BatchNorm
,
self
).
__init__
(
**
kwargs
)
self
.
num_features
=
num_features
self
.
eps
=
eps
self
.
momentum
=
momentum
...
...
@@ -151,9 +152,10 @@ class SyncBatchNorm(_BatchNorm):
track_running_stats
=
True
,
freeze
=
False
,
group
:
Optional
[
Group
]
=
WORLD
,
**
kwargs
)
->
None
:
super
().
__init__
(
num_features
,
eps
,
momentum
,
affine
,
track_running_stats
,
freeze
num_features
,
eps
,
momentum
,
affine
,
track_running_stats
,
freeze
,
**
kwargs
)
self
.
group
=
group
...
...
imperative/python/megengine/module/conv.py
浏览文件 @
47db29aa
...
...
@@ -37,8 +37,9 @@ class _ConvNd(Module):
dilation
:
Union
[
int
,
Tuple
[
int
,
int
]],
groups
:
int
,
bias
:
bool
=
True
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
(
**
kwargs
)
if
in_channels
%
groups
!=
0
:
raise
ValueError
(
"in_channels must be divisible by groups"
)
if
out_channels
%
groups
!=
0
:
...
...
@@ -176,6 +177,7 @@ class Conv1d(_ConvNd):
bias
:
bool
=
True
,
conv_mode
:
str
=
"CROSS_CORRELATION"
,
compute_mode
:
str
=
"DEFAULT"
,
**
kwargs
):
kernel_size
=
kernel_size
stride
=
stride
...
...
@@ -192,6 +194,7 @@ class Conv1d(_ConvNd):
dilation
,
groups
,
bias
,
**
kwargs
,
)
def
_get_fanin
(
self
):
...
...
@@ -334,6 +337,7 @@ class Conv2d(_ConvNd):
bias
:
bool
=
True
,
conv_mode
:
str
=
"CROSS_CORRELATION"
,
compute_mode
:
str
=
"DEFAULT"
,
**
kwargs
):
kernel_size
=
_pair_nonzero
(
kernel_size
)
stride
=
_pair_nonzero
(
stride
)
...
...
@@ -350,6 +354,7 @@ class Conv2d(_ConvNd):
dilation
,
groups
,
bias
,
**
kwargs
,
)
def
_get_fanin
(
self
):
...
...
@@ -444,6 +449,7 @@ class ConvTranspose2d(_ConvNd):
bias
:
bool
=
True
,
conv_mode
:
str
=
"CROSS_CORRELATION"
,
compute_mode
:
str
=
"DEFAULT"
,
**
kwargs
):
kernel_size
=
_pair_nonzero
(
kernel_size
)
stride
=
_pair_nonzero
(
stride
)
...
...
@@ -460,6 +466,7 @@ class ConvTranspose2d(_ConvNd):
dilation
,
groups
,
bias
,
**
kwargs
,
)
def
_get_fanin
(
self
):
...
...
@@ -536,6 +543,7 @@ class LocalConv2d(Conv2d):
dilation
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
1
,
groups
:
int
=
1
,
conv_mode
:
str
=
"CROSS_CORRELATION"
,
**
kwargs
):
self
.
input_height
=
input_height
self
.
input_width
=
input_width
...
...
@@ -548,6 +556,7 @@ class LocalConv2d(Conv2d):
dilation
,
groups
,
bias
=
False
,
**
kwargs
,
)
def
_infer_weight_shape
(
self
):
...
...
imperative/python/megengine/module/conv_bn.py
浏览文件 @
47db29aa
...
...
@@ -30,6 +30,7 @@ class _ConvBnActivation2d(Module):
momentum
=
0.9
,
affine
=
True
,
track_running_stats
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
conv
=
Conv2d
(
...
...
@@ -43,6 +44,7 @@ class _ConvBnActivation2d(Module):
bias
,
conv_mode
,
compute_mode
,
**
kwargs
,
)
self
.
bn
=
BatchNorm2d
(
out_channels
,
eps
,
momentum
,
affine
,
track_running_stats
)
...
...
imperative/python/megengine/module/dropout.py
浏览文件 @
47db29aa
...
...
@@ -20,8 +20,8 @@ class Dropout(Module):
:param drop_prob: The probability to drop (set to zero) each single element
"""
def
__init__
(
self
,
drop_prob
=
0.0
):
super
().
__init__
()
def
__init__
(
self
,
drop_prob
=
0.0
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
inputs
):
...
...
imperative/python/megengine/module/elemwise.py
浏览文件 @
47db29aa
...
...
@@ -72,8 +72,8 @@ class Elemwise(Module):
* "NOT": bool unary: ~x
"""
def
__init__
(
self
,
method
):
super
().
__init__
()
def
__init__
(
self
,
method
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
method
=
method
def
forward
(
self
,
*
inps
):
...
...
imperative/python/megengine/module/embedding.py
浏览文件 @
47db29aa
...
...
@@ -64,8 +64,9 @@ class Embedding(Module):
norm_type
:
Optional
[
float
]
=
None
,
initial_weight
:
Parameter
=
None
,
freeze
:
bool
=
False
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
(
**
kwargs
)
if
padding_idx
is
not
None
:
raise
ValueError
(
"Not support padding index now."
)
if
max_norm
is
not
None
or
norm_type
is
not
None
:
...
...
imperative/python/megengine/module/external.py
浏览文件 @
47db29aa
...
...
@@ -19,10 +19,8 @@ class TensorrtRuntimeSubgraph(Module):
See :func:`~.tensorrt_runtime_opr` for more details.
"""
def
__init__
(
self
,
data
,
):
super
(
TensorrtRuntimeSubgraph
,
self
).
__init__
()
def
__init__
(
self
,
data
,
**
kwargs
):
super
(
TensorrtRuntimeSubgraph
,
self
).
__init__
(
**
kwargs
)
self
.
_data
=
data
@
property
...
...
imperative/python/megengine/module/normalization.py
浏览文件 @
47db29aa
...
...
@@ -20,8 +20,8 @@ class GroupNorm(Module):
Reference: https://arxiv.org/pdf/1803.08494.pdf.
"""
def
__init__
(
self
,
num_groups
,
num_channels
,
eps
=
1e-5
,
affine
=
True
):
super
().
__init__
()
def
__init__
(
self
,
num_groups
,
num_channels
,
eps
=
1e-5
,
affine
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
assert
num_channels
%
num_groups
==
0
self
.
num_groups
=
num_groups
self
.
num_channels
=
num_channels
...
...
@@ -70,8 +70,8 @@ class InstanceNorm(Module):
Note that InstanceNorm equals using GroupNome with num_groups=num_channels.
"""
def
__init__
(
self
,
num_channels
,
eps
=
1e-05
,
affine
=
True
):
super
().
__init__
()
def
__init__
(
self
,
num_channels
,
eps
=
1e-05
,
affine
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
num_channels
=
num_channels
self
.
eps
=
eps
self
.
affine
=
affine
...
...
@@ -114,8 +114,8 @@ class LayerNorm(Module):
Note that LayerNorm equals using GroupNorm with num_groups=1.
"""
def
__init__
(
self
,
num_channels
,
eps
=
1e-05
,
affine
=
True
):
super
().
__init__
()
def
__init__
(
self
,
num_channels
,
eps
=
1e-05
,
affine
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
num_channels
=
num_channels
self
.
eps
=
eps
self
.
affine
=
affine
...
...
imperative/python/megengine/module/pooling.py
浏览文件 @
47db29aa
...
...
@@ -19,8 +19,9 @@ class _PoolNd(Module):
kernel_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
stride
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
None
,
padding
:
Union
[
int
,
Tuple
[
int
,
int
]]
=
0
,
**
kwargs
):
super
(
_PoolNd
,
self
).
__init__
()
super
(
_PoolNd
,
self
).
__init__
(
**
kwargs
)
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
or
kernel_size
self
.
padding
=
padding
...
...
imperative/python/megengine/module/sequential.py
浏览文件 @
47db29aa
...
...
@@ -46,8 +46,8 @@ class Sequential(Module):
pred1 = net1(data)
"""
def
__init__
(
self
,
*
args
):
super
().
__init__
()
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
layer_keys
=
[]
if
len
(
args
)
==
1
and
isinstance
(
args
[
0
],
OrderedDict
):
for
key
,
module
in
args
[
0
].
items
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录