diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9827975fd0539c8160fd3215b7a8f09fe9d4033f..55bf88fcd8240d8f31215153d53ceee0654a35c1 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2488,6 +2488,27 @@ class LARSUpdate(PrimitiveWithInfer): Outputs: Tensor, representing the new gradient. + + Examples: + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> from mindspore.ops import functional as F + >>> import mindspore.nn as nn + >>> import numpy as np + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.lars = P.LARSUpdate() + >>> self.reduce = P.ReduceSum() + >>> def construct(self, weight, gradient): + >>> w_square_sum = self.reduce(F.square(weight)) + >>> grad_square_sum = self.reduce(F.square(gradient)) + >>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0) + >>> return grad_t + >>> weight = np.random.random(size=(2, 3)).astype(np.float32) + >>> gradient = np.random.random(size=(2, 3)).astype(np.float32) + >>> net = Net() + >>> ms_output = net(Tensor(weight), Tensor(gradient)) """ @prim_attr_register