diff --git a/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md b/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md index 7dac737e2555610bbca0e8ea992d2fc8b204c484..88e8953754ec0b24949dad6a50a0f97c0871b141 100644 --- a/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md +++ b/tutorials/source_zh_cn/advanced_use/gradient_accumulation.md @@ -28,7 +28,7 @@ 最终目的是为了达到跟直接用N*Mini-batch数据训练几乎同样的效果。 -> 本教程用于GPU、Ascend 910 AI处理器。 +> 本教程用于GPU、Ascend 910 AI处理器, 你可以在这里下载完整的样例代码: ## 创建梯度累积模型 @@ -232,7 +232,7 @@ class TrainClear(Cell): ``` ### 定义训练过程 -- 每个Mini-batch通过正反向训练计算loss和梯度,通过mini_steps控制每次更新参数前的累加次数。达到累加次数后进行参数更新和 +每个Mini-batch通过正反向训练计算loss和梯度,通过mini_steps控制每次更新参数前的累加次数。达到累加次数后进行参数更新和 累加梯度变量清零。 ```python @@ -305,7 +305,7 @@ class GradientAccumulation: ```python if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore Gard Cumulative Example') - parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], + parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], help='device where the code will be implemented (default: Ascend)') parser.add_argument('--data_path', type=str, default="./Data", help='path where the dataset is saved') @@ -346,7 +346,7 @@ if __name__ == "__main__": 2. 查看保存的CheckPoint文件。 - 训练过程中保存了CheckPoint文件gradient_accumulation.ckpt,即模型文件。 + 训练过程中保存了CheckPoint文件`gradient_accumulation.ckpt`,即模型文件。 **验证模型** diff --git a/tutorials/tutorial_code/gradient_accumulation/train.py b/tutorials/tutorial_code/gradient_accumulation/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c99a6a876a4dd6ded366d4d2fa2664925ff917b2 --- /dev/null +++ b/tutorials/tutorial_code/gradient_accumulation/train.py @@ -0,0 +1,158 @@ +import argparse +import os +from collections.abc import Iterable + +import mindspore.nn as nn +from mindspore import ParameterTuple +from mindspore import context +from mindspore.nn import Cell +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.train.dataset_helper import DatasetHelper +from mindspore.train.serialization import _exec_save_checkpoint +from model_zoo.official.cv.lenet.src.dataset import create_dataset +from model_zoo.official.cv.lenet.src.lenet import LeNet5 + +_sum_op = C.MultitypeFuncGraph("grad_sum_op") +_clear_op = C.MultitypeFuncGraph("clear_op") + + +@_sum_op.register("Tensor", "Tensor") +def _cumulative_gard(grad_sum, grad): + """Apply gard sum to cumulative gradient.""" + add = P.AssignAdd() + return add(grad_sum, grad) + + +@_clear_op.register("Tensor", "Tensor") +def _clear_grad_sum(grad_sum, zero): + """Apply zero to clear grad_sum.""" + success = True + success = F.depend(success, F.assign(grad_sum, zero)) + return success + + +class TrainForwardBackward(Cell): + def __init__(self, network, optimizer, grad_sum, sens=1.0): + super(TrainForwardBackward, self).__init__(auto_prefix=False) + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad_sum = grad_sum + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.hyper_map = C.HyperMap() + + def construct(self, *inputs): + weights = self.weights + loss = self.network(*inputs) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*inputs, sens) + return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads)) + + +class TrainOptim(Cell): + def __init__(self, optimizer, grad_sum): + super(TrainOptim, self).__init__(auto_prefix=False) + self.optimizer = optimizer + self.grad_sum = grad_sum + + def construct(self): + return self.optimizer(self.grad_sum) + + +class TrainClear(Cell): + def __init__(self, grad_sum, zeros): + super(TrainClear, self).__init__(auto_prefix=False) + self.grad_sum = grad_sum + self.zeros = zeros + self.hyper_map = C.HyperMap() + + def construct(self): + seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros) + return seccess + + +class GradientAccumulation: + def __init__(self, network, loss_fn, optimizer): + self._network = network + self._loss_fn = loss_fn + self._optimizer = optimizer + + params = self._optimizer.parameters + self._grad_sum = params.clone(prefix="grad_sum", init='zeros') + self._zeros = params.clone(prefix="zeros", init='zeros') + self._train_forward_backward = self._build_train_forward_backward_network() + self._train_optim = self._build_train_optim() + self._train_clear = self._build_train_clear() + + @staticmethod + def _transform_callbacks(callbacks): + """Transform callback to a list.""" + if callbacks is None: + return [] + + if isinstance(callbacks, Iterable): + return list(callbacks) + + return [callbacks] + + def _build_train_forward_backward_network(self): + """Build forward and backward network""" + network = self._network + network = nn.WithLossCell(network, self._loss_fn) + loss_scale = 1.0 + network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train() + return network + + def _build_train_optim(self): + """Build optimizer network""" + network = TrainOptim(self._optimizer, self._grad_sum).set_train() + return network + + def _build_train_clear(self): + """Build clear network""" + network = TrainClear(self._grad_sum, self._zeros).set_train() + return network + + def train_process(self, epoch, train_dataset, mini_steps=None): + """ + Training process. The data would be passed to network directly. + """ + dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch) + + for i in range(epoch): + step = 0 + for k, next_element in enumerate(dataset_helper): + loss = self._train_forward_backward(*next_element) + if (k + 1) % mini_steps == 0: + step += 1 + print("epoch:", i + 1, "step:", step, "loss is ", loss) + self._train_optim() + self._train_clear() + + train_dataset.reset() + + _exec_save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='MindSpore Gard Cumulative Example') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'], + help='device where the code will be implemented (default: Ascend)') + parser.add_argument('--data_path', type=str, default="./Data", + help='path where the dataset is saved') + args = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + ds_train = create_dataset(os.path.join(args.data_path, "train"), 32) + + network = LeNet5(10) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + model = GradientAccumulation(network, net_loss, net_opt) + + print("============== Starting Training ==============") + model.train_process(10, ds_train, mini_steps=4)