提交 1a54785f 编写于 作者: P panyifeng

remove name arg from gradoperation

上级 b5ed5466
...@@ -117,7 +117,7 @@ class WithGradCell(Cell): ...@@ -117,7 +117,7 @@ class WithGradCell(Cell):
self.network = network self.network = network
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=(sens is not None)) self.grad = C.GradOperation(get_by_list=True, sens_param=(sens is not None))
self.sens = sens self.sens = sens
if loss_fn is None: if loss_fn is None:
self.network_with_loss = network self.network_with_loss = network
...@@ -182,7 +182,7 @@ class TrainOneStepCell(Cell): ...@@ -182,7 +182,7 @@ class TrainOneStepCell(Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -269,7 +269,7 @@ class DistributedGradReducer(Cell): ...@@ -269,7 +269,7 @@ class DistributedGradReducer(Cell):
>>> self.network.add_flags(defer_inline=True) >>> self.network.add_flags(defer_inline=True)
>>> self.weights = optimizer.parameters >>> self.weights = optimizer.parameters
>>> self.optimizer = optimizer >>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens >>> self.sens = sens
>>> self.reducer_flag = False >>> self.reducer_flag = False
>>> self.grad_reducer = None >>> self.grad_reducer = None
......
...@@ -210,7 +210,7 @@ class TrainOneStepWithLossScaleCell(Cell): ...@@ -210,7 +210,7 @@ class TrainOneStepWithLossScaleCell(Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
if context.get_context("device_target") == "GPU": if context.get_context("device_target") == "GPU":
self.gpu_target = True self.gpu_target = True
......
...@@ -106,12 +106,11 @@ class GradOperation(GradOperation_): ...@@ -106,12 +106,11 @@ class GradOperation(GradOperation_):
a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False. a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
""" """
def __init__(self, name, def __init__(self, get_all=False, get_by_list=False, sens_param=False):
get_all=False, get_by_list=False, sens_param=False):
self.get_all = get_all self.get_all = get_all
self.get_by_list = get_by_list self.get_by_list = get_by_list
self.sens_param = sens_param self.sens_param = sens_param
GradOperation_.__init__(self, name, get_all, get_by_list, sens_param) GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param)
self.grad_fn = None self.grad_fn = None
self.fn = None self.fn = None
self.need_forward = False self.need_forward = False
...@@ -139,7 +138,7 @@ class GradOperation(GradOperation_): ...@@ -139,7 +138,7 @@ class GradOperation(GradOperation_):
fn.already_run = False fn.already_run = False
def __call__(self, fn, weights=None): def __call__(self, fn, weights=None):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
if self.grad_fn is None or self.fn != fn: if self.grad_fn is None or self.fn != fn:
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
if self.get_by_list: if self.get_by_list:
......
...@@ -216,7 +216,7 @@ class InsertGradientOf(PrimitiveWithInfer): ...@@ -216,7 +216,7 @@ class InsertGradientOf(PrimitiveWithInfer):
>>> return ret >>> return ret
>>> >>>
>>> clip = P.InsertGradientOf(clip_gradient) >>> clip = P.InsertGradientOf(clip_gradient)
>>> grad_all = C.GradOperation('get_all', get_all=True) >>> grad_all = C.GradOperation(get_all=True)
>>> def InsertGradientOfClipDemo(): >>> def InsertGradientOfClipDemo():
>>> def clip_test(x, y): >>> def clip_test(x, y):
>>> x = clip(x) >>> x = clip(x)
...@@ -268,7 +268,7 @@ class HookBackward(PrimitiveWithInfer): ...@@ -268,7 +268,7 @@ class HookBackward(PrimitiveWithInfer):
>>> def hook_fn(grad_out): >>> def hook_fn(grad_out):
>>> print(grad_out) >>> print(grad_out)
>>> >>>
>>> grad_all = GradOperation('get_all', get_all=True) >>> grad_all = GradOperation(get_all=True)
>>> hook = P.HookBackward(hook_fn) >>> hook = P.HookBackward(hook_fn)
>>> >>>
>>> def hook_test(x, y): >>> def hook_test(x, y):
......
...@@ -163,8 +163,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -163,8 +163,7 @@ class TrainOneStepCell(nn.Cell):
self.backbone = network_backbone self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag self.reduce_flag = reduce_flag
......
...@@ -171,8 +171,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -171,8 +171,7 @@ class TrainOneStepCell(nn.Cell):
self.backbone = network_backbone self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16)) self.sens = Tensor((np.ones((1,)) * sens).astype(np.float16))
self.reduce_flag = reduce_flag self.reduce_flag = reduce_flag
......
...@@ -119,7 +119,7 @@ class DistributedGradReducerThor(Cell): ...@@ -119,7 +119,7 @@ class DistributedGradReducerThor(Cell):
>>> self.network.add_flags(defer_inline=True) >>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params()) >>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer >>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens >>> self.sens = sens
>>> self.reducer_flag = False >>> self.reducer_flag = False
>>> self.grad_reducer = None >>> self.grad_reducer = None
......
...@@ -383,7 +383,7 @@ class TrainingWrapper(nn.Cell): ...@@ -383,7 +383,7 @@ class TrainingWrapper(nn.Cell):
self.network = network self.network = network
self.weights = ms.ParameterTuple(network.trainable_params()) self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -77,7 +77,7 @@ class TrainOneStepCellWithGradClip(Cell): ...@@ -77,7 +77,7 @@ class TrainOneStepCellWithGradClip(Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -412,7 +412,7 @@ class TrainingWrapper(nn.Cell): ...@@ -412,7 +412,7 @@ class TrainingWrapper(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -412,7 +412,7 @@ class TrainingWrapper(nn.Cell): ...@@ -412,7 +412,7 @@ class TrainingWrapper(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -647,7 +647,7 @@ class TrainingWrapper(nn.Cell): ...@@ -647,7 +647,7 @@ class TrainingWrapper(nn.Cell):
self.network = network self.network = network
self.weights = ms.ParameterTuple(network.trainable_params()) self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -141,7 +141,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -141,7 +141,7 @@ class TrainOneStepCell(nn.Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
def construct(self): def construct(self):
......
...@@ -150,7 +150,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -150,7 +150,7 @@ class TrainOneStepCell(nn.Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
def construct(self): def construct(self):
......
...@@ -57,8 +57,7 @@ class BertFinetuneCell(nn.Cell): ...@@ -57,8 +57,7 @@ class BertFinetuneCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
...@@ -160,7 +159,7 @@ class BertSquadCell(nn.Cell): ...@@ -160,7 +159,7 @@ class BertSquadCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
......
...@@ -274,7 +274,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -274,7 +274,7 @@ class BertTrainOneStepCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
...@@ -353,8 +353,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -353,8 +353,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
......
...@@ -293,7 +293,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -293,7 +293,7 @@ class BertTrainOneStepCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
...@@ -373,8 +373,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -373,8 +373,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
......
...@@ -119,7 +119,7 @@ class DistributedGradReducerThor(Cell): ...@@ -119,7 +119,7 @@ class DistributedGradReducerThor(Cell):
>>> self.network.add_flags(defer_inline=True) >>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params()) >>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer >>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens >>> self.sens = sens
>>> self.reducer_flag = False >>> self.reducer_flag = False
>>> self.grad_reducer = None >>> self.grad_reducer = None
......
...@@ -239,7 +239,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -239,7 +239,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.all_reduce = P.AllReduce() self.all_reduce = P.AllReduce()
......
...@@ -218,8 +218,7 @@ class BertTrainWithLossScaleCell(nn.Cell): ...@@ -218,8 +218,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
...@@ -310,8 +309,7 @@ class BertTrainCell(nn.Cell): ...@@ -310,8 +309,7 @@ class BertTrainCell(nn.Cell):
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.sens = sens self.sens = sens
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
...@@ -474,8 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell): ...@@ -474,8 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
self.network = network self.network = network
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
...@@ -562,8 +559,7 @@ class BertEvaluationCell(nn.Cell): ...@@ -562,8 +559,7 @@ class BertEvaluationCell(nn.Cell):
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.sens = sens self.sens = sens
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
......
...@@ -158,7 +158,7 @@ class TransformerTrainOneStepCell(nn.Cell): ...@@ -158,7 +158,7 @@ class TransformerTrainOneStepCell(nn.Cell):
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
...@@ -244,8 +244,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -244,8 +244,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
......
...@@ -286,7 +286,7 @@ class TrainStepWrap(nn.Cell): ...@@ -286,7 +286,7 @@ class TrainStepWrap(nn.Cell):
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale) self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = loss_scale self.sens = loss_scale
def construct(self, batch_ids, batch_wts, label): def construct(self, batch_ids, batch_wts, label):
......
...@@ -337,9 +337,9 @@ class TrainStepWrap(nn.Cell): ...@@ -337,9 +337,9 @@ class TrainStepWrap(nn.Cell):
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True, self.grad_w = C.GradOperation(get_by_list=True,
sens_param=True) sens_param=True)
self.grad_d = C.GradOperation('grad_d', get_by_list=True, self.grad_d = C.GradOperation(get_by_list=True,
sens_param=True) sens_param=True)
self.sens = sens self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_w = IthOutputCell(network, output_index=0)
......
...@@ -537,11 +537,9 @@ class TrainStepWrap(nn.Cell): ...@@ -537,11 +537,9 @@ class TrainStepWrap(nn.Cell):
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', self.grad_w = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.grad_d = C.GradOperation('grad_d', self.grad_d = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.sens = sens self.sens = sens
......
...@@ -46,5 +46,5 @@ class CompileBackwardBlockWrtInputsBC(IBuilderComponent): ...@@ -46,5 +46,5 @@ class CompileBackwardBlockWrtInputsBC(IBuilderComponent):
""" """
def __call__(self): def __call__(self):
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation(get_all=True, sens_param=True)
return create_funcs(self.verification_set, gen_grad_net, compile_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, compile_block, grad_op)
...@@ -46,5 +46,5 @@ class CompileBackwardBlockWrtParamsBC(IBuilderComponent): ...@@ -46,5 +46,5 @@ class CompileBackwardBlockWrtParamsBC(IBuilderComponent):
""" """
def __call__(self, verification_set): def __call__(self, verification_set):
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation(get_by_list=True, sens_param=True)
return create_funcs(self.verification_set, gen_grad_net, compile_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, compile_block, grad_op)
...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_unifo ...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_unifo
class RunBackwardBlockWrtInputsWithRandParamBC(IBuilderComponent): class RunBackwardBlockWrtInputsWithRandParamBC(IBuilderComponent):
def __call__(self): def __call__(self):
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation(get_all=True, sens_param=True)
return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape)
...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_unifo ...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs, get_unifo
class RunBackwardBlockWrtParamsWithRandParamBC(IBuilderComponent): class RunBackwardBlockWrtParamsWithRandParamBC(IBuilderComponent):
def __call__(self): def __call__(self):
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation(get_by_list=True, sens_param=True)
return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op, get_uniform_with_shape)
...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs ...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs
class RunBackwardBlockWrtInputsBC(IBuilderComponent): class RunBackwardBlockWrtInputsBC(IBuilderComponent):
def __call__(self): def __call__(self):
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation(get_all=True, sens_param=True)
return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op)
...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs ...@@ -22,5 +22,5 @@ from ...utils.block_util import run_block, gen_grad_net, create_funcs
class RunBackwardBlockWrtParamsBC(IBuilderComponent): class RunBackwardBlockWrtParamsBC(IBuilderComponent):
def __call__(self): def __call__(self):
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation(get_by_list=True, sens_param=True)
return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op) return create_funcs(self.verification_set, gen_grad_net, run_block, grad_op)
...@@ -331,7 +331,7 @@ def create_funcs(verification_set, block_generator, block_runner, grad_op=None, ...@@ -331,7 +331,7 @@ def create_funcs(verification_set, block_generator, block_runner, grad_op=None,
# gradient # gradient
if grad_op: if grad_op:
if num_outputs == 0: if num_outputs == 0:
grad_op_ = GradOperation('grad', get_all=grad_op.get_all, grad_op_ = GradOperation(get_all=grad_op.get_all,
get_by_list=grad_op.get_by_list, sens_param=False) get_by_list=grad_op.get_by_list, sens_param=False)
b = block_generator(block, grad_op_, len(inputs), desc_const=desc_const, b = block_generator(block, grad_op_, len(inputs), desc_const=desc_const,
const_first=const_first, add_fake_input=add_fake_input) const_first=const_first, add_fake_input=add_fake_input)
......
...@@ -85,7 +85,7 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list ...@@ -85,7 +85,7 @@ def bprop(func, *inputs, grads_wrt_outputs=None, wrt: list = None, params: list
if not params: if not params:
params = func.trainable_params() params = func.trainable_params()
grad_op = GradOperation(name='grad', get_all=wrt_inputs, get_by_list=wrt_params, sens_param=with_sens_param) grad_op = GradOperation(get_all=wrt_inputs, get_by_list=wrt_params, sens_param=with_sens_param)
grad = Bprop(func, wrt_params, params, grad_op, grads_wrt_outputs) grad = Bprop(func, wrt_params, params, grad_op, grads_wrt_outputs)
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
......
...@@ -315,7 +315,7 @@ class ScalarGradChecker(_GradChecker): ...@@ -315,7 +315,7 @@ class ScalarGradChecker(_GradChecker):
output_selector=None, output_selector=None,
sampling_times=-1, sampling_times=-1,
reduce_output=False) -> None: reduce_output=False) -> None:
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation(get_all=True, sens_param=True)
super(ScalarGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \ super(ScalarGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \
output_selector, sampling_times, reduce_output) output_selector, sampling_times, reduce_output)
...@@ -358,7 +358,7 @@ class OperationGradChecker(_GradChecker): ...@@ -358,7 +358,7 @@ class OperationGradChecker(_GradChecker):
output_selector=None, output_selector=None,
sampling_times=-1, sampling_times=-1,
reduce_output=False) -> None: reduce_output=False) -> None:
grad_op = GradOperation('grad', get_all=True, sens_param=True) grad_op = GradOperation(get_all=True, sens_param=True)
super(OperationGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \ super(OperationGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \
output_selector, sampling_times, reduce_output) output_selector, sampling_times, reduce_output)
...@@ -390,7 +390,7 @@ class NNGradChecker(_GradChecker): ...@@ -390,7 +390,7 @@ class NNGradChecker(_GradChecker):
output_selector=None, output_selector=None,
sampling_times=-1, sampling_times=-1,
reduce_output=False) -> None: reduce_output=False) -> None:
grad_op = GradOperation('grad', get_by_list=True, sens_param=True) grad_op = GradOperation(get_by_list=True, sens_param=True)
self.params = ParameterTuple(fn.trainable_params()) self.params = ParameterTuple(fn.trainable_params())
super(NNGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \ super(NNGradChecker, self).__init__(fn, grad_op, args, delta, max_error, input_selector, \
output_selector, sampling_times, reduce_output) output_selector, sampling_times, reduce_output)
......
...@@ -23,7 +23,7 @@ from mindspore import Tensor ...@@ -23,7 +23,7 @@ from mindspore import Tensor
from mindspore.common.api import _executor from mindspore.common.api import _executor
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
class InputBackward(nn.Cell): class InputBackward(nn.Cell):
......
...@@ -27,7 +27,7 @@ from mindspore.common.api import _executor ...@@ -27,7 +27,7 @@ from mindspore.common.api import _executor
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
batch_size = 1 batch_size = 1
channel = 1 channel = 1
......
...@@ -28,8 +28,8 @@ from mindspore.ops import operations as P ...@@ -28,8 +28,8 @@ from mindspore.ops import operations as P
# context.set_context(save_graphs=True) # context.set_context(save_graphs=True)
grad_by_list = C.GradOperation('get_by_list', get_by_list=True) grad_by_list = C.GradOperation(get_by_list=True)
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation(get_all=True)
def test_while_forward(): def test_while_forward():
......
...@@ -25,7 +25,7 @@ from mindspore.common.api import _executor ...@@ -25,7 +25,7 @@ from mindspore.common.api import _executor
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
class MeanAggregatorGrad(nn.Cell): class MeanAggregatorGrad(nn.Cell):
......
...@@ -284,9 +284,9 @@ class TrainStepWrap(nn.Cell): ...@@ -284,9 +284,9 @@ class TrainStepWrap(nn.Cell):
self.optimizer_d = Adam( self.optimizer_d = Adam(
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.grad_w = C.GradOperation('grad_w', get_by_list=True, self.grad_w = C.GradOperation(get_by_list=True,
sens_param=True) sens_param=True)
self.grad_d = C.GradOperation('grad_d', get_by_list=True, self.grad_d = C.GradOperation(get_by_list=True,
sens_param=True) sens_param=True)
self.sens = sens self.sens = sens
self.loss_net_w = IthOutputCell(network, output_index=0) self.loss_net_w = IthOutputCell(network, output_index=0)
......
...@@ -647,7 +647,7 @@ class TrainingWrapper(nn.Cell): ...@@ -647,7 +647,7 @@ class TrainingWrapper(nn.Cell):
self.network = network self.network = network
self.weights = ms.ParameterTuple(network.trainable_params()) self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
......
...@@ -271,7 +271,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -271,7 +271,7 @@ class BertTrainOneStepCell(nn.Cell):
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
self.reducer_flag = False self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
...@@ -351,8 +351,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -351,8 +351,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
......
...@@ -52,8 +52,7 @@ class BertFinetuneCell(nn.Cell): ...@@ -52,8 +52,7 @@ class BertFinetuneCell(nn.Cell):
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.reducer_flag = False self.reducer_flag = False
self.allreduce = P.AllReduce() self.allreduce = P.AllReduce()
......
...@@ -120,7 +120,7 @@ class DistributedGradReducerThor(Cell): ...@@ -120,7 +120,7 @@ class DistributedGradReducerThor(Cell):
>>> self.network.add_flags(defer_inline=True) >>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params()) >>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer >>> self.optimizer = optimizer
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens >>> self.sens = sens
>>> self.reducer_flag = False >>> self.reducer_flag = False
>>> self.grad_reducer = None >>> self.grad_reducer = None
......
...@@ -29,7 +29,7 @@ from mindspore.ops import operations as P ...@@ -29,7 +29,7 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation(get_all=True)
class MulAdd(nn.Cell): class MulAdd(nn.Cell):
...@@ -351,7 +351,7 @@ class MulAddWithParam(nn.Cell): ...@@ -351,7 +351,7 @@ class MulAddWithParam(nn.Cell):
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_refkey_bprop(): def test_refkey_bprop():
grad_by_list = C.GradOperation('get_by_list', get_all=True, get_by_list=True) grad_by_list = C.GradOperation(get_all=True, get_by_list=True)
class GradWrap(nn.Cell): class GradWrap(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
......
...@@ -49,7 +49,7 @@ def test_net(): ...@@ -49,7 +49,7 @@ def test_net():
def test_grad_addn_with_list(): def test_grad_addn_with_list():
grad_op = C.GradOperation('get_all', get_all=True) grad_op = C.GradOperation(get_all=True)
class AddN(nn.Cell): class AddN(nn.Cell):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -29,7 +29,7 @@ context.set_context(device_target="Ascend") ...@@ -29,7 +29,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -26,7 +26,7 @@ context.set_context(device_target="Ascend") ...@@ -26,7 +26,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -30,7 +30,7 @@ context.set_context(device_target="Ascend") ...@@ -30,7 +30,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -30,7 +30,7 @@ context.set_context(device_target="Ascend") ...@@ -30,7 +30,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_, output_grad): def construct(self, input_, output_grad):
...@@ -71,7 +71,7 @@ class MEGeluLargeIn(Cell): ...@@ -71,7 +71,7 @@ class MEGeluLargeIn(Cell):
class GradLargeIn(Cell): class GradLargeIn(Cell):
def __init__(self, network): def __init__(self, network):
super(GradLargeIn, self).__init__() super(GradLargeIn, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, x1, x2, output_grad): def construct(self, x1, x2, output_grad):
......
...@@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_, output_grad,): def construct(self, input_, output_grad,):
......
...@@ -21,7 +21,7 @@ from mindspore.ops import composite as C ...@@ -21,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(device_target="Ascend") context.set_context(device_target="Ascend")
grad = C.GradOperation('get_all', get_all=True, sens_param=True) grad = C.GradOperation(get_all=True, sens_param=True)
class MaxNetMe(Cell): class MaxNetMe(Cell):
......
...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -21,7 +21,7 @@ from mindspore.ops import composite as C ...@@ -21,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore.ops.operations import Minimum from mindspore.ops.operations import Minimum
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad = C.GradOperation('get_all', get_all=True, sens_param=True) grad = C.GradOperation(get_all=True, sens_param=True)
class MinNetMe(Cell): class MinNetMe(Cell):
......
...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -27,7 +27,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True) self.grad = GradOperation(get_all=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -37,7 +37,7 @@ class Net(nn.Cell): ...@@ -37,7 +37,7 @@ class Net(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -37,7 +37,7 @@ class Net(nn.Cell): ...@@ -37,7 +37,7 @@ class Net(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -37,7 +37,7 @@ class Net(nn.Cell): ...@@ -37,7 +37,7 @@ class Net(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -36,7 +36,7 @@ class Net(nn.Cell): ...@@ -36,7 +36,7 @@ class Net(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, pred, gt, dout): def construct(self, pred, gt, dout):
......
...@@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") ...@@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_, output_grad): def construct(self, input_, output_grad):
......
...@@ -37,7 +37,7 @@ class Batchnorm_Net(Cell): ...@@ -37,7 +37,7 @@ class Batchnorm_Net(Cell):
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_data, sens): def construct(self, input_data, sens):
......
...@@ -207,8 +207,7 @@ class Grad(nn.Cell): ...@@ -207,8 +207,7 @@ class Grad(nn.Cell):
super(Grad, self).__init__() super(Grad, self).__init__()
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
@ms_function @ms_function
......
...@@ -23,7 +23,7 @@ from mindspore.ops import composite as C ...@@ -23,7 +23,7 @@ from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad_with_sens = C.GradOperation('grad_with_sens', sens_param=True) grad_with_sens = C.GradOperation(sens_param=True)
class Net(nn.Cell): class Net(nn.Cell):
......
...@@ -37,7 +37,7 @@ class Batchnorm_Net(Cell): ...@@ -37,7 +37,7 @@ class Batchnorm_Net(Cell):
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_data, sens): def construct(self, input_data, sens):
......
...@@ -54,7 +54,7 @@ def test_binary_cross_entropy_loss(): ...@@ -54,7 +54,7 @@ def test_binary_cross_entropy_loss():
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, x1, x2, sens, weight): def construct(self, x1, x2, sens, weight):
......
...@@ -40,7 +40,7 @@ class Net(nn.Cell): ...@@ -40,7 +40,7 @@ class Net(nn.Cell):
class GradData(nn.Cell): class GradData(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradData, self).__init__() super(GradData, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=False) self.grad = GradOperation(get_all=True, sens_param=False)
self.network = network self.network = network
def construct(self, probs, labels, input_lengths, label_lengths): def construct(self, probs, labels, input_lengths, label_lengths):
......
...@@ -65,7 +65,7 @@ def test_biasadd(): ...@@ -65,7 +65,7 @@ def test_biasadd():
class GradData(nn.Cell): class GradData(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradData, self).__init__() super(GradData, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, inputs, output_grad): def construct(self, inputs, output_grad):
...@@ -77,8 +77,7 @@ class GradWeight(nn.Cell): ...@@ -77,8 +77,7 @@ class GradWeight(nn.Cell):
super(GradWeight, self).__init__() super(GradWeight, self).__init__()
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
def construct(self, x, output_grad): def construct(self, x, output_grad):
...@@ -169,7 +168,7 @@ def test_dw(): ...@@ -169,7 +168,7 @@ def test_dw():
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_, bias, dy): def construct(self, input_, bias, dy):
......
...@@ -37,7 +37,7 @@ class GeluNet(nn.Cell): ...@@ -37,7 +37,7 @@ class GeluNet(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_data, sens): def construct(self, input_data, sens):
......
...@@ -53,7 +53,7 @@ def test_binary_cross_entropy_loss(): ...@@ -53,7 +53,7 @@ def test_binary_cross_entropy_loss():
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, x1, x2, sens): def construct(self, x1, x2, sens):
......
...@@ -52,7 +52,7 @@ class LogSoftmax(nn.Cell): ...@@ -52,7 +52,7 @@ class LogSoftmax(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_data, sens): def construct(self, input_data, sens):
......
...@@ -581,8 +581,7 @@ class Grad(nn.Cell): ...@@ -581,8 +581,7 @@ class Grad(nn.Cell):
super(Grad, self).__init__() super(Grad, self).__init__()
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
@ms_function @ms_function
......
...@@ -35,7 +35,7 @@ class Net(Cell): ...@@ -35,7 +35,7 @@ class Net(Cell):
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, x1, x2, sens): def construct(self, x1, x2, sens):
......
...@@ -36,7 +36,7 @@ class MinimumNet(Cell): ...@@ -36,7 +36,7 @@ class MinimumNet(Cell):
class Grad(Cell): class Grad(Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, x1, x2, sens): def construct(self, x1, x2, sens):
......
...@@ -58,7 +58,7 @@ def test_mirror_pad(): ...@@ -58,7 +58,7 @@ def test_mirror_pad():
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_, output_grad): def construct(self, input_, output_grad):
return self.grad(self.network)(input_, output_grad) return self.grad(self.network)(input_, output_grad)
......
...@@ -59,7 +59,7 @@ def test_smoothl1loss(): ...@@ -59,7 +59,7 @@ def test_smoothl1loss():
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, x1, x2, sens): def construct(self, x1, x2, sens):
......
...@@ -79,7 +79,7 @@ class Net(nn.Cell): ...@@ -79,7 +79,7 @@ class Net(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_data, sens): def construct(self, input_data, sens):
......
...@@ -36,7 +36,7 @@ class StridedSliceNet(nn.Cell): ...@@ -36,7 +36,7 @@ class StridedSliceNet(nn.Cell):
class GradData(nn.Cell): class GradData(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradData, self).__init__() super(GradData, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=False) self.grad = C.GradOperation(get_all=True, sens_param=False)
self.network = network self.network = network
def construct(self, x): def construct(self, x):
......
...@@ -37,7 +37,7 @@ class TanhNet(nn.Cell): ...@@ -37,7 +37,7 @@ class TanhNet(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = C.GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
def construct(self, input_data, sens): def construct(self, input_data, sens):
......
...@@ -30,7 +30,7 @@ from mindspore.common.initializer import TruncatedNormal ...@@ -30,7 +30,7 @@ from mindspore.common.initializer import TruncatedNormal
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation(get_all=True)
def weight_variable(): def weight_variable():
...@@ -112,7 +112,7 @@ class GradWrap(nn.Cell): ...@@ -112,7 +112,7 @@ class GradWrap(nn.Cell):
def construct(self, x, label): def construct(self, x, label):
weights = self.weights weights = self.weights
return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label) return C.GradOperation(get_by_list=True)(self.network, weights)(x, label)
class test_custom_cell_base(): class test_custom_cell_base():
......
...@@ -29,7 +29,7 @@ from mindspore.ops import operations as P ...@@ -29,7 +29,7 @@ from mindspore.ops import operations as P
np.random.seed(1) np.random.seed(1)
grad_by_list = C.GradOperation('get_by_list', get_by_list=True) grad_by_list = C.GradOperation(get_by_list=True)
def weight_variable(): def weight_variable():
......
...@@ -87,7 +87,7 @@ class LeNet(nn.Cell): ...@@ -87,7 +87,7 @@ class LeNet(nn.Cell):
class GradWithSens(Cell): class GradWithSens(Cell):
def __init__(self, network): def __init__(self, network):
super(GradWithSens, self).__init__() super(GradWithSens, self).__init__()
self.grad = GradOperation(name="grad", get_all=False, self.grad = GradOperation(get_all=False,
sens_param=True) sens_param=True)
self.network = network self.network = network
...@@ -99,8 +99,7 @@ class GradWithSens(Cell): ...@@ -99,8 +99,7 @@ class GradWithSens(Cell):
class GradWrapWithLoss(Cell): class GradWrapWithLoss(Cell):
def __init__(self, network): def __init__(self, network):
super(GradWrapWithLoss, self).__init__() super(GradWrapWithLoss, self).__init__()
self._grad_all = GradOperation(name="get_all", self._grad_all = GradOperation(get_all=True,
get_all=True,
sens_param=False) sens_param=False)
self._network = network self._network = network
......
...@@ -40,7 +40,7 @@ np.random.seed(1) ...@@ -40,7 +40,7 @@ np.random.seed(1)
ds.config.set_seed(1) ds.config.set_seed(1)
grad_by_list = CP.GradOperation('get_by_list', get_by_list=True) grad_by_list = CP.GradOperation(get_by_list=True)
def weight_variable(shape): def weight_variable(shape):
......
...@@ -24,7 +24,7 @@ from mindspore.common.parameter import ParameterTuple ...@@ -24,7 +24,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C from mindspore.ops import composite as C
grad_by_list_with_sens = C.GradOperation('grad_by_list_with_sens', get_by_list=True, sens_param=True) grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True)
def setup_module(): def setup_module():
......
...@@ -32,7 +32,7 @@ class TrainStepWrap(nn.Cell): ...@@ -32,7 +32,7 @@ class TrainStepWrap(nn.Cell):
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.grad = C.GradOperation('grad', get_by_list=True) self.grad = C.GradOperation(get_by_list=True)
def construct(self, x, label): def construct(self, x, label):
weights = self.weights weights = self.weights
...@@ -71,7 +71,7 @@ class TrainStepWrap2(nn.Cell): ...@@ -71,7 +71,7 @@ class TrainStepWrap2(nn.Cell):
self.weights = ParameterTuple(network.get_parameters()) self.weights = ParameterTuple(network.get_parameters())
self.optimizer = nn.Momentum(self.weights, 0.1, 0.9) self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens self.sens = sens
def construct(self, x): def construct(self, x):
...@@ -93,7 +93,7 @@ class TrainStepWrapWithoutOpt(nn.Cell): ...@@ -93,7 +93,7 @@ class TrainStepWrapWithoutOpt(nn.Cell):
super(TrainStepWrapWithoutOpt, self).__init__() super(TrainStepWrapWithoutOpt, self).__init__()
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.grad = C.GradOperation('grad', get_by_list=True) self.grad = C.GradOperation(get_by_list=True)
def construct(self, x, label): def construct(self, x, label):
grads = self.grad(self.network, self.weights)(x, label) grads = self.grad(self.network, self.weights)(x, label)
......
...@@ -31,7 +31,7 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \ ...@@ -31,7 +31,7 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation(get_all=True)
def test_list_equal(): def test_list_equal():
......
...@@ -52,8 +52,7 @@ class TrainOneStepWithLarsCell(nn.Cell): ...@@ -52,8 +52,7 @@ class TrainOneStepWithLarsCell(nn.Cell):
self.slice_index, self.params_len, weights = get_net_trainable_reordered_params(self.network) self.slice_index, self.params_len, weights = get_net_trainable_reordered_params(self.network)
self.weights = ParameterTuple(weights) self.weights = ParameterTuple(weights)
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
self.sens = Parameter(Tensor([sens], mstype.float32), name='sens', requires_grad=False) self.sens = Parameter(Tensor([sens], mstype.float32), name='sens', requires_grad=False)
self.weight_decay = 1.0 self.weight_decay = 1.0
......
...@@ -248,7 +248,7 @@ def test_row_tensor_attr(): ...@@ -248,7 +248,7 @@ def test_row_tensor_attr():
def test_row_tensor_sparse_gatherv2_grad_all(): def test_row_tensor_sparse_gatherv2_grad_all():
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation(get_all=True)
class GradWrap(nn.Cell): class GradWrap(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
...@@ -269,7 +269,7 @@ def test_row_tensor_sparse_gatherv2_grad_all(): ...@@ -269,7 +269,7 @@ def test_row_tensor_sparse_gatherv2_grad_all():
def test_row_tensor_sparse_gatherv2_grad_with_pram(): def test_row_tensor_sparse_gatherv2_grad_with_pram():
grad_by_list = C.GradOperation('get_by_list', get_by_list=True) grad_by_list = C.GradOperation(get_by_list=True)
class GradWrap(nn.Cell): class GradWrap(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradWrap, self).__init__() super(GradWrap, self).__init__()
......
...@@ -28,7 +28,7 @@ from mindspore import Tensor, SparseTensor, context ...@@ -28,7 +28,7 @@ from mindspore import Tensor, SparseTensor, context
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
grad_op = C.GradOperation('get_all', get_all=True) grad_op = C.GradOperation(get_all=True)
class MakeSparseTensor(nn.Cell): class MakeSparseTensor(nn.Cell):
def __init__(self, dense_shape): def __init__(self, dense_shape):
......
...@@ -50,7 +50,7 @@ class Func(nn.Cell): ...@@ -50,7 +50,7 @@ class Func(nn.Cell):
return out return out
grad_s = C.GradOperation('grad_with_sens', get_all=True, sens_param=True) grad_s = C.GradOperation(get_all=True, sens_param=True)
class Net(nn.Cell): class Net(nn.Cell):
......
...@@ -166,8 +166,7 @@ class GetParamGrad(nn.Cell): ...@@ -166,8 +166,7 @@ class GetParamGrad(nn.Cell):
super(GetParamGrad, self).__init__(auto_prefix=False) super(GetParamGrad, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.grad = C.GradOperation('grad', self.grad = C.GradOperation(get_by_list=True,
get_by_list=True,
sens_param=True) sens_param=True)
def construct(self, data, sens): def construct(self, data, sens):
......
...@@ -22,7 +22,7 @@ from mindspore.ops.operations import BiasAdd, MatMul ...@@ -22,7 +22,7 @@ from mindspore.ops.operations import BiasAdd, MatMul
import mindspore.ops.composite as C import mindspore.ops.composite as C
grad_by_list = C.GradOperation('get_by_list', get_by_list=True) grad_by_list = C.GradOperation(get_by_list=True)
class Net(Cell): class Net(Cell):
......
...@@ -34,7 +34,7 @@ class Net(nn.Cell): ...@@ -34,7 +34,7 @@ class Net(nn.Cell):
class Grad(nn.Cell): class Grad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(Grad, self).__init__() super(Grad, self).__init__()
self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) self.grad = GradOperation(get_all=True, sens_param=True)
self.network = network self.network = network
@ms_function @ms_function
......
...@@ -28,7 +28,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ ...@@ -28,7 +28,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
grad_by_list_with_sens = C.GradOperation('grad_by_list_with_sens', get_by_list=True, sens_param=True) grad_by_list_with_sens = C.GradOperation(get_by_list=True, sens_param=True)
class DisOrderTest1(nn.Cell): class DisOrderTest1(nn.Cell):
......
...@@ -30,9 +30,9 @@ from mindspore.common import ms_function ...@@ -30,9 +30,9 @@ from mindspore.common import ms_function
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_by_list = C.GradOperation('get_by_list', get_by_list=True) grad_by_list = C.GradOperation(get_by_list=True)
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation(get_all=True)
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True) grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
def cond_data_test(x_init, y_init): def cond_data_test(x_init, y_init):
...@@ -564,7 +564,7 @@ def test_switch_layer_env_eliminate(): ...@@ -564,7 +564,7 @@ def test_switch_layer_env_eliminate():
class NetGrad(nn.Cell): class NetGrad(nn.Cell):
def __init__(self, net): def __init__(self, net):
super(NetGrad, self).__init__() super(NetGrad, self).__init__()
self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False) self.grad_op = C.GradOperation(get_by_list=True, sens_param=False)
self.net = net self.net = net
self.weights = ParameterTuple(self.net.trainable_params()) self.weights = ParameterTuple(self.net.trainable_params())
...@@ -593,7 +593,7 @@ def test_switch_layer_single_layer(): ...@@ -593,7 +593,7 @@ def test_switch_layer_single_layer():
class NetGrad(nn.Cell): class NetGrad(nn.Cell):
def __init__(self, net): def __init__(self, net):
super(NetGrad, self).__init__() super(NetGrad, self).__init__()
self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False) self.grad_op = C.GradOperation(get_by_list=True, sens_param=False)
self.net = net self.net = net
self.weights = ParameterTuple(self.net.trainable_params()) self.weights = ParameterTuple(self.net.trainable_params())
......
...@@ -38,7 +38,7 @@ context.set_context(mode=context.GRAPH_MODE) ...@@ -38,7 +38,7 @@ context.set_context(mode=context.GRAPH_MODE)
# W0613: unused-argument # W0613: unused-argument
# W0231: super-init-not-called # W0231: super-init-not-called
grad = C.GradOperation('grad') grad = C.GradOperation()
def test_multiply(): def test_multiply():
""" test_multiply """ """ test_multiply """
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册