提交 99bbb3a3 编写于 作者: M meixiaowei

modify scripts for pylint

上级 f1cec60d
...@@ -12,15 +12,16 @@ ...@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
"""define loss function for network"""
class CrossEntropy(_Loss): class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001): def __init__(self, smooth_factor=0., num_classes=1001):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
self.onehot = P.OneHot() self.onehot = P.OneHot()
...@@ -28,7 +29,6 @@ class CrossEntropy(_Loss): ...@@ -28,7 +29,6 @@ class CrossEntropy(_Loss):
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32) self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits() self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False) self.mean = P.ReduceMean(False)
def construct(self, logit, label): def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label) loss = self.ce(logit, one_hot_label)
......
...@@ -57,7 +57,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): ...@@ -57,7 +57,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
normalize_op = C.Normalize((0.475, 0.451, 0.392), (0.275, 0.267, 0.278)) normalize_op = C.Normalize((0.475, 0.451, 0.392), (0.275, 0.267, 0.278))
changeswap_op = C.HWC2CHW() changeswap_op = C.HWC2CHW()
trans=[] trans = []
if do_train: if do_train:
trans = [decode_op, trans = [decode_op,
random_resize_crop_op, random_resize_crop_op,
......
...@@ -13,9 +13,8 @@ ...@@ -13,9 +13,8 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""learning rate generator""" """learning rate generator"""
import numpy as np
import math import math
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
...@@ -50,7 +49,7 @@ def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch): ...@@ -50,7 +49,7 @@ def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch):
decayed = linear_decay * cosine_decay + 0.00001 decayed = linear_decay * cosine_decay + 0.00001
lr = base_lr * decayed lr = base_lr * decayed
lr_each_step.append(lr) lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32) return np.array(lr_each_step).astype(np.float32)
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
""" """
......
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
# ============================================================================ # ============================================================================
"""train_imagenet.""" """train_imagenet."""
import os import os
import math
import argparse import argparse
import random import random
import numpy as np import numpy as np
from dataset import create_dataset from dataset import create_dataset
from lr_generator import get_lr from lr_generator import get_lr, warmup_cosine_annealing_lr
from config import config from config import config
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
...@@ -33,7 +34,7 @@ from mindspore.communication.management import init ...@@ -33,7 +34,7 @@ from mindspore.communication.management import init
import mindspore.nn as nn import mindspore.nn as nn
from crossentropy import CrossEntropy from crossentropy import CrossEntropy
from var_init import default_recurisive_init, KaimingNormal from var_init import default_recurisive_init, KaimingNormal
from mindspore.common import initializer as weight_init import mindspore.common.initializer as weight_init
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -69,23 +70,20 @@ if __name__ == '__main__': ...@@ -69,23 +70,20 @@ if __name__ == '__main__':
epoch_size = config.epoch_size epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num) net = resnet101(class_num=config.class_num)
# weight init # weight init
default_recurisive_init(net) default_recurisive_init(net)
for name, cell in net.cells_and_names(): for name, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(KaimingNormal(a=math.sqrt(5), cell.weight.default_input = weight_init.initializer(KaimingNormal(a=math.sqrt(5),
mode='fan_out', nonlinearity='relu'), mode='fan_out', nonlinearity='relu'),
cell.weight.default_input.shape(), cell.weight.default_input.shape(),
cell.weight.default_input.dtype()) cell.weight.default_input.dtype())
if not config.label_smooth: if not config.label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train: if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size) repeat_num=epoch_size, batch_size=config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
...@@ -96,12 +94,10 @@ if __name__ == '__main__': ...@@ -96,12 +94,10 @@ if __name__ == '__main__':
lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size, warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size,
lr_decay_mode='poly')) lr_decay_mode='poly'))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale) config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', keep_batchnorm_fp32=False,
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', keep_batchnorm_fp32=False, loss_scale_manager=loss_scale, metrics={'acc'}) loss_scale_manager=loss_scale, metrics={'acc'})
time_cb = TimeMonitor(data_size=step_size) time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor() loss_cb = LossMonitor()
cb = [time_cb, loss_cb] cb = [time_cb, loss_cb]
......
...@@ -18,12 +18,10 @@ import numpy as np ...@@ -18,12 +18,10 @@ import numpy as np
from mindspore.common import initializer as init from mindspore.common import initializer as init
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
def calculate_gain(nonlinearity, param=None): def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function. r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows: The values are as follows:
================= ==================================================== ================= ====================================================
nonlinearity gain nonlinearity gain
================= ==================================================== ================= ====================================================
...@@ -34,11 +32,9 @@ def calculate_gain(nonlinearity, param=None): ...@@ -34,11 +32,9 @@ def calculate_gain(nonlinearity, param=None):
ReLU :math:`\sqrt{2}` ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ==================================================== ================= ====================================================
Args: Args:
nonlinearity: the non-linear function (`nn.functional` name) nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function param: optional parameter for the non-linear function
""" """
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid': if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
...@@ -57,17 +53,15 @@ def calculate_gain(nonlinearity, param=None): ...@@ -57,17 +53,15 @@ def calculate_gain(nonlinearity, param=None):
raise ValueError("negative_slope {} not a valid number".format(param)) raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2)) return math.sqrt(2.0 / (1 + negative_slope ** 2))
else: else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _calculate_correct_fan(array, mode): def _calculate_correct_fan(array, mode):
mode = mode.lower() mode = mode.lower()
valid_modes = ['fan_in', 'fan_out'] valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes: if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array) fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method r"""Fills the input `Tensor` with values according to the method
...@@ -75,12 +69,10 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -75,12 +69,10 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
performance on ImageNet classification` - He, K. et al. (2015), using a performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math:: .. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization. Also known as He initialization.
Args: Args:
array: an n-dimensional `tensor` array: an n-dimensional `tensor`
a: the negative slope of the rectifier used after this layer (only a: the negative slope of the rectifier used after this layer (only
...@@ -91,8 +83,7 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -91,8 +83,7 @@ def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
backwards pass. backwards pass.
nonlinearity: the non-linear function (`nn.functional` name), nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
""" """
fan = _calculate_correct_fan(array, mode) fan = _calculate_correct_fan(array, mode)
gain = calculate_gain(nonlinearity, a) gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
...@@ -129,6 +120,7 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -129,6 +120,7 @@ def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
return np.random.normal(0, std, array.shape) return np.random.normal(0, std, array.shape)
def _calculate_fan_in_and_fan_out(array): def _calculate_fan_in_and_fan_out(array):
"""calculate the fan_in and fan_out for input array"""
dimensions = len(array.shape) dimensions = len(array.shape)
if dimensions < 2: if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
...@@ -166,18 +158,27 @@ class KaimingNormal(init.Initializer): ...@@ -166,18 +158,27 @@ class KaimingNormal(init.Initializer):
init._assignment(arr, tmp) init._assignment(arr, tmp)
def default_recurisive_init(custom_cell): def default_recurisive_init(custom_cell):
"""weight init for conv2d and dense"""
for name, cell in custom_cell.cells_and_names(): for name, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input.shape(), cell.weight.default_input.dtype()) cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape()), cell.bias.default_input.dtype()) cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
cell.bias.default_input.shape()),
cell.bias.default_input.dtype())
elif isinstance(cell, nn.Dense): elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)), cell.weight.default_input.shape(), cell.weight.default_input.dtype()) cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if cell.bias is not None: if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy()) fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape()), cell.bias.default_input.dtype()) cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
elif isinstance(cell, nn.BatchNorm2d) or isinstance(cell, nn.BatchNorm1d): cell.bias.default_input.shape()),
cell.bias.default_input.dtype())
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass pass
...@@ -279,5 +279,4 @@ def resnet101(class_num=1001): ...@@ -279,5 +279,4 @@ def resnet101(class_num=1001):
[64, 256, 512, 1024], [64, 256, 512, 1024],
[256, 512, 1024, 2048], [256, 512, 1024, 2048],
[1, 2, 2, 2], [1, 2, 2, 2],
class_num) class_num)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册