From bb527bc5cf544e80db0d0dd47d507648afff45e0 Mon Sep 17 00:00:00 2001 From: Ziyan Date: Fri, 24 Apr 2020 15:47:10 +0800 Subject: [PATCH] add LARSUpdate example --- mindspore/ops/operations/nn_ops.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9827975fd..55bf88fcd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2488,6 +2488,27 @@ class LARSUpdate(PrimitiveWithInfer): Outputs: Tensor, representing the new gradient. + + Examples: + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> from mindspore.ops import functional as F + >>> import mindspore.nn as nn + >>> import numpy as np + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.lars = P.LARSUpdate() + >>> self.reduce = P.ReduceSum() + >>> def construct(self, weight, gradient): + >>> w_square_sum = self.reduce(F.square(weight)) + >>> grad_square_sum = self.reduce(F.square(gradient)) + >>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0) + >>> return grad_t + >>> weight = np.random.random(size=(2, 3)).astype(np.float32) + >>> gradient = np.random.random(size=(2, 3)).astype(np.float32) + >>> net = Net() + >>> ms_output = net(Tensor(weight), Tensor(gradient)) """ @prim_attr_register -- GitLab