Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3cb692be
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3cb692be
编写于
4月 26, 2020
作者:
M
meixiaowei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify resnet101 scripts for pylint
上级
99bbb3a3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
49 addition
and
112 deletion
+49
-112
example/resnet101_imagenet/README.md
example/resnet101_imagenet/README.md
+0
-3
example/resnet101_imagenet/config.py
example/resnet101_imagenet/config.py
+0
-3
example/resnet101_imagenet/lr_generator.py
example/resnet101_imagenet/lr_generator.py
+0
-60
example/resnet101_imagenet/train.py
example/resnet101_imagenet/train.py
+6
-11
example/resnet101_imagenet/var_init.py
example/resnet101_imagenet/var_init.py
+42
-34
mindspore/model_zoo/resnet.py
mindspore/model_zoo/resnet.py
+1
-1
未找到文件。
example/resnet101_imagenet/README.md
浏览文件 @
3cb692be
...
@@ -54,9 +54,6 @@ Parameters for both training and evaluating can be set in config.py.
...
@@ -54,9 +54,6 @@ Parameters for both training and evaluating can be set in config.py.
"save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
"save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
"keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint
"keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"lr_init": 0.01, # initial learning rate
"lr_end": 0.00001, # final learning rate
"lr_max": 0.1, # maximum learning rate
"warmup_epochs": 0, # number of warmup epoch
"warmup_epochs": 0, # number of warmup epoch
"lr_decay_mode": "cosine" # decay mode for generating learning rate
"lr_decay_mode": "cosine" # decay mode for generating learning rate
"label_smooth": 1, # label_smooth
"label_smooth": 1, # label_smooth
...
...
example/resnet101_imagenet/config.py
浏览文件 @
3cb692be
...
@@ -31,9 +31,6 @@ config = ed({
...
@@ -31,9 +31,6 @@ config = ed({
"save_checkpoint_steps"
:
500
,
"save_checkpoint_steps"
:
500
,
"keep_checkpoint_max"
:
40
,
"keep_checkpoint_max"
:
40
,
"save_checkpoint_path"
:
"./"
,
"save_checkpoint_path"
:
"./"
,
"lr_init"
:
0.01
,
"lr_end"
:
0.00001
,
"lr_max"
:
0.1
,
"warmup_epochs"
:
0
,
"warmup_epochs"
:
0
,
"lr_decay_mode"
:
"cosine"
,
"lr_decay_mode"
:
"cosine"
,
"label_smooth"
:
1
,
"label_smooth"
:
1
,
...
...
example/resnet101_imagenet/lr_generator.py
浏览文件 @
3cb692be
...
@@ -50,63 +50,3 @@ def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch):
...
@@ -50,63 +50,3 @@ def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch):
lr
=
base_lr
*
decayed
lr
=
base_lr
*
decayed
lr_each_step
.
append
(
lr
)
lr_each_step
.
append
(
lr
)
return
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
return
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
def
get_lr
(
global_step
,
lr_init
,
lr_end
,
lr_max
,
warmup_epochs
,
total_epochs
,
steps_per_epoch
,
lr_decay_mode
):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
Returns:
np.array, learning rate array
"""
lr_each_step
=
[]
total_steps
=
steps_per_epoch
*
total_epochs
warmup_steps
=
steps_per_epoch
*
warmup_epochs
if
lr_decay_mode
==
'steps'
:
decay_epoch_index
=
[
0.3
*
total_steps
,
0.6
*
total_steps
,
0.8
*
total_steps
]
for
i
in
range
(
total_steps
):
if
i
<
decay_epoch_index
[
0
]:
lr
=
lr_max
elif
i
<
decay_epoch_index
[
1
]:
lr
=
lr_max
*
0.1
elif
i
<
decay_epoch_index
[
2
]:
lr
=
lr_max
*
0.01
else
:
lr
=
lr_max
*
0.001
lr_each_step
.
append
(
lr
)
elif
lr_decay_mode
==
'poly'
:
if
warmup_steps
!=
0
:
inc_each_step
=
(
float
(
lr_max
)
-
float
(
lr_init
))
/
float
(
warmup_steps
)
else
:
inc_each_step
=
0
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
lr
=
float
(
lr_init
)
+
inc_each_step
*
float
(
i
)
else
:
base
=
(
1.0
-
(
float
(
i
)
-
float
(
warmup_steps
))
/
(
float
(
total_steps
)
-
float
(
warmup_steps
)))
lr
=
float
(
lr_max
)
*
base
*
base
if
lr
<
0.0
:
lr
=
0.0
lr_each_step
.
append
(
lr
)
else
:
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
lr
=
lr_init
+
(
lr_max
-
lr_init
)
*
i
/
warmup_steps
else
:
lr
=
lr_max
-
(
lr_max
-
lr_end
)
*
(
i
-
warmup_steps
)
/
(
total_steps
-
warmup_steps
)
lr_each_step
.
append
(
lr
)
current_step
=
global_step
lr_each_step
=
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
learning_rate
=
lr_each_step
[
current_step
:]
return
learning_rate
example/resnet101_imagenet/train.py
浏览文件 @
3cb692be
...
@@ -19,7 +19,7 @@ import argparse
...
@@ -19,7 +19,7 @@ import argparse
import
random
import
random
import
numpy
as
np
import
numpy
as
np
from
dataset
import
create_dataset
from
dataset
import
create_dataset
from
lr_generator
import
get_lr
,
warmup_cosine_annealing_lr
from
lr_generator
import
warmup_cosine_annealing_lr
from
config
import
config
from
config
import
config
from
mindspore
import
context
from
mindspore
import
context
from
mindspore
import
Tensor
from
mindspore
import
Tensor
...
@@ -32,9 +32,9 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
...
@@ -32,9 +32,9 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
import
mindspore.dataset.engine
as
de
import
mindspore.dataset.engine
as
de
from
mindspore.communication.management
import
init
from
mindspore.communication.management
import
init
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.common.initializer
as
weight_init
from
crossentropy
import
CrossEntropy
from
crossentropy
import
CrossEntropy
from
var_init
import
default_recurisive_init
,
KaimingNormal
from
var_init
import
default_recurisive_init
,
KaimingNormal
import
mindspore.common.initializer
as
weight_init
random
.
seed
(
1
)
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
...
@@ -72,7 +72,7 @@ if __name__ == '__main__':
...
@@ -72,7 +72,7 @@ if __name__ == '__main__':
net
=
resnet101
(
class_num
=
config
.
class_num
)
net
=
resnet101
(
class_num
=
config
.
class_num
)
# weight init
# weight init
default_recurisive_init
(
net
)
default_recurisive_init
(
net
)
for
name
,
cell
in
net
.
cells_and_names
():
for
_
,
cell
in
net
.
cells_and_names
():
if
isinstance
(
cell
,
nn
.
Conv2d
):
if
isinstance
(
cell
,
nn
.
Conv2d
):
cell
.
weight
.
default_input
=
weight_init
.
initializer
(
KaimingNormal
(
a
=
math
.
sqrt
(
5
),
cell
.
weight
.
default_input
=
weight_init
.
initializer
(
KaimingNormal
(
a
=
math
.
sqrt
(
5
),
mode
=
'fan_out'
,
nonlinearity
=
'relu'
),
mode
=
'fan_out'
,
nonlinearity
=
'relu'
),
...
@@ -83,17 +83,12 @@ if __name__ == '__main__':
...
@@ -83,17 +83,12 @@ if __name__ == '__main__':
loss
=
CrossEntropy
(
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
loss
=
CrossEntropy
(
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
if
args_opt
.
do_train
:
if
args_opt
.
do_train
:
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
epoch_size
,
batch_size
=
config
.
batch_size
)
repeat_num
=
epoch_size
,
batch_size
=
config
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
step_size
=
dataset
.
get_dataset_size
()
loss_scale
=
FixedLossScaleManager
(
config
.
loss_scale
,
drop_overflow_update
=
False
)
loss_scale
=
FixedLossScaleManager
(
config
.
loss_scale
,
drop_overflow_update
=
False
)
# learning rate strategy
# learning rate strategy with cosine
if
config
.
lr_decay_mode
==
'cosine'
:
lr
=
Tensor
(
warmup_cosine_annealing_lr
(
config
.
lr
,
step_size
,
config
.
warmup_epochs
,
config
.
epoch_size
))
lr
=
Tensor
(
warmup_cosine_annealing_lr
(
config
.
lr
,
step_size
,
config
.
warmup_epochs
,
config
.
epoch_size
))
else
:
lr
=
Tensor
(
get_lr
(
global_step
=
0
,
lr_init
=
config
.
lr_init
,
lr_end
=
config
.
lr_end
,
lr_max
=
config
.
lr_max
,
warmup_epochs
=
config
.
warmup_epochs
,
total_epochs
=
epoch_size
,
steps_per_epoch
=
step_size
,
lr_decay_mode
=
'poly'
))
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
lr
,
config
.
momentum
,
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
lr
,
config
.
momentum
,
config
.
weight_decay
,
config
.
loss_scale
)
config
.
weight_decay
,
config
.
loss_scale
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
amp_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
amp_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
...
...
example/resnet101_imagenet/var_init.py
浏览文件 @
3cb692be
...
@@ -18,10 +18,10 @@ import numpy as np
...
@@ -18,10 +18,10 @@ import numpy as np
from
mindspore.common
import
initializer
as
init
from
mindspore.common
import
initializer
as
init
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
def
calculate_gain
(
nonlinearity
,
param
=
None
):
def
calculate_gain
(
nonlinearity
,
param
=
None
):
r
"""Return the recommended gain value for the given nonlinearity function.
r
"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
The values are as follows:
================= ====================================================
================= ====================================================
nonlinearity gain
nonlinearity gain
================= ====================================================
================= ====================================================
...
@@ -37,12 +37,13 @@ def calculate_gain(nonlinearity, param=None):
...
@@ -37,12 +37,13 @@ def calculate_gain(nonlinearity, param=None):
param: optional parameter for the non-linear function
param: optional parameter for the non-linear function
"""
"""
linear_fns
=
[
'linear'
,
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
]
linear_fns
=
[
'linear'
,
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
]
gain
=
0
if
nonlinearity
in
linear_fns
or
nonlinearity
==
'sigmoid'
:
if
nonlinearity
in
linear_fns
or
nonlinearity
==
'sigmoid'
:
return
1
gain
=
1
elif
nonlinearity
==
'tanh'
:
elif
nonlinearity
==
'tanh'
:
return
5.0
/
3
gain
=
5.0
/
3
elif
nonlinearity
==
'relu'
:
elif
nonlinearity
==
'relu'
:
return
math
.
sqrt
(
2.0
)
gain
=
math
.
sqrt
(
2.0
)
elif
nonlinearity
==
'leaky_relu'
:
elif
nonlinearity
==
'leaky_relu'
:
if
param
is
None
:
if
param
is
None
:
negative_slope
=
0.01
negative_slope
=
0.01
...
@@ -51,15 +52,16 @@ def calculate_gain(nonlinearity, param=None):
...
@@ -51,15 +52,16 @@ def calculate_gain(nonlinearity, param=None):
negative_slope
=
param
negative_slope
=
param
else
:
else
:
raise
ValueError
(
"negative_slope {} not a valid number"
.
format
(
param
))
raise
ValueError
(
"negative_slope {} not a valid number"
.
format
(
param
))
return
math
.
sqrt
(
2.0
/
(
1
+
negative_slope
**
2
))
gain
=
math
.
sqrt
(
2.0
/
(
1
+
negative_slope
**
2
))
else
:
else
:
raise
ValueError
(
"Unsupported nonlinearity {}"
.
format
(
nonlinearity
))
raise
ValueError
(
"Unsupported nonlinearity {}"
.
format
(
nonlinearity
))
return
gain
def
_calculate_correct_fan
(
array
,
mode
):
def
_calculate_correct_fan
(
array
,
mode
):
mode
=
mode
.
lower
()
mode
=
mode
.
lower
()
valid_modes
=
[
'fan_in'
,
'fan_out'
]
valid_modes
=
[
'fan_in'
,
'fan_out'
]
if
mode
not
in
valid_modes
:
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Mode {} not supported, please use one of {}"
.
format
(
mode
,
valid_modes
))
raise
ValueError
(
"Mode {} not supported, please use one of {}"
.
format
(
mode
,
valid_modes
))
fan_in
,
fan_out
=
_calculate_fan_in_and_fan_out
(
array
)
fan_in
,
fan_out
=
_calculate_fan_in_and_fan_out
(
array
)
return
fan_in
if
mode
==
'fan_in'
else
fan_out
return
fan_in
if
mode
==
'fan_in'
else
fan_out
...
@@ -83,13 +85,12 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
...
@@ -83,13 +85,12 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
backwards pass.
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
"""
fan
=
_calculate_correct_fan
(
array
,
mode
)
fan
=
_calculate_correct_fan
(
array
,
mode
)
gain
=
calculate_gain
(
nonlinearity
,
a
)
gain
=
calculate_gain
(
nonlinearity
,
a
)
std
=
gain
/
math
.
sqrt
(
fan
)
std
=
gain
/
math
.
sqrt
(
fan
)
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
bound
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
return
np
.
random
.
uniform
(
-
bound
,
bound
,
array
.
shape
)
return
np
.
random
.
uniform
(
-
bound
,
bound
,
array
.
shape
)
def
kaiming_normal_
(
array
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
def
kaiming_normal_
(
array
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
r
"""Fills the input `Tensor` with values according to the method
r
"""Fills the input `Tensor` with values according to the method
...
@@ -97,12 +98,10 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
...
@@ -97,12 +98,10 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
performance on ImageNet classification` - He, K. et al. (2015), using a
performance on ImageNet classification` - He, K. et al. (2015), using a
normal distribution. The resulting tensor will have values sampled from
normal distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Also known as He initialization.
Args:
Args:
array: an n-dimensional `tensor`
array: an n-dimensional `tensor`
a: the negative slope of the rectifier used after this layer (only
a: the negative slope of the rectifier used after this layer (only
...
@@ -118,13 +117,12 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
...
@@ -118,13 +117,12 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
gain
=
calculate_gain
(
nonlinearity
,
a
)
gain
=
calculate_gain
(
nonlinearity
,
a
)
std
=
gain
/
math
.
sqrt
(
fan
)
std
=
gain
/
math
.
sqrt
(
fan
)
return
np
.
random
.
normal
(
0
,
std
,
array
.
shape
)
return
np
.
random
.
normal
(
0
,
std
,
array
.
shape
)
def
_calculate_fan_in_and_fan_out
(
array
):
def
_calculate_fan_in_and_fan_out
(
array
):
"""calculate the fan_in and fan_out for input array"""
"""calculate the fan_in and fan_out for input array"""
dimensions
=
len
(
array
.
shape
)
dimensions
=
len
(
array
.
shape
)
if
dimensions
<
2
:
if
dimensions
<
2
:
raise
ValueError
(
"Fan in and fan out can not be computed for array with fewer than 2 dimensions"
)
raise
ValueError
(
"Fan in and fan out can not be computed for array with fewer than 2 dimensions"
)
num_input_fmaps
=
array
.
shape
[
1
]
num_input_fmaps
=
array
.
shape
[
1
]
num_output_fmaps
=
array
.
shape
[
0
]
num_output_fmaps
=
array
.
shape
[
0
]
receptive_field_size
=
1
receptive_field_size
=
1
...
@@ -132,19 +130,30 @@ def _calculate_fan_in_and_fan_out(array):
...
@@ -132,19 +130,30 @@ def _calculate_fan_in_and_fan_out(array):
receptive_field_size
=
array
[
0
][
0
].
size
receptive_field_size
=
array
[
0
][
0
].
size
fan_in
=
num_input_fmaps
*
receptive_field_size
fan_in
=
num_input_fmaps
*
receptive_field_size
fan_out
=
num_output_fmaps
*
receptive_field_size
fan_out
=
num_output_fmaps
*
receptive_field_size
return
fan_in
,
fan_out
return
fan_in
,
fan_out
def
assignment
(
arr
,
num
):
"""Assign the value of num to arr"""
if
arr
.
shape
==
():
arr
=
arr
.
reshape
((
1
))
arr
[:]
=
num
arr
=
arr
.
reshape
(())
else
:
if
isinstance
(
num
,
np
.
ndarray
):
arr
[:]
=
num
[:]
else
:
arr
[:]
=
num
return
arr
class
KaimingUniform
(
init
.
Initializer
):
class
KaimingUniform
(
init
.
Initializer
):
def
__init__
(
self
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
def
__init__
(
self
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
super
(
KaimingUniform
,
self
).
__init__
()
super
(
KaimingUniform
,
self
).
__init__
()
self
.
a
=
a
self
.
a
=
a
self
.
mode
=
mode
self
.
mode
=
mode
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
def
_initialize
(
self
,
arr
):
def
_initialize
(
self
,
arr
):
tmp
=
kaiming_uniform_
(
arr
,
self
.
a
,
self
.
mode
,
self
.
nonlinearity
)
tmp
=
kaiming_uniform_
(
arr
,
self
.
a
,
self
.
mode
,
self
.
nonlinearity
)
init
.
_assignment
(
arr
,
tmp
)
assignment
(
arr
,
tmp
)
class
KaimingNormal
(
init
.
Initializer
):
class
KaimingNormal
(
init
.
Initializer
):
def
__init__
(
self
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
def
__init__
(
self
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
...
@@ -152,33 +161,32 @@ class KaimingNormal(init.Initializer):
...
@@ -152,33 +161,32 @@ class KaimingNormal(init.Initializer):
self
.
a
=
a
self
.
a
=
a
self
.
mode
=
mode
self
.
mode
=
mode
self
.
nonlinearity
=
nonlinearity
self
.
nonlinearity
=
nonlinearity
def
_initialize
(
self
,
arr
):
def
_initialize
(
self
,
arr
):
tmp
=
kaiming_normal_
(
arr
,
self
.
a
,
self
.
mode
,
self
.
nonlinearity
)
tmp
=
kaiming_normal_
(
arr
,
self
.
a
,
self
.
mode
,
self
.
nonlinearity
)
init
.
_
assignment
(
arr
,
tmp
)
assignment
(
arr
,
tmp
)
def
default_recurisive_init
(
custom_cell
):
def
default_recurisive_init
(
custom_cell
):
"""weight init for conv2d and dense"""
"""weight init for conv2d and dense"""
for
name
,
cell
in
custom_cell
.
cells_and_names
():
for
_
,
cell
in
custom_cell
.
cells_and_names
():
if
isinstance
(
cell
,
nn
.
Conv2d
):
if
isinstance
(
cell
,
nn
.
Conv2d
):
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingUniform
(
a
=
math
.
sqrt
(
5
)),
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingUniform
(
a
=
math
.
sqrt
(
5
)),
cell
.
weight
.
default_input
.
shape
(),
cell
.
weight
.
default_input
.
shape
(),
cell
.
weight
.
default_input
.
dtype
())
cell
.
weight
.
default_input
.
dtype
())
if
cell
.
bias
is
not
None
:
if
cell
.
bias
is
not
None
:
fan_in
,
_
=
_calculate_fan_in_and_fan_out
(
cell
.
weight
.
default_input
.
asnumpy
())
fan_in
,
_
=
_calculate_fan_in_and_fan_out
(
cell
.
weight
.
default_input
.
asnumpy
())
bound
=
1
/
math
.
sqrt
(
fan_in
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
cell
.
bias
.
default_input
=
Tensor
(
np
.
random
.
uniform
(
-
bound
,
bound
,
cell
.
bias
.
default_input
=
Tensor
(
np
.
random
.
uniform
(
-
bound
,
bound
,
cell
.
bias
.
default_input
.
shape
()),
cell
.
bias
.
default_input
.
shape
()),
cell
.
bias
.
default_input
.
dtype
())
cell
.
bias
.
default_input
.
dtype
())
elif
isinstance
(
cell
,
nn
.
Dense
):
elif
isinstance
(
cell
,
nn
.
Dense
):
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingUniform
(
a
=
math
.
sqrt
(
5
)),
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingUniform
(
a
=
math
.
sqrt
(
5
)),
cell
.
weight
.
default_input
.
shape
(),
cell
.
weight
.
default_input
.
shape
(),
cell
.
weight
.
default_input
.
dtype
())
cell
.
weight
.
default_input
.
dtype
())
if
cell
.
bias
is
not
None
:
if
cell
.
bias
is
not
None
:
fan_in
,
_
=
_calculate_fan_in_and_fan_out
(
cell
.
weight
.
default_input
.
asnumpy
())
fan_in
,
_
=
_calculate_fan_in_and_fan_out
(
cell
.
weight
.
default_input
.
asnumpy
())
bound
=
1
/
math
.
sqrt
(
fan_in
)
bound
=
1
/
math
.
sqrt
(
fan_in
)
cell
.
bias
.
default_input
=
Tensor
(
np
.
random
.
uniform
(
-
bound
,
bound
,
cell
.
bias
.
default_input
=
Tensor
(
np
.
random
.
uniform
(
-
bound
,
bound
,
cell
.
bias
.
default_input
.
shape
()),
cell
.
bias
.
default_input
.
shape
()),
cell
.
bias
.
default_input
.
dtype
())
cell
.
bias
.
default_input
.
dtype
())
elif
isinstance
(
cell
,
(
nn
.
BatchNorm2d
,
nn
.
BatchNorm1d
)):
elif
isinstance
(
cell
,
(
nn
.
BatchNorm2d
,
nn
.
BatchNorm1d
)):
pass
pass
mindspore/model_zoo/resnet.py
浏览文件 @
3cb692be
...
@@ -279,4 +279,4 @@ def resnet101(class_num=1001):
...
@@ -279,4 +279,4 @@ def resnet101(class_num=1001):
[
64
,
256
,
512
,
1024
],
[
64
,
256
,
512
,
1024
],
[
256
,
512
,
1024
,
2048
],
[
256
,
512
,
1024
,
2048
],
[
1
,
2
,
2
,
2
],
[
1
,
2
,
2
,
2
],
class_num
)
class_num
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录