未验证 提交 e29c2d12 编写于 作者: L Leo Chen 提交者: GitHub

[amp] dygraph amp support param_group (#34899)

* dygraph amp support param_group

* remove unused code

* fix doc
上级 b0cb4148
......@@ -146,6 +146,49 @@ class GradScaler(AmpScaler):
"""
return super(GradScaler, self).minimize(optimizer, *args, **kwargs)
def step(self, optimizer):
"""
This function is similar as `optimizer.step()`, which performs parameters updating.
If the scaled gradients of parameters contains NAN or INF, the parameters updating is skipped.
Otherwise, it first unscales the scaled gradients of parameters, then updates the parameters.
Args:
optimizer(Optimizer): The optimizer used to update parameters.
Examples:
.. code-block:: python
# required: gpu
import paddle
model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
data = paddle.rand([10, 3, 32, 32])
with paddle.amp.auto_cast():
conv = model(data)
loss = paddle.mean(conv)
scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
scaler.step(optimizer)
optimizer.clear_grad()
"""
if not self._enable:
return optimizer.step()
# unscale the grad
self._unscale(optimizer)
if self._found_inf:
self._cache_founf_inf = True
else:
optimizer.step()
self._cache_founf_inf = False
if self._use_dynamic_loss_scaling:
# uopdate the scale
self._update()
def is_enable(self):
"""
Enable loss scaling or not.
......
......@@ -212,10 +212,19 @@ class AmpScaler(object):
def _unscale(self, optimizer):
if not self._enable:
return
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict):
param_grads = []
for group in optimizer._param_groups:
for param in group['params']:
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
_C_ops.check_finite_and_unscale(param_grads, self._scale, param_grads,
self._found_inf)
......
......@@ -19,6 +19,9 @@ import numpy as np
import six
from test_imperative_resnet import ResNet, BottleneckBlock, ConvBNLayer, train_parameters, optimizer_setting
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
class SimpleConv(fluid.dygraph.Layer):
def __init__(self,
......@@ -373,8 +376,6 @@ class TestGradScalerStateDict(unittest.TestCase):
return dy_out, dy_param_value, dy_grad_value
def test_with_state_dict(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
with fluid.dygraph.guard():
out_use_state_dict = self.train_resnet(
enable_amp=True, use_data_loader=True, use_save_load=True)
......@@ -390,18 +391,43 @@ class TestResnet2(unittest.TestCase):
Use paddle-2.0 API
"""
def train_resnet(self, enable_amp=True, use_data_loader=False):
def train_resnet(self,
enable_amp=True,
use_data_loader=False,
use_param_group=False):
seed = 90
batch_size = train_parameters["batch_size"]
batch_num = 1
batch_num = 10
paddle.seed(seed)
paddle.framework.random._manual_program_seed(seed)
resnet = ResNet(use_cudnn=True)
optimizer = optimizer_setting(
train_parameters, parameter_list=resnet.parameters())
if use_param_group:
conv_params = resnet.conv.parameters()
other_params = []
for p in resnet.parameters():
contains = False
for q in conv_params:
if p is q:
contains = True
if not contains:
other_params.append(p)
# NOTE(zhiqiu): The Membership test operations(in / not in) calls "is" and "equal",
# see details: https://docs.python.org/3/reference/expressions.html#membership-test-operations.
# So do not use other_params = [p for p in resnet.parameters() if p not in conv_params]
optimizer = paddle.optimizer.Momentum(parameters=[{
'params': conv_params,
'learning_rate': 0.01
}, {
'params': other_params,
'learning_rate': 0.001
}])
else:
optimizer = paddle.optimizer.SGD(parameters=resnet.parameters())
np.random.seed(seed)
train_reader = paddle.batch(
paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
......@@ -456,7 +482,7 @@ class TestResnet2(unittest.TestCase):
scaled_loss = scaler.scale(avg_loss)
scaled_loss.backward()
scaler.minimize(optimizer, scaled_loss)
scaler.step(optimizer)
dy_grad_value = {}
for param in resnet.parameters():
......@@ -475,22 +501,27 @@ class TestResnet2(unittest.TestCase):
return dy_out, dy_param_value, dy_grad_value
def test_resnet(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
with fluid.dygraph.guard():
out_fp32 = self.train_resnet(enable_amp=False)
out_amp = self.train_resnet(enable_amp=True)
print(out_fp32[0], out_amp[0])
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
def test_with_data_loader(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
with fluid.dygraph.guard():
out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True)
out_amp = self.train_resnet(enable_amp=True, use_data_loader=True)
print(out_fp32[0], out_amp[0])
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
def test_param_group(self):
with fluid.dygraph.guard():
out_fp32 = self.train_resnet(
enable_amp=False, use_data_loader=True, use_param_group=True)
out_amp = self.train_resnet(
enable_amp=True, use_data_loader=True, use_param_group=True)
print(out_fp32[0], out_amp[0])
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-5))
class TestResnet(unittest.TestCase):
......@@ -566,8 +597,6 @@ class TestResnet(unittest.TestCase):
return dy_out, dy_param_value, dy_grad_value
def test_resnet(self):
if fluid.core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
out_fp32 = self.train_resnet(enable_amp=False)
out_amp = self.train_resnet(enable_amp=True)
print(out_fp32[0], out_amp[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册