diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 827a320b2cc9c4976e3c552526577fea167ffb55..18c436a0bb95f794944a07a974031f2b936f8c86 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -47,7 +47,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python - + import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -91,7 +91,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python - + import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -156,6 +156,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -178,7 +179,8 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python - + + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -202,6 +204,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -220,11 +223,12 @@ class GradScaler(AmpScaler): Set the initial loss scaling factor by `new_init_loss_scaling`. Args: - new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor. + new_init_loss_scaling(float): The new_init_loss_scaling used to update initial loss scaling factor. Examples: .. code-block:: python - + + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -250,6 +254,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -273,6 +278,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -298,6 +304,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -321,6 +328,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -346,6 +354,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -369,6 +378,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -394,6 +404,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -417,6 +428,7 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -432,3 +444,59 @@ class GradScaler(AmpScaler): """ super(GradScaler, self).set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf) + + def state_dict(self): + """ + Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + + Reurns: + A dict of scaler includes: + init_loss_scaling (float, optional): The initial loss scaling factor. + incr_ratio(float, optional): The multiplier to use when increasing the loss scaling. + decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing the loss scaling. + incr_every_n_steps(int, optional): Increases loss scaling every n consecutive steps with finite gradients. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n accumulated steps with nan or inf gradients. + + Examples: + + .. code-block:: python + + # required: gpu,xpu + import paddle + + scaler = paddle.amp.GradScaler(enable=True, + init_loss_scaling=1024, + incr_ratio=2.0, + decr_ratio=0.5, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + use_dynamic_loss_scaling=True) + scaler_state = scaler.state_dict() + """ + return super(GradScaler, self).state_dict() + + def load_state_dict(self, state_dict): + """ + Loads the scaler state. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to `GradScaler.state_dict()`. + + Examples: + + .. code-block:: python + + # required: gpu,xpu + import paddle + + scaler = paddle.amp.GradScaler(enable=True, + init_loss_scaling=1024, + incr_ratio=2.0, + decr_ratio=0.5, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + use_dynamic_loss_scaling=True) + scaler_state = scaler.state_dict() + scaler.load_state_dict(scaler_state) + """ + super(GradScaler, self).load_state_dict(state_dict) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 96ee4514ac2b936c77083b136f34e0ce2647782b..2065bec8af3bc4fdc3d4bbbc558433b1448e8757 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -357,3 +357,55 @@ class AmpScaler(object): new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. """ self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf + + def state_dict(self): + """ + Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + + Reurns: + A dict of scaler includes: + scale (tensor): The loss scaling factor. + incr_ratio(float): The multiplier to use when increasing the loss scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. + incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. + incr_count(int): The number of recent consecutive unskipped steps. + decr_count(int): The number of recent consecutive skipped steps. + use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. + """ + return { + "scale": self._scale.numpy(), + "incr_ratio": self._incr_ratio, + "decr_ratio": self._decr_ratio, + "incr_every_n_steps": self._incr_every_n_steps, + "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf, + "incr_count": self._incr_count, + "decr_count": self._decr_count, + "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling + } if self._enable else {} + + def load_state_dict(self, state_dict): + """ + Loads the scaler state. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. + """ + if not self._enable: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The input state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_loss_scaling = state_dict["scale"][0] + self._scale = to_variable( + np.array([self._init_loss_scaling]).astype(np.float32)) + self._incr_ratio = state_dict["incr_ratio"] + self._decr_ratio = state_dict["decr_ratio"] + self._incr_every_n_steps = state_dict["incr_every_n_steps"] + self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"] + self._incr_count = state_dict["incr_count"] + self._decr_count = state_dict["decr_count"] + self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"] diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index e3d2bda89212874fdc97bc05cbc0addb008cd924..17d50ed8c19de0e20afd9c15fbcc08d7cbdf1fe9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -237,6 +237,37 @@ class TestAmpScaler(unittest.TestCase): scaler.set_init_loss_scaling(100) self.assertEqual(scaler.get_init_loss_scaling() == 100, True) + def test_state_dict_and_load_state_dict(self): + with fluid.dygraph.guard(): + scaler1 = paddle.amp.GradScaler( + enable=True, + init_loss_scaling=14, + incr_ratio=233.0, + decr_ratio=0.523, + incr_every_n_steps=1090, + decr_every_n_nan_or_inf=20, + use_dynamic_loss_scaling=True) + scaler_state = scaler1.state_dict() + scaler2 = paddle.amp.GradScaler(enable=True) + scaler2.load_state_dict(scaler_state) + self.assertEqual(scaler2.get_init_loss_scaling() == 14, True) + self.assertEqual(scaler2.get_incr_ratio() == 233.0, True) + self.assertEqual(scaler2.get_decr_ratio() == 0.523, True) + self.assertEqual(scaler2.get_incr_every_n_steps() == 1090, True) + self.assertEqual(scaler2.get_decr_every_n_nan_or_inf() == 20, True) + + scaler3 = paddle.amp.GradScaler(enable=False) + scaler3.load_state_dict(scaler_state) + self.assertEqual(scaler3.is_enable() == False, True) + + def test_state_dict_and_load_state_dict_error(self): + def test_error(): + state_empty = {} + scaler = paddle.amp.GradScaler(enable=True) + scaler.load_state_dict(state_empty) + + self.assertRaises(RuntimeError, test_error) + def reader_decorator(reader): def __reader__(): @@ -248,6 +279,112 @@ def reader_decorator(reader): return __reader__ +class TestGradScalerStateDict(unittest.TestCase): + def train_resnet(self, + enable_amp=True, + use_data_loader=True, + use_save_load=True): + seed = 90 + + batch_size = train_parameters["batch_size"] + batch_num = 4 + + paddle.seed(seed) + paddle.framework.random._manual_program_seed(seed) + + resnet = ResNet(use_cudnn=True) + optimizer = optimizer_setting( + train_parameters, parameter_list=resnet.parameters()) + np.random.seed(seed) + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) + + dy_param_init_value = {} + for param in resnet.parameters(): + dy_param_init_value[param.name] = param.numpy() + + program = None + scaler = paddle.amp.GradScaler( + enable=enable_amp, init_loss_scaling=2.**10) + + if use_data_loader: + train_reader = paddle.batch( + reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), + batch_size=batch_size, + drop_last=True) + train_loader = fluid.io.DataLoader.from_generator( + capacity=4, + use_double_buffer=True, + iterable=True, + return_list=True) + train_loader.set_sample_list_generator(train_reader) + train_reader = train_loader + + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + if use_data_loader: + img, label = data + else: + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + if len(np.array([x[1] + for x in data]).astype('int64')) != batch_size: + continue + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(dy_x_data) + label = paddle.to_tensor(y_data) + label.stop_gradient = True + + with paddle.amp.auto_cast(enable=enable_amp): + out = resnet(img) + + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + avg_loss = paddle.mean(x=loss) + + dy_out = avg_loss.numpy() + + scaled_loss = scaler.scale(avg_loss) + scaled_loss.backward() + + scaler.minimize(optimizer, scaled_loss) + + dy_grad_value = {} + for param in resnet.parameters(): + if param.trainable: + np_array = np.array(param._grad_ivar().value().get_tensor()) + dy_grad_value[param.name + fluid.core.grad_var_suffix( + )] = np_array + + resnet.clear_gradients() + + dy_param_value = {} + for param in resnet.parameters(): + dy_param_value[param.name] = param.numpy() + + if use_save_load and batch_id == 2: + paddle.save(scaler.state_dict(), 'ResNet_model.pdparams') + dict_load = paddle.load('ResNet_model.pdparams') + scaler.load_state_dict(dict_load) + if use_data_loader: + train_reader._reset() + return dy_out, dy_param_value, dy_grad_value + + def test_with_state_dict(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + with fluid.dygraph.guard(): + out_use_state_dict = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=True) + out_no_state_dict = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=False) + print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) + self.assertTrue( + np.allclose(out_use_state_dict[0], out_no_state_dict[0])) + + class TestResnet2(unittest.TestCase): """ Use paddle-2.0 API @@ -338,6 +475,8 @@ class TestResnet2(unittest.TestCase): return dy_out, dy_param_value, dy_grad_value def test_resnet(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) with fluid.dygraph.guard(): out_fp32 = self.train_resnet(enable_amp=False) out_amp = self.train_resnet(enable_amp=True) @@ -345,6 +484,8 @@ class TestResnet2(unittest.TestCase): self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) def test_with_data_loader(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) with fluid.dygraph.guard(): out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True) out_amp = self.train_resnet(enable_amp=True, use_data_loader=True) @@ -425,6 +566,8 @@ class TestResnet(unittest.TestCase): return dy_out, dy_param_value, dy_grad_value def test_resnet(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) out_fp32 = self.train_resnet(enable_amp=False) out_amp = self.train_resnet(enable_amp=True) print(out_fp32[0], out_amp[0])