提交 6600a49a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!751 modify code support pynative mode

Merge pull request !751 from jinyaohui/master
...@@ -95,6 +95,7 @@ class TrainForwardBackward(Cell): ...@@ -95,6 +95,7 @@ class TrainForwardBackward(Cell):
def __init__(self, network, optimizer, grad_sum, sens=1.0): def __init__(self, network, optimizer, grad_sum, sens=1.0):
super(TrainForwardBackward, self).__init__(auto_prefix=False) super(TrainForwardBackward, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
......
import argparse import argparse
import os import os
from collections.abc import Iterable
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import ParameterTuple from mindspore import ParameterTuple
...@@ -37,6 +36,7 @@ class TrainForwardBackward(Cell): ...@@ -37,6 +36,7 @@ class TrainForwardBackward(Cell):
def __init__(self, network, optimizer, grad_sum, sens=1.0): def __init__(self, network, optimizer, grad_sum, sens=1.0):
super(TrainForwardBackward, self).__init__(auto_prefix=False) super(TrainForwardBackward, self).__init__(auto_prefix=False)
self.network = network self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True) self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params()) self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer self.optimizer = optimizer
...@@ -88,17 +88,6 @@ class GradientAccumulation: ...@@ -88,17 +88,6 @@ class GradientAccumulation:
self._train_optim = self._build_train_optim() self._train_optim = self._build_train_optim()
self._train_clear = self._build_train_clear() self._train_clear = self._build_train_clear()
@staticmethod
def _transform_callbacks(callbacks):
"""Transform callback to a list."""
if callbacks is None:
return []
if isinstance(callbacks, Iterable):
return list(callbacks)
return [callbacks]
def _build_train_forward_backward_network(self): def _build_train_forward_backward_network(self):
"""Build forward and backward network""" """Build forward and backward network"""
network = self._network network = self._network
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册