提交 db9dd264 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!653 add LARSUpdate example in comments

Merge pull request !653 from gziyan/add_lars_example
......@@ -2478,6 +2478,27 @@ class LARSUpdate(PrimitiveWithInfer):
Outputs:
Tensor, representing the new gradient.
Examples:
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> import mindspore.nn as nn
>>> import numpy as np
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.lars = P.LARSUpdate()
>>> self.reduce = P.ReduceSum()
>>> def construct(self, weight, gradient):
>>> w_square_sum = self.reduce(F.square(weight))
>>> grad_square_sum = self.reduce(F.square(gradient))
>>> grad_t = self.lars(weight, gradient, w_square_sum, grad_square_sum, 0.0, 1.0)
>>> return grad_t
>>> weight = np.random.random(size=(2, 3)).astype(np.float32)
>>> gradient = np.random.random(size=(2, 3)).astype(np.float32)
>>> net = Net()
>>> ms_output = net(Tensor(weight), Tensor(gradient))
"""
@prim_attr_register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册