提交 0555bdb0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!701 upload resnet101 scripts

Merge pull request !701 from 梅晓蔚/master
# ResNet101 Example
## Description
This is an example of training ResNet101 with ImageNet dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset [ImageNet](http://image-net.org/download).
> Unzip the ImageNet dataset to any path you want, the folder should include train and eval dataset as follows:
```
.
└─dataset
├─ilsvrc
└─validation_preprocess
```
## Example structure
```shell
.
├── crossentropy.py # CrossEntropy loss function
├── var_init.py # weight initial
├── config.py # parameter configuration
├── dataset.py # data preprocessing
├── eval.py # eval net
├── lr_generator.py # generate learning rate
├── run_distribute_train.sh # launch distributed training(8p)
├── run_infer.sh # launch evaluating
├── run_standalone_train.sh # launch standalone training(1p)
└── train.py # train net
```
## Parameter configuration
Parameters for both training and evaluating can be set in config.py.
```
"class_num": 1001, # dataset class number
"batch_size": 32, # batch size of input tensor
"loss_scale": 1024, # loss scale
"momentum": 0.9, # momentum optimizer
"weight_decay": 1e-4, # weight decay
"epoch_size": 120, # epoch sizes for training
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_width": 224, # image width
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 500, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
"keep_checkpoint_max": 40, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint relative to the executed path
"warmup_epochs": 0, # number of warmup epoch
"lr_decay_mode": "cosine" # decay mode for generating learning rate
"label_smooth": 1, # label_smooth
"label_smooth_factor": 0.1, # label_smooth_factor
"lr": 0.1 # base learning rate
```
## Running the example
### Train
#### Usage
```
# distributed training
sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# standalone training
sh run_standalone_train.sh [DATASET_PATH]
```
#### Launch
```bash
# distributed training example(8p)
sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc
# standalone training example(1p)
sh run_standalone_train.sh dataset/ilsvrc
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
#### Result
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.
```
# distribute training result(8p)
epoch: 1 step: 5004, loss is 4.805483
epoch: 2 step: 5004, loss is 3.2121816
epoch: 3 step: 5004, loss is 3.429647
epoch: 4 step: 5004, loss is 3.3667371
epoch: 5 step: 5004, loss is 3.1718972
...
epoch: 67 step: 5004, loss is 2.2768745
epoch: 68 step: 5004, loss is 1.7223864
epoch: 69 step: 5004, loss is 2.0665488
epoch: 70 step: 5004, loss is 1.8717369
...
```
### Infer
#### Usage
```
# infer
sh run_infer.sh [VALIDATION_DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch
```bash
# infer with checkpoint
sh run_infer.sh dataset/validation_preprocess/ train_parallel0/resnet-120_5004.ckpt
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the example path, whose folder name is "infer". Under this, you can find result like the followings in log.
```
result: {'top_5_accuracy': 0.9429417413572343, 'top_1_accuracy': 0.7853513124199744} ckpt=train_parallel0/resnet-120_5004.ckpt
```
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 120,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True,
"save_checkpoint_steps": 500,
"keep_checkpoint_max": 40,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"label_smooth": 1,
"label_smooth_factor": 0.1,
"lr": 0.1
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from config import config
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
create a train or evaluate dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
resize_height = 224
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
decode_op = C.Decode()
random_resize_crop_op = C.RandomResizedCrop(resize_height, (0.08, 1.0), (0.75, 1.33), max_attempts=100)
horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1))
resize_op_256 = C.Resize((256, 256))
center_crop = C.CenterCrop(224)
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize((0.475, 0.451, 0.392), (0.275, 0.267, 0.278))
changeswap_op = C.HWC2CHW()
trans = []
if do_train:
trans = [decode_op,
random_resize_crop_op,
horizontal_flip_op,
rescale_op,
normalize_op,
changeswap_op]
else:
trans = [decode_op,
resize_op_256,
center_crop,
rescale_op,
normalize_op,
changeswap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans)
ds = ds.map(input_columns="label", operations=type_cast_op)
# apply shuffle operations
ds = ds.shuffle(buffer_size=config.buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
import random
import numpy as np
from dataset import create_dataset
from config import config
from mindspore import context
from mindspore.model_zoo.resnet import resnet101
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
from mindspore.communication.management import init
from crossentropy import CrossEntropy
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=False, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=True, help='Do eval or not.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num)
if not config.label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_eval:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch):
"""
generate learning rate array with cosine
Args:
lr(float): base learning rate
steps_per_epoch(int): steps size of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epochs of training
Returns:
np.array, learning rate array
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = base_lr * decayed
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: DMINDSPORE_HCCL_CONFIG_PATH=$1 is not a file"
exit 1
fi
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$1
export RANK_TABLE_FILE=$1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cp *.sh ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 &> log &
cd ..
done
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_infer.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
if [ ! -f $2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "infer" ];
then
rm -rf ./infer
fi
mkdir ./infer
cp *.py ./infer
cp *.sh ./infer
cd ./infer || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py --do_eval=True --dataset_path=$1 --checkpoint_path=$2 &> log &
cd ..
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
exit 1
fi
if [ ! -d $1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp *.py ./train
cp *.sh ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --do_train=True --dataset_path=$1 &> log &
cd ..
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
import math
import argparse
import random
import numpy as np
from dataset import create_dataset
from lr_generator import warmup_cosine_annealing_lr
from config import config
from mindspore import context
from mindspore import Tensor
from mindspore.model_zoo.resnet import resnet101
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.dataset.engine as de
from mindspore.communication.management import init
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from crossentropy import CrossEntropy
from var_init import default_recurisive_init, KaimingNormal
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
epoch_size = config.epoch_size
net = resnet101(class_num=config.class_num)
# weight init
default_recurisive_init(net)
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(KaimingNormal(a=math.sqrt(5),
mode='fan_out', nonlinearity='relu'),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if not config.label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# learning rate strategy with cosine
lr = Tensor(warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', keep_batchnorm_fp32=False,
loss_scale_manager=loss_scale, metrics={'acc'})
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""weight initial"""
import math
import numpy as np
from mindspore.common import initializer as init
import mindspore.nn as nn
from mindspore import Tensor
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
gain = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
gain = 1
elif nonlinearity == 'tanh':
gain = 5.0 / 3
elif nonlinearity == 'relu':
gain = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
gain = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return gain
def _calculate_correct_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
array: an n-dimensional `tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
fan = _calculate_correct_fan(array, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, array.shape)
def kaiming_normal_(array, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
normal distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Args:
array: an n-dimensional `tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
fan = _calculate_correct_fan(array, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, array.shape)
def _calculate_fan_in_and_fan_out(array):
"""calculate the fan_in and fan_out for input array"""
dimensions = len(array.shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
num_input_fmaps = array.shape[1]
num_output_fmaps = array.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = array[0][0].size
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def assignment(arr, num):
"""Assign the value of num to arr"""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
class KaimingUniform(init.Initializer):
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingUniform, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
assignment(arr, tmp)
class KaimingNormal(init.Initializer):
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingNormal, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_normal_(arr, self.a, self.mode, self.nonlinearity)
assignment(arr, tmp)
def default_recurisive_init(custom_cell):
"""weight init for conv2d and dense"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
cell.bias.default_input.shape()),
cell.bias.default_input.dtype())
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape(),
cell.weight.default_input.dtype())
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound,
cell.bias.default_input.shape()),
cell.bias.default_input.dtype())
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
......@@ -260,3 +260,23 @@ def resnet50(class_num=10):
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def resnet101(class_num=1001):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册