提交 7399d560 编写于 作者: Z zxcd

fix scaler save and load.

上级 2f4414a5
...@@ -82,6 +82,7 @@ class U2Trainer(Trainer): ...@@ -82,6 +82,7 @@ class U2Trainer(Trainer):
with context(): with context():
if scaler: if scaler:
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.unscale_(self.optimizer)
else: else:
loss.backward() loss.backward()
layer_tools.print_grads(self.model, print_func=None) layer_tools.print_grads(self.model, print_func=None)
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
import time import time
from collections import OrderedDict from collections import OrderedDict
...@@ -189,8 +190,12 @@ class Trainer(): ...@@ -189,8 +190,12 @@ class Trainer():
"step": self.iteration, "step": self.iteration,
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr(), "lr": self.optimizer.get_lr(),
"scaler": self.scaler.state_dict()
}) })
if self.scaler:
scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
paddle.save(self.scaler.state_dict(), scaler_path)
self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration self.checkpoint.save_parameters(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)
...@@ -213,8 +218,13 @@ class Trainer(): ...@@ -213,8 +218,13 @@ class Trainer():
# lr will resotre from optimizer ckpt # lr will resotre from optimizer ckpt
self.iteration = infos["step"] self.iteration = infos["step"]
self.epoch = infos["epoch"] self.epoch = infos["epoch"]
self.scaler = paddle.amp.GradScaler()
self.scaler.load_state_dict(infos["scaler"]) scaler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.scaler'
if os.path.exists(scaler_path):
scaler_state_dict = paddle.load(scaler_path)
self.scaler.load_state_dict(scaler_state_dict)
scratch = False scratch = False
logger.info( logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册