未验证 提交 e29c2d12 编写于 作者: L Leo Chen 提交者: GitHub

[amp] dygraph amp support param_group (#34899)

* dygraph amp support param_group

* remove unused code

* fix doc
上级 b0cb4148
...@@ -146,6 +146,49 @@ class GradScaler(AmpScaler): ...@@ -146,6 +146,49 @@ class GradScaler(AmpScaler):
""" """
return super(GradScaler, self).minimize(optimizer, *args, **kwargs) return super(GradScaler, self).minimize(optimizer, *args, **kwargs)
def step(self, optimizer):
"""
This function is similar as `optimizer.step()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Examples:
.. code-block:: python
# required: gpu
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.auto_cast():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.step(optimizer)
optimizer.clear_grad()
"""
if not self._enable:
return optimizer.step()
# unscale the grad
self._unscale(optimizer)
if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
# uopdate the scale
self._update()
def is_enable(self): def is_enable(self):
""" """
Enable loss scaling or not. Enable loss scaling or not.
......
...@@ -212,10 +212,19 @@ class AmpScaler(object): ...@@ -212,10 +212,19 @@ class AmpScaler(object):
def _unscale(self, optimizer): def _unscale(self, optimizer):
if not self._enable: if not self._enable:
return return
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list if getattr(optimizer, '_param_groups', None) and isinstance(
if param._grad_ivar() is not None optimizer._param_groups[0], dict):
] param_grads = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
_C_ops.check_finite_and_unscale(param_grads, self._scale, param_grads, _C_ops.check_finite_and_unscale(param_grads, self._scale, param_grads,
self._found_inf) self._found_inf)
......
...@@ -19,6 +19,9 @@ import numpy as np ...@@ -19,6 +19,9 @@ import numpy as np
import six import six
from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
class SimpleConv(fluid.dygraph.Layer): class SimpleConv(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
...@@ -373,8 +376,6 @@ class TestGradScalerStateDict(unittest.TestCase): ...@@ -373,8 +376,6 @@ class TestGradScalerStateDict(unittest.TestCase):
return dy_out, dy_param_value, dy_grad_value return dy_out, dy_param_value, dy_grad_value
def test_with_state_dict(self): def test_with_state_dict(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
with fluid.dygraph.guard(): with fluid.dygraph.guard():
out_use_state_dict = self.train_resnet( out_use_state_dict = self.train_resnet(
enable_amp=True, use_data_loader=True, use_save_load=True) enable_amp=True, use_data_loader=True, use_save_load=True)
...@@ -390,18 +391,43 @@ class TestResnet2(unittest.TestCase): ...@@ -390,18 +391,43 @@ class TestResnet2(unittest.TestCase):
Use paddle-2.0 API Use paddle-2.0 API
""" """
def train_resnet(self, enable_amp=True, use_data_loader=False): def train_resnet(self,
enable_amp=True,
use_data_loader=False,
use_param_group=False):
seed = 90 seed = 90
batch_size = train_parameters["batch_size"] batch_size = train_parameters["batch_size"]
batch_num = 1 batch_num = 10
paddle.seed(seed) paddle.seed(seed)
paddle.framework.random._manual_program_seed(seed) paddle.framework.random._manual_program_seed(seed)
resnet = ResNet(use_cudnn=True) resnet = ResNet(use_cudnn=True)
optimizer = optimizer_setting(
train_parameters, parameter_list=resnet.parameters()) if use_param_group:
conv_params = resnet.conv.parameters()
other_params = []
for p in resnet.parameters():
contains = False
for q in conv_params:
if p is q:
contains = True
if not contains:
other_params.append(p)
# NOTE(zhiqiu): The Membership test operations(in / not in) calls "is" and "equal",
# see details: https://docs.python.org/3/reference/expressions.html#membership-test-operations.
# So do not use other_params = [p for p in resnet.parameters() if p not in conv_params]
optimizer = paddle.optimizer.Momentum(parameters=[{
'params': conv_params,
'learning_rate': 0.01
}, {
'params': other_params,
'learning_rate': 0.001
}])
else:
optimizer = paddle.optimizer.SGD(parameters=resnet.parameters())
np.random.seed(seed) np.random.seed(seed)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
...@@ -456,7 +482,7 @@ class TestResnet2(unittest.TestCase): ...@@ -456,7 +482,7 @@ class TestResnet2(unittest.TestCase):
scaled_loss = scaler.scale(avg_loss) scaled_loss = scaler.scale(avg_loss)
scaled_loss.backward() scaled_loss.backward()
scaler.minimize(optimizer, scaled_loss) scaler.step(optimizer)
dy_grad_value = {} dy_grad_value = {}
for param in resnet.parameters(): for param in resnet.parameters():
...@@ -475,22 +501,27 @@ class TestResnet2(unittest.TestCase): ...@@ -475,22 +501,27 @@ class TestResnet2(unittest.TestCase):
return dy_out, dy_param_value, dy_grad_value return dy_out, dy_param_value, dy_grad_value
def test_resnet(self): def test_resnet(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
with fluid.dygraph.guard(): with fluid.dygraph.guard():
out_fp32 = self.train_resnet(enable_amp=False) out_fp32 = self.train_resnet(enable_amp=False)
out_amp = self.train_resnet(enable_amp=True) out_amp = self.train_resnet(enable_amp=True)
print(out_fp32[0], out_amp[0]) print(out_fp32[0], out_amp[0])
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
def test_with_data_loader(self): def test_with_data_loader(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
with fluid.dygraph.guard(): with fluid.dygraph.guard():
out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True) out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True)
out_amp = self.train_resnet(enable_amp=True, use_data_loader=True) out_amp = self.train_resnet(enable_amp=True, use_data_loader=True)
print(out_fp32[0], out_amp[0]) print(out_fp32[0], out_amp[0])
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
def test_param_group(self):
with fluid.dygraph.guard():
out_fp32 = self.train_resnet(
enable_amp=False, use_data_loader=True, use_param_group=True)
out_amp = self.train_resnet(
enable_amp=True, use_data_loader=True, use_param_group=True)
print(out_fp32[0], out_amp[0])
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
class TestResnet(unittest.TestCase): class TestResnet(unittest.TestCase):
...@@ -566,8 +597,6 @@ class TestResnet(unittest.TestCase): ...@@ -566,8 +597,6 @@ class TestResnet(unittest.TestCase):
return dy_out, dy_param_value, dy_grad_value return dy_out, dy_param_value, dy_grad_value
def test_resnet(self): def test_resnet(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
out_fp32 = self.train_resnet(enable_amp=False) out_fp32 = self.train_resnet(enable_amp=False)
out_amp = self.train_resnet(enable_amp=True) out_amp = self.train_resnet(enable_amp=True)
print(out_fp32[0], out_amp[0]) print(out_fp32[0], out_amp[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册