import argparse import os import mindspore.nn as nn from mindspore import ParameterTuple from mindspore import context from mindspore.nn import Cell from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.train.dataset_helper import DatasetHelper from mindspore.train.serialization import save_checkpoint from model_zoo.official.cv.lenet.src.dataset import create_dataset from model_zoo.official.cv.lenet.src.lenet import LeNet5 _sum_op = C.MultitypeFuncGraph("grad_sum_op") _clear_op = C.MultitypeFuncGraph("clear_op") @_sum_op.register("Tensor", "Tensor") def _cumulative_gard(grad_sum, grad): """Apply gard sum to cumulative gradient.""" add = P.AssignAdd() return add(grad_sum, grad) @_clear_op.register("Tensor", "Tensor") def _clear_grad_sum(grad_sum, zero): """Apply zero to clear grad_sum.""" success = True success = F.depend(success, F.assign(grad_sum, zero)) return success class TrainForwardBackward(Cell): def __init__(self, network, optimizer, grad_sum, sens=1.0): super(TrainForwardBackward, self).__init__(auto_prefix=False) self.network = network self.network.set_grad() self.network.add_flags(defer_inline=True) self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer self.grad_sum = grad_sum self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.sens = sens self.hyper_map = C.HyperMap() def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)) class TrainOptim(Cell): def __init__(self, optimizer, grad_sum): super(TrainOptim, self).__init__(auto_prefix=False) self.optimizer = optimizer self.grad_sum = grad_sum def construct(self): return self.optimizer(self.grad_sum) class TrainClear(Cell): def __init__(self, grad_sum, zeros): super(TrainClear, self).__init__(auto_prefix=False) self.grad_sum = grad_sum self.zeros = zeros self.hyper_map = C.HyperMap() def construct(self): seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros) return seccess class GradientAccumulation: def __init__(self, network, loss_fn, optimizer): self._network = network self._loss_fn = loss_fn self._optimizer = optimizer params = self._optimizer.parameters self._grad_sum = params.clone(prefix="grad_sum", init='zeros') self._zeros = params.clone(prefix="zeros", init='zeros') self._train_forward_backward = self._build_train_forward_backward_network() self._train_optim = self._build_train_optim() self._train_clear = self._build_train_clear() def _build_train_forward_backward_network(self): """Build forward and backward network""" network = self._network network = nn.WithLossCell(network, self._loss_fn) loss_scale = 1.0 network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train() return network def _build_train_optim(self): """Build optimizer network""" network = TrainOptim(self._optimizer, self._grad_sum).set_train() return network def _build_train_clear(self): """Build clear network""" network = TrainClear(self._grad_sum, self._zeros).set_train() return network def train_process(self, epoch, train_dataset, mini_steps=None): """ Training process. The data would be passed to network directly. """ dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) for i in range(epoch): step = 0 for k, next_element in enumerate(dataset_helper): loss = self._train_forward_backward(*next_element) if (k + 1) % mini_steps == 0: step += 1 print("epoch:", i + 1, "step:", step, "loss is ", loss) self._train_optim() self._train_clear() train_dataset.reset() save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", ) if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore Gard Cumulative Example') parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./Data", help='path where the dataset is saved') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) ds_train = create_dataset(os.path.join(args.data_path, "train"), 32) network = LeNet5(10) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) model = GradientAccumulation(network, net_loss, net_opt) print("============== Starting Training ==============") model.train_process(10, ds_train, mini_steps=4)