提交 b8e25c38 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1315 for second order codes

Merge pull request !1315 from zongha/master
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"class_num": 1000,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 50,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"save_checkpoint": True,
"save_checkpoint_steps": 5004,
"keep_checkpoint_max": 20,
"save_checkpoint_path": "./",
"lr_init": 0.01,
"lr_end": 0.00001,
"lr_max": 0.1,
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"label_smooth": 1,
"label_smooth_factor": 0.1,
"lr": 0.1,
"T_max": 90,
"eta_min": 0,
"frequency": 278
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CrossEntropy"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropy(_Loss):
"""CrossEntropy"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
# self.cast = P.Cast()
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
# one_hot_label = self.onehot(self.cast(label, mstype.int32),
# F.shape(logit)[1], self.on_value, self.off_value)、
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.c_transforms as V_C
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
Returns:
dataset
"""
device_num = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=False)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if do_train:
transform_img = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = [
V_C.Decode(),
V_C.Resize((256, 256)),
V_C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
# type_cast_op = C2.TypeCast(mstype.float16)
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=transform_img, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply shuffle operations
# ds = ds.shuffle(buffer_size=config.buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""linear_warmup_lr"""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5):
"""linear_warmup_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
# linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * i / decay_steps))
decayed = cosine_decay
lr = base_lr * decayed
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5):
"""warmup_cosine_annealing_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch * 0.99)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * num_periods * i / decay_steps))
decayed = linear_decay * cosine_decay
lr = base_lr * decayed + 0.000005
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
from mindspore import context
from mindspore._checkparam import check_bool
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list, _to_full_shapes, _to_full_tensor
from mindspore.train.parallel_utils import ParallelMode
class DatasetHelper:
"""
Help function to use the Minddata dataset.
According to different context, change the iter of dataset, to use the same for loop in different context.
Note:
The iter of DatasetHelper will give one epoch data.
Args:
dataset (DataSet): The dataset.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
Default: True.
Examples:
>>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper:
>>> outputs = network(*inputs)
"""
def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True):
check_bool(dataset_sink_mode)
iterclass = _DatasetIterGE
if not dataset_sink_mode:
iterclass = _DatasetIterFeed
elif not context.get_context("enable_ge"):
if context.get_context("enable_loop_sink"):
iterclass = _DatasetIterMSLoopSink
else:
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, first_order_iter)
def __iter__(self):
return self.iter.__iter__()
# A temp solution for loop sink. Delete later
def types_shapes(self):
"""Get the types and shapes from dataset on current config."""
return self.iter.types_shapes()
def loop_size(self):
"""Get loop_size for every iteration."""
return self.iter.loop_size
class _DatasetIter:
"""Base iter for dataset help"""
def __init__(self, dataset):
self.loop_size = 1
if not hasattr(dataset, '__ME_INITED__'):
if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
self.ind = 0
self.dataset = dataset
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
def __iter__(self):
self.ind = 0
return self
def __next__(self):
if self.ind >= self.loop_count:
raise StopIteration()
self.ind += 1
return self.op()
def types_shapes(self):
return self.dataset_types, self.dataset_shapes
def get_loop_count(self, dataset):
loop_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
loop_count = int(dataset.get_dataset_size() / loop_size)
return loop_count
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (enable_loop_sink=True)"""
def __init__(self, dataset, first_order_iter):
super(_DatasetIterMSLoopSink, self).__init__(dataset)
# self.loop_count = self.get_loop_count(dataset)
loop_size = dataset.__loop_size__ + first_order_iter
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
def op():
return tuple()
self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for context (enable_loop_sink=False)"""
def __init__(self, dataset, first_order_order):
super(_DatasetIterMS, self).__init__(dataset)
self.loop_count = dataset.get_dataset_size()
self.loop_size = 1
queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
class _DatasetIterGE(_DatasetIter):
"""Iter for ge"""
def __init__(self, dataset):
super(_DatasetIterGE, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
parallel_mode = _get_parallel_mode()
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
batch_expand_num = 1
if self.need_to_full:
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
return tensor_list_run
self.op = op
class _DatasetIterFeed:
"""Iter for feed data"""
def __init__(self, dataset, first_order_order):
self.dataset = dataset
self.device_num = _get_device_num()
self.global_rank = _get_global_rank()
self.repeat_count = dataset.get_repeat_count()
self.repeat_ind = 0
self.loop_count = dataset.get_dataset_size()
self.ind = 0
parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
def __iter__(self):
if self.repeat_ind % self.repeat_count == 0:
self.iter = self.dataset.__iter__()
self.repeat_ind += 1
self.ind = 0
return self
def __next__(self):
if self.ind >= self.loop_count:
raise StopIteration()
self.ind += 1
data = self.iter.__next__()
if self.need_to_full:
return _to_full_tensor(data, self.device_num, self.global_rank)
return _to_tensor(data)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad_reducer_thor"""
import mindspore.common.dtype as mstype
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.nn.cell import Cell
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
_all_reduce_A = AllReduce()
def _init_optimizer_allreduce(group):
global _all_reduce_A
_all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
_all_reduce_A.add_prim_attr('fusion', group)
@reduce_opt.register("Function", "Number", "Tensor")
def _tensors_allreduce_mean(mul, degree, grad):
degree = F.scalar_cast(degree, F.dtype(grad))
grad = _all_reduce_A(grad)
cast_op = P.Cast()
return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
@reduce_opt.register("Bool", "Tensor")
def _tensors_allreduce(allreduce_filter, grad):
if allreduce_filter:
return _all_reduce_A(grad)
return grad
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
@_get_datatype.register("Tensor")
def _tensors_get_datatype(grad):
"""
Acquire gradient datatype.
Args:
grad (Tensor): The gradient tensor before operation.
Returns:
mstype, the datatype of gradient.
"""
return F.dtype(grad)
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
@_cast_datatype.register("TypeType", "Tensor")
def _tensors_cast_datatype(datatype, grad):
"""
Cast gradient to datatype.
Args:
datatype (mstype): the destination datatype of gradient.
grad (Tensor): The gradient tensor before operation.
Returns:
Tensor, the gradient tensor after operation.
"""
return F.cast(grad, datatype)
class DistributedGradReducerThor(Cell):
"""
A distributed optimizer.
Constructs a gradient reducer Cell, which applies communication and average operations on
single-process gradient values.
Args:
parameters (list): the parameters to be updated.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
Raises:
ValueError: If degree is not a int or less than 0.
Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("mirror_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
"""
def __init__(self, parameters, group, mean=True, degree=None):
super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap()
self.mul = P.Mul()
if degree is None:
self.degree = get_group_size()
else:
if not isinstance(degree, int) or degree <= 0:
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
self.degree = degree
self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
_init_optimizer_allreduce(group)
def construct(self, grads):
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
if self.mean:
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
else:
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import mindspore.nn as nn
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from second_order.thor_layer import Conv2d_Thor, Dense_Thor
def calculate_gain(nonlinearity, param=None):
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _calculate_fan_in_and_fan_out(tensor):
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
# return nn.Conv2d(in_channel, out_channel,
# kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, damping, loss_scale, frequency):
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1,
damping=0.03,
loss_scale=1,
frequency=278):
super(ResidualBlock, self).__init__()
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency)
self.bn1 = _bn(channel)
self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
frequency=frequency)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency)
self.bn3 = _bn_last(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
damping=damping, loss_scale=loss_scale,
frequency=frequency),
_bn(out_channel)])
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
damping,
loss_scale,
frequency):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2], damping=damping,
loss_scale=loss_scale,
frequency=frequency)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
damping, loss_scale, frequency):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride,
damping=damping, loss_scale=loss_scale, frequency=frequency)
layers.append(resnet_block)
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1,
damping=damping, loss_scale=loss_scale, frequency=frequency)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1, argmax = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
damping,
loss_scale,
frequency)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""momentum"""
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from cus_ops.cus_matmul_cube_dense_right import CusMatMulCubeDenseRight
from cus_ops.cus_matmul_cube_fracz_left_cast import CusMatMulCubeFraczLeftCast
from cus_ops.cus_matmul_cube_dense_left import CusMatMulCubeDenseLeft
from cus_ops.cus_matmul_cube_fracz_right_mul import CusMatMulCubeFraczRightMul
from model.grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success
op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient
class THOR(Optimizer):
"""THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
loss_scale=1.0,
decay_filter=lambda x: x.name not in []):
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.cube_matmul_left = CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.weight_idx = []
for i in range(len(self.params)):
if "conv" in self.params[i].name or "end_point" in self.params[i].name:
self.weight_idx.append(i)
self.weight_idx.append(len(self.params))
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree)
self.matrix_A_inv = ()
self.matrix_G_inv = ()
self.matrix_max_inv = ()
for i in range(54):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
def construct(self, gradients):
params = self.params
moments = self.moments
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
matrix_A_max_allreduce = ()
matrix_G_max_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
A_max = self.A_inv_max[i]
G_max = self.G_inv_max[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
new_grads = ()
for i in range(54):
g = gradients[i * 3]
temp_a = matrix_A_allreduce[i]
temp_g = matrix_G_allreduce[i]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
temp_max = self.mul(temp_max, self.feature_map[i])
if i == 53:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
else:
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
fake_A = self.assign(self.matrix_A[i], temp_a)
fake_G = self.assign(self.matrix_G[i], temp_g)
fake_max = self.assign(self.matrix_max_inv[i], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
if i == 53:
new_grads = new_grads + (g,)
else:
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
else:
new_grads = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_max = self.matrix_max_inv[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
matrix_max = F.depend(matrix_max, g)
if i == 53:
g = self.cube_matmul_left_fc(matrix_G, g)
g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
new_grads = new_grads + (g,)
else:
g = self.cube_matmul_left(matrix_G, g)
g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
return success
此差异已折叠。
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [DEVICE_NUM]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error: DMINDSPORE_HCCL_CONFIG_PATH=$1 is not a file"
exit 1
fi
if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=$3
export RANK_SIZE=$3
export MINDSPORE_HCCL_CONFIG_PATH=$1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r second_order ./train_parallel$i/second_order
cp -r test_ops ./train_parallel$i/test_ops
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train_0517_1.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
cd ..
done
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import argparse
import os
import random
import mindspore.dataset.engine as de
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode
from second_order.model_second_order import Model
from second_order.resnet import resnet50
from second_order.thor import THOR
import numpy as np
from config_imagenet import config
from crossentropy import CrossEntropy
from dataset_imagenet import create_dataset
from lr_generator import warmup_cosine_annealing_lr
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.')
parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
def get_second_order_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
"""get_second_order_lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
for i in range(total_steps):
epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base
lr_each_step.append(lr_local)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
print("learning_rate_is=====", lr_each_step)
learning_rate = lr_each_step[current_step:]
return learning_rate
def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
"""get_second_order_damping"""
damping_each_step = []
total_steps = steps_per_epoch * total_epochs
for step in range(total_steps):
epoch = (step + 1) / steps_per_epoch
damping_here = damping_init * (decay_rate ** (epoch / 10))
damping_each_step.append(damping_here)
current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:]
print("damping_is=========", damping_now)
return damping_now
if __name__ == '__main__':
if args_opt.do_eval:
print("eval")
else:
if args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([80], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
init()
else:
print(" ")
epoch_size = config.epoch_size
damping = get_second_order_damping(0, 0.03, 0.87, 50, 5004)
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
frequency=config.frequency)
if not config.label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(warmup_cosine_annealing_lr(0.035,
step_size,
config.warmup_epochs,
50,
config.T_max,
config.eta_min))
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
config.momentum, damping, config.frequency,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'spatial_norm' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,
keep_batchnorm_fp32=False, metrics={'acc'}, frequency=config.frequency)
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="resnet", directory=config.save_checkpoint_path, config=config_ck)
cb += [ckpt_cb]
model.train(epoch_size, dataset, callbacks=cb)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册