提交 aedd6de6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4927 fix bug for identity

Merge pull request !4927 from flywind/pynative_identity
......@@ -853,8 +853,15 @@ class Cell:
self.add_flags_recursive(**flags)
return self
def set_grad(self, mode=True):
self.requires_grad = mode
def set_grad(self, requires_grad=True):
"""
Sets the cell flag for gradient.
Args:
requires_grad (bool): Specifies if the net need to grad, if it is
True, cell will construct backward network in pynative mode. Default: True.
"""
self.requires_grad = requires_grad
return self
def set_train(self, mode=True):
......
......@@ -82,6 +82,7 @@ pack = P.Pack()
partial = P.Partial()
# depend: mount a node to another node
depend = P.Depend()
identity = P.identity()
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem')
......@@ -135,7 +136,6 @@ broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
zeros_like = P.ZerosLike()
identity = Primitive('identity')
distribute = Primitive('distribute')
embed = Primitive('embed')
ref_to_embed = _grad_ops.RefToEmbed()
......
......@@ -83,7 +83,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
CheckValid, MakeRefKey, Partial, Depend, identity, CheckBprop, Push, Pull)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
......@@ -268,6 +268,7 @@ __all__ = [
'MakeRefKey',
'Partial',
'Depend',
'identity',
'AvgPool',
# Back Primitive
'Equal',
......
......@@ -560,3 +560,21 @@ class Pull(PrimitiveWithInfer):
def infer_dtype(self, key_dtype, weight_dtype):
return mstype.float32
class identity(Primitive):
"""
Make a identify primitive, used for pynative mode.
Inputs:
- **x** (Any) - identity input value.
Outputs:
The same as input.
"""
@prim_attr_register
def __init__(self):
pass
def __call__(self, x):
return x
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册