提交 81ce714b 编写于 作者: Z Ziyan

replace square and reducesum with squaresumall in lars

上级 18c94950
......@@ -30,9 +30,8 @@ lars_opt = C.MultitypeFuncGraph("lars_opt")
def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag):
"""Apply lars optimizer to the weight parameter."""
if lars_flag:
op_reduce = P.ReduceSum()
w_square_sum = op_reduce(F.square(weight))
grad_square_sum = op_reduce(F.square(gradient))
op_reduce_sum = P.SquareSumAll()
w_square_sum, grad_square_sum = op_reduce_sum(weight, gradient)
if decay_flag:
grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册