未验证 提交 3e6950d5 编写于 作者: Z zhangbo9674 提交者: GitHub

[Optimizer] Add master weight for opt state_dict (#39121)

* add master weight for opt state_dict

* check empty of master weight

* strict gpu test

* refine unittest
上级 80dfa010
...@@ -25,6 +25,8 @@ import numpy as np ...@@ -25,6 +25,8 @@ import numpy as np
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard, convert_np_dtype_to_dtype_ from paddle.fluid.framework import Program, program_guard, convert_np_dtype_to_dtype_
import paddle import paddle
from paddle.io import Dataset
import numpy
paddle.enable_static() paddle.enable_static()
...@@ -1113,5 +1115,90 @@ class TestOptimizerDtype(unittest.TestCase): ...@@ -1113,5 +1115,90 @@ class TestOptimizerDtype(unittest.TestCase):
self.check_with_dtype('float32') self.check_with_dtype('float32')
class TestMasterWeightSaveForFP16(unittest.TestCase):
'''
For Amp-O2, some optimizer(Momentum, Adam ...) will create master weights for parameters to to improve the accuracy.
Master weights will be saved by optimizer::state_dict.
'''
def check_with_opt_state_dict(self, use_save_load=True):
paddle.seed(100)
numpy.random.seed(100)
class SimpleNet(paddle.nn.Layer):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linears = paddle.nn.LayerList([
paddle.nn.Linear(input_size, output_size) for i in range(1)
])
def forward(self, x):
for i, l in enumerate(self.linears):
x = self.linears[i](x)
return x
input_size = 2 # 设为较大的值
output_size = 2 # 设为较大的值
batch_size = 2 # batch_size 为8的倍数
nums_batch = 10
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
data = numpy.random.random([input_size]).astype('float16')
label = numpy.random.random([output_size]).astype('float16')
return data, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(nums_batch * batch_size)
loader = paddle.io.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True,
num_workers=0)
mse = paddle.nn.MSELoss()
model = SimpleNet(input_size, output_size) # 定义模型
optimizer = paddle.optimizer.Momentum(
learning_rate=0.0001,
parameters=model.parameters(),
multi_precision=True) # 定义优化器
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
model = paddle.amp.decorate(models=model, level='O2')
for i, (data, label) in enumerate(loader):
with paddle.amp.auto_cast(level='O2'):
output = model(data)
loss = mse(output, label)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad(set_to_zero=False)
if use_save_load and i == 5:
paddle.save(model.state_dict(), "model.pdparams")
paddle.save(optimizer.state_dict(), "opt.pdopt")
model.set_state_dict(paddle.load("model.pdparams"))
optimizer.set_state_dict(paddle.load("opt.pdopt"))
return loss.numpy()
def test_with_state_dict(self):
if core.is_compiled_with_cuda():
with fluid.dygraph.guard():
out_use_state_dict = self.check_with_opt_state_dict(
use_save_load=True)
out_no_state_dict = self.check_with_opt_state_dict(
use_save_load=False)
self.assertTrue(
np.array_equal(out_use_state_dict, out_no_state_dict))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -256,6 +256,10 @@ class Optimizer(object): ...@@ -256,6 +256,10 @@ class Optimizer(object):
for k, v in self._accumulators.items(): for k, v in self._accumulators.items():
for para_name, var_tmp in v.items(): for para_name, var_tmp in v.items():
state_dict[var_tmp.name] = var_tmp state_dict[var_tmp.name] = var_tmp
# if has master weight and then save master weight
if hasattr(self, "_master_weights"):
if len(self._master_weights) != 0:
state_dict["master_weights"] = self._master_weights
# global step if use lr decay # global step if use lr decay
if isinstance(self._learning_rate, LRScheduler): if isinstance(self._learning_rate, LRScheduler):
state_dict["LR_Scheduler"] = self._learning_rate.state_dict() state_dict["LR_Scheduler"] = self._learning_rate.state_dict()
...@@ -304,6 +308,10 @@ class Optimizer(object): ...@@ -304,6 +308,10 @@ class Optimizer(object):
state_dict = state_dict.copy() state_dict = state_dict.copy()
if "LR_Scheduler" in state_dict: if "LR_Scheduler" in state_dict:
state_dict.pop("LR_Scheduler") state_dict.pop("LR_Scheduler")
if "master_weights" in state_dict:
if hasattr(self, "_master_weights"):
self._master_weights = state_dict["master_weights"]
state_dict.pop("master_weights")
self._accumulators_holder = state_dict self._accumulators_holder = state_dict
for k, v in self._accumulators.items(): for k, v in self._accumulators.items():
for para_name, var_tmp in v.items(): for para_name, var_tmp in v.items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册