提交 7abe692f 编写于 作者: L lvchangquan

add set_grad() for pynative mode in model_zoo network

上级 d05c22a1
...@@ -160,6 +160,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -160,6 +160,7 @@ class TrainOneStepCell(nn.Cell):
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.backbone = network_backbone self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
......
...@@ -168,6 +168,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -168,6 +168,7 @@ class TrainOneStepCell(nn.Cell):
def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.backbone = network_backbone self.backbone = network_backbone
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
......
...@@ -382,6 +382,7 @@ class TrainingWrapper(nn.Cell): ...@@ -382,6 +382,7 @@ class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False) super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = ms.ParameterTuple(network.trainable_params()) self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
......
...@@ -411,6 +411,7 @@ class TrainingWrapper(nn.Cell): ...@@ -411,6 +411,7 @@ class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False) super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
......
...@@ -411,6 +411,7 @@ class TrainingWrapper(nn.Cell): ...@@ -411,6 +411,7 @@ class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False) super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
......
...@@ -646,6 +646,7 @@ class TrainingWrapper(nn.Cell): ...@@ -646,6 +646,7 @@ class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False) super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = ms.ParameterTuple(network.trainable_params()) self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
......
...@@ -167,6 +167,7 @@ class TrainGAT(nn.Cell): ...@@ -167,6 +167,7 @@ class TrainGAT(nn.Cell):
def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff): def __init__(self, network, num_class, label, mask, learning_rate, l2_coeff):
super(TrainGAT, self).__init__(auto_prefix=False) super(TrainGAT, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff) loss_net = LossNetWrapper(network, num_class, label, mask, l2_coeff)
optimizer = nn.Adam(loss_net.trainable_params(), optimizer = nn.Adam(loss_net.trainable_params(),
learning_rate=learning_rate) learning_rate=learning_rate)
......
...@@ -147,6 +147,7 @@ class TrainOneStepCell(nn.Cell): ...@@ -147,6 +147,7 @@ class TrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False) super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
......
...@@ -55,6 +55,7 @@ class BertFinetuneCell(nn.Cell): ...@@ -55,6 +55,7 @@ class BertFinetuneCell(nn.Cell):
super(BertFinetuneCell, self).__init__(auto_prefix=False) super(BertFinetuneCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
...@@ -157,6 +158,7 @@ class BertSquadCell(nn.Cell): ...@@ -157,6 +158,7 @@ class BertSquadCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertSquadCell, self).__init__(auto_prefix=False) super(BertSquadCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
......
...@@ -273,6 +273,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -273,6 +273,7 @@ class BertTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
...@@ -352,6 +353,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -352,6 +353,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False): def __init__(self, network, optimizer, scale_update_cell=None, enable_global_norm=False):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.enable_global_norm = enable_global_norm self.enable_global_norm = enable_global_norm
...@@ -482,6 +484,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell): ...@@ -482,6 +484,7 @@ class BertTrainAccumulateStepsWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
......
...@@ -291,6 +291,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -291,6 +291,7 @@ class BertTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainOneStepCell, self).__init__(auto_prefix=False) super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
...@@ -371,6 +372,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -371,6 +372,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
......
...@@ -236,6 +236,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -236,6 +236,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
......
...@@ -216,6 +216,7 @@ class BertTrainWithLossScaleCell(nn.Cell): ...@@ -216,6 +216,7 @@ class BertTrainWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False) super(BertTrainWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
...@@ -306,6 +307,7 @@ class BertTrainCell(nn.Cell): ...@@ -306,6 +307,7 @@ class BertTrainCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertTrainCell, self).__init__(auto_prefix=False) super(BertTrainCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.sens = sens self.sens = sens
...@@ -470,6 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell): ...@@ -470,6 +472,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, self.grad = C.GradOperation(get_by_list=True,
...@@ -556,6 +559,7 @@ class BertEvaluationCell(nn.Cell): ...@@ -556,6 +559,7 @@ class BertEvaluationCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(BertEvaluationCell, self).__init__(auto_prefix=False) super(BertEvaluationCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = optimizer.parameters self.weights = optimizer.parameters
self.optimizer = optimizer self.optimizer = optimizer
self.sens = sens self.sens = sens
......
...@@ -156,6 +156,7 @@ class TransformerTrainOneStepCell(nn.Cell): ...@@ -156,6 +156,7 @@ class TransformerTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0): def __init__(self, network, optimizer, sens=1.0):
super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False) super(TransformerTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.grad = C.GradOperation(get_by_list=True, sens_param=True)
...@@ -241,6 +242,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -241,6 +242,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
def __init__(self, network, optimizer, scale_update_cell=None): def __init__(self, network, optimizer, scale_update_cell=None):
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) super(TransformerTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
......
...@@ -282,6 +282,7 @@ class TrainStepWrap(nn.Cell): ...@@ -282,6 +282,7 @@ class TrainStepWrap(nn.Cell):
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0): def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
super(TrainStepWrap, self).__init__(auto_prefix=False) super(TrainStepWrap, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.set_train() self.network.set_train()
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale) self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
......
...@@ -328,6 +328,7 @@ class TrainStepWrap(nn.Cell): ...@@ -328,6 +328,7 @@ class TrainStepWrap(nn.Cell):
parallel_mode = context.get_auto_parallel_context("parallel_mode") parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.network = network self.network = network
self.network.set_grad()
self.network.set_train() self.network.set_train()
self.trainable_params = network.trainable_params() self.trainable_params = network.trainable_params()
weights_w = [] weights_w = []
......
...@@ -510,6 +510,7 @@ class TrainStepWrap(nn.Cell): ...@@ -510,6 +510,7 @@ class TrainStepWrap(nn.Cell):
def __init__(self, network, config, sens=1000.0): def __init__(self, network, config, sens=1000.0):
super(TrainStepWrap, self).__init__() super(TrainStepWrap, self).__init__()
self.network = network self.network = network
self.network.set_grad()
self.network.set_train() self.network.set_train()
self.trainable_params = network.trainable_params() self.trainable_params = network.trainable_params()
weights_w = [] weights_w = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册