diff --git a/python/paddle/hapi/callbacks.py b/python/paddle/hapi/callbacks.py index ac95fea151ed01e06369511d5f8cba684004bb41..cd4b35ea29a83b1962ee93cdad152f82d30b362c 100644 --- a/python/paddle/hapi/callbacks.py +++ b/python/paddle/hapi/callbacks.py @@ -364,7 +364,7 @@ class ProgBarLogger(Callback): } if self._is_print(): print( - "The loss value printed in the log is the current step, and the metric is the average value of previous step." + "The loss value printed in the log is the current step, and the metric is the average value of previous steps." ) def on_epoch_begin(self, epoch=None, logs=None): diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 137ca186d7946a426b263b6b902e101be4744135..4f3d73b22e39027896885b1835447190c7e1b655 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -41,8 +41,6 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTra from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.layers.utils import flatten from paddle.fluid.layers import collective -from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy -from paddle.fluid.incubate.fleet.base import role_maker from paddle.io import DataLoader, Dataset, DistributedBatchSampler from paddle.fluid.executor import scope_guard, Executor @@ -50,6 +48,8 @@ from paddle.fluid.dygraph.layers import Layer from paddle.metric import Metric from paddle.static import InputSpec as Input import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.base import role_maker from .callbacks import config_callbacks, EarlyStopping from .model_summary import summary @@ -252,6 +252,11 @@ class StaticGraphAdapter(object): self._nranks = ParallelEnv().nranks self._local_rank = ParallelEnv().local_rank + self._amp_level = "O0" + self._amp_configs = {} + self._amp_custom_lists = {} + self._use_fp16_guard = True + @property def mode(self): return self.model.mode @@ -550,11 +555,26 @@ class StaticGraphAdapter(object): if self._nranks > 1: role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) - dist_strategy = DistributedStrategy() - dist_strategy.mode = "collective" - dist_strategy.collective_mode = "grad_allreduce" + dist_strategy = fleet.DistributedStrategy() + if self._amp_level != 'O0': + dist_strategy.amp = True + dist_strategy.amp_configs = self._amp_configs.copy() + dist_strategy.amp_configs.update(self._amp_custom_lists) + dist_strategy.amp_configs[ + 'use_pure_fp16'] = self._amp_level == 'O2' self.model._optimizer = fleet.distributed_optimizer( self.model._optimizer, strategy=dist_strategy) + elif self._amp_level != "O0" and core.is_compiled_with_cuda: + amp_lists = paddle.static.amp.AutoMixedPrecisionLists( + **self. + _amp_custom_lists) if self._amp_custom_lists else None + + self.model._optimizer = paddle.static.amp.decorate( + self.model._optimizer, + amp_lists=amp_lists, + use_pure_fp16=self._amp_level == "O2", + use_fp16_guard=self._use_fp16_guard, + **self._amp_configs) self.model._optimizer.minimize(self._loss_endpoint) @@ -598,6 +618,10 @@ class StaticGraphAdapter(object): startup_prog = self._startup_prog._prune(uninitialized) self._executor.run(startup_prog) + if self._amp_level == "O2" and mode == 'train' and core.is_compiled_with_cuda( + ): + self.model._optimizer.amp_init(place) + if self._nranks < 2: compiled_prog = fluid.CompiledProgram(prog) else: @@ -620,6 +644,11 @@ class DynamicGraphAdapter(object): } self._input_info = None + self._amp_level = "O0" + self._amp_configs = {} + self._amp_custom_lists = {} + self._use_fp16_guard = True + if self._nranks > 1: dist.init_parallel_env() stradegy = fluid.dygraph.parallel.ParallelStrategy() @@ -649,19 +678,30 @@ class DynamicGraphAdapter(object): labels = labels or [] labels = [to_variable(l) for l in to_list(labels)] - if self._nranks > 1: - outputs = self.ddp_model.forward(* [to_variable(x) for x in inputs]) - else: - outputs = self.model.network.forward( - * [to_variable(x) for x in inputs]) + if self._amp_level != "O0": + scaler = paddle.amp.GradScaler(**self._amp_configs) + with paddle.amp.auto_cast( + enable=self._amp_level != 'O0', **self._amp_custom_lists): + if self._nranks > 1: + outputs = self.ddp_model.forward( + * [to_variable(x) for x in inputs]) + else: + outputs = self.model.network.forward( + * [to_variable(x) for x in inputs]) - losses = self.model._loss(*(to_list(outputs) + labels)) - losses = to_list(losses) - final_loss = fluid.layers.sum(losses) - final_loss.backward() + losses = self.model._loss(*(to_list(outputs) + labels)) + losses = to_list(losses) + final_loss = fluid.layers.sum(losses) - self.model._optimizer.minimize(final_loss) - self.model.network.clear_gradients() + if self._amp_level != "O0": + scaled = scaler.scale(final_loss) + scaled.backward() + scaler.minimize(self.model._optimizer, scaled) + self.model.network.clear_gradients() + else: + final_loss.backward() + self.model._optimizer.minimize(final_loss) + self.model.network.clear_gradients() metrics = [] for metric in self.model._metrics: @@ -816,6 +856,16 @@ class Model(object): instantiating a Model. The input description, i.e, paddle.static.InputSpec, must be required for static graph. + When training on GPU, auto mixed precision (AMP) training is supported, and + pure float16 training is also supported in static mode while using Adam, + AdamW and Momentum optimizer. Before using pure float16 training, + `multi_precision` could be set to True when creating optimizer, which can + avoid poor accuracy or slow convergence in a way, and inputs of dtype float + should be cast to float16 by users. Users should also use + `paddle.static.amp.fp16_guard` API to limit the range of pure float16 + training, otherwise, 'use_fp16_guard' should be set to False by users. + However, limiting the range of is not supported during training using AMP. + Args: network (paddle.nn.Layer): The network is an instance of paddle.nn.Layer. @@ -830,6 +880,8 @@ class Model(object): Examples: + 1. A common example + .. code-block:: python import paddle @@ -838,7 +890,7 @@ class Model(object): from paddle.static import InputSpec device = paddle.set_device('cpu') # or 'gpu' - + net = nn.Sequential( nn.Flatten(1), nn.Linear(784, 200), @@ -852,6 +904,7 @@ class Model(object): model = paddle.Model(net, input, label) optim = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters()) + model.prepare(optim, paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy()) @@ -862,6 +915,43 @@ class Model(object): ]) data = paddle.vision.datasets.MNIST(mode='train', transform=transform) model.fit(data, epochs=2, batch_size=32, verbose=1) + + + 2. An example using mixed precision training. + + .. code-block:: python + + import paddle + import paddle.nn as nn + import paddle.vision.transforms as T + + def run_example_code(): + device = paddle.set_device('gpu') + + net = nn.Sequential(nn.Flatten(1), nn.Linear(784, 200), nn.Tanh(), + nn.Linear(200, 10)) + + model = paddle.Model(net) + optim = paddle.optimizer.SGD(learning_rate=1e-3, parameters=model.parameters()) + + amp_configs = { + "level": "O1", + "custom_white_list": {'conv2d'}, + "use_dynamic_loss_scaling": True + } + model.prepare(optim, + paddle.nn.CrossEntropyLoss(), + paddle.metric.Accuracy(), + amp_configs=amp_configs) + + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + data = paddle.vision.datasets.MNIST(mode='train', transform=transform) + model.fit(data, epochs=2, batch_size=32, verbose=1) + + # mixed precision training is only support on GPU now. + if paddle.is_compiled_with_cuda(): + run_example_code() + """ def __init__(self, network, inputs=None, labels=None): @@ -1241,7 +1331,94 @@ class Model(object): """ return self._adapter.parameters() - def prepare(self, optimizer=None, loss=None, metrics=None): + def _prepare_amp(self, amp_configs): + def _check_pure_fp16_configs(): + # pure float16 training has some restricts now + if self._adapter._amp_level == "O2": + if in_dygraph_mode(): + warnings.warn("Pure float16 training is not supported in dygraph mode now, "\ + "and it will be supported in future version.") + else: + # grad clip is not supported in pure fp16 training now + assert self._optimizer._grad_clip is None, \ + "Grad clip is not supported in pure float16 training now, and it will be supported in future version." + + self._adapter._amp_custom_lists = {} + self._adapter._amp_configs = {} + + # check and get level of mixed precision training + if not amp_configs: + self._adapter._amp_level = 'O0' + return + elif isinstance(amp_configs, str): + if amp_configs not in ('O0', 'O1', 'O2'): + raise ValueError( + "The level of amp_configs should be 'O0', 'O1' or 'O2'.") + self._adapter._amp_level = amp_configs + _check_pure_fp16_configs() + return + else: + if 'level' not in amp_configs: + self._adapter._amp_level = 'O1' + elif amp_configs['level'] not in ('O0', 'O1', 'O2'): + raise ValueError( + "amp_configs['level'] should be 'O0', 'O1' or 'O2'.") + else: + self._adapter._amp_level = amp_configs['level'] + amp_config_key_set = set(amp_configs.keys()) - {'level'} + if not amp_config_key_set or self._adapter._amp_level == 'O0': + return + + if 'use_pure_fp16' in amp_configs: + raise ValueError( + "''use_pure_fp16' is an invalid parameter, " + "the level of mixed precision training only depends on 'O1' or 'O2'." + ) + + _check_pure_fp16_configs() + + # construct amp_custom_lists + if self._adapter._amp_level != 'O0' and amp_config_key_set: + for param_name in [ + 'custom_white_list', 'custom_black_list', + 'custom_black_varnames' + ]: + if param_name in amp_config_key_set: + self._adapter._amp_custom_lists[param_name] = amp_configs[ + param_name] + amp_config_key_set -= {param_name} + + def _check_amp_configs(amp_config_key_set): + accepted_param_set = { + 'init_loss_scaling', + 'incr_ratio', + 'decr_ratio', + 'incr_every_n_steps', + 'decr_every_n_nan_or_inf', + 'use_dynamic_loss_scaling', + 'use_fp16_guard', + } + if amp_config_key_set - accepted_param_set: + raise ValueError( + "Except for 'level', the keys of 'amp_configs' must be accepted by mixed precision APIs, " + "but {} could not be recognized.".format( + tuple(amp_config_key_set - accepted_param_set))) + + if 'use_fp16_guard' in amp_config_key_set: + if in_dygraph_mode(): + raise ValueError( + "'use_fp16_guard' is supported in static mode only.") + self._adapter._use_fp16_guard = amp_configs['use_fp16_guard'] + amp_config_key_set.remove('use_fp16_guard') + + return amp_config_key_set + + amp_configs_set = _check_amp_configs(amp_config_key_set) + for key in amp_configs_set: + self._adapter._amp_configs[key] = amp_configs[key] + + def prepare(self, optimizer=None, loss=None, metrics=None, + amp_configs=None): """ Configures the model before runing. @@ -1255,7 +1432,23 @@ class Model(object): It can be None when there is no loss. metrics (Metric|list of Metric|None): If metrics is set, all metrics will be calculated and output in train/eval mode. - + amp_configs (str|dict|None): AMP configurations. If AMP or pure + float16 training is used, the key 'level' of 'amp_configs' + should be set to 'O1' or 'O2' respectively. Otherwise, the + value of 'level' defaults to 'O0', which means float32 + training. In addition to 'level', users could pass in more + parameters consistent with mixed precision API. The supported + keys are: 'init_loss_scaling', 'incr_ratio', 'decr_ratio', + 'incr_every_n_steps', 'decr_every_n_nan_or_inf', + 'use_dynamic_loss_scaling', 'custom_white_list', + 'custom_black_list', and 'custom_black_varnames'or + 'use_fp16_guard' is only supported in static mode. Users could + refer to mixed precision API documentations + :ref:`api_paddle_amp_auto_cast` and + :ref:`api_paddle_amp_GradScaler` for details. For convenience, + 'amp_configs' could be set to 'O1' or 'O2' if no more + parameters are needed. 'amp_configs' could be None in float32 + training. Default: None. Returns: None """ @@ -1292,6 +1485,7 @@ class Model(object): "{} is not sub class of Metric".format( metric.__class__.__name__) self._metrics = to_list(metrics) + self._prepare_amp(amp_configs) if not in_dygraph_mode(): self._adapter.prepare() diff --git a/python/paddle/tests/dist_hapi_pure_fp16_static.py b/python/paddle/tests/dist_hapi_pure_fp16_static.py new file mode 100644 index 0000000000000000000000000000000000000000..0174e4f54e3416c174768051c8191304b01d2f2d --- /dev/null +++ b/python/paddle/tests/dist_hapi_pure_fp16_static.py @@ -0,0 +1,60 @@ +# copyright (c) 2021 paddlepaddle authors. all rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np + +import paddle +from paddle import fluid + +from paddle import Model +from paddle.static import InputSpec +from paddle.nn.layer.loss import CrossEntropyLoss +from paddle.vision.models import LeNet + + +@unittest.skipIf(not fluid.is_compiled_with_cuda(), + 'CPU testing is not supported') +class TestDistTraningWithPureFP16(unittest.TestCase): + def test_amp_training_purefp16(self): + if not fluid.is_compiled_with_cuda(): + self.skipTest('module not tested when ONLY_CPU compling') + data = np.random.random(size=(4, 1, 28, 28)).astype(np.float32) + label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) + + paddle.enable_static() + paddle.set_device('gpu') + net = LeNet() + amp_level = "O2" + inputs = InputSpec([None, 1, 28, 28], "float32", 'x') + labels = InputSpec([None, 1], "int64", "y") + model = Model(net, inputs, labels) + optim = paddle.optimizer.Adam( + learning_rate=0.001, + parameters=model.parameters(), + multi_precision=True) + amp_configs = {"level": amp_level, "use_fp16_guard": False} + model.prepare( + optimizer=optim, + loss=CrossEntropyLoss(reduction="sum"), + amp_configs=amp_configs) + model.train_batch([data], [label]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_dist_hapi_model.py b/python/paddle/tests/test_dist_hapi_model.py index db5b63c5ae0e29fa6f1274befd277c4e46c3a1b1..16788e4656192e43f17e09464d1d53ab6dda3ce7 100644 --- a/python/paddle/tests/test_dist_hapi_model.py +++ b/python/paddle/tests/test_dist_hapi_model.py @@ -129,6 +129,9 @@ class TestMultipleGpus(unittest.TestCase): def test_hapi_multiple_gpus_dynamic(self): self.run_mnist_2gpu('dist_hapi_mnist_dynamic.py') + def test_hapi_amp_static(self): + self.run_mnist_2gpu('dist_hapi_pure_fp16_static.py') + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tests/test_hapi_amp.py b/python/paddle/tests/test_hapi_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..ecab4db7516d752c93c5ed74f28ba17232bec115 --- /dev/null +++ b/python/paddle/tests/test_hapi_amp.py @@ -0,0 +1,115 @@ +# copyright (c) 2020 paddlepaddle authors. all rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np + +import paddle +from paddle import fluid + +from paddle import Model +from paddle.static import InputSpec +from paddle.nn.layer.loss import CrossEntropyLoss +from paddle.vision.models import LeNet + + +@unittest.skipIf(not fluid.is_compiled_with_cuda(), + 'CPU testing is not supported') +class TestDistTraningUsingAMP(unittest.TestCase): + def test_amp_training(self): + if not fluid.is_compiled_with_cuda(): + self.skipTest('module not tested when ONLY_CPU compling') + data = np.random.random(size=(4, 1, 28, 28)).astype(np.float32) + label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) + amp_level = "O1" + for dynamic in [True, False]: + if not fluid.is_compiled_with_cuda(): + self.skipTest('module not tested when ONLY_CPU compling') + paddle.enable_static() if not dynamic else None + paddle.set_device('gpu') + net = LeNet() + inputs = InputSpec([None, 1, 28, 28], "float32", 'x') + labels = InputSpec([None, 1], "int64", "y") + model = Model(net, inputs, labels) + optim = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + amp_configs = {"level": amp_level} + model.prepare( + optimizer=optim, + loss=CrossEntropyLoss(reduction="sum"), + amp_configs=amp_configs) + model.train_batch([data], [label]) + + def test_dynamic_check_input(self): + paddle.disable_static() + amp_configs_list = [ + { + "level": "O3" + }, + { + "level": "O1", + "test": 0 + }, + { + "level": "O1", + "use_fp16_guard": True + }, + "O3", + ] + if not fluid.is_compiled_with_cuda(): + self.skipTest('module not tested when ONLY_CPU compling') + paddle.set_device('gpu') + net = LeNet() + model = Model(net) + optim = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + loss = CrossEntropyLoss(reduction="sum") + with self.assertRaises(ValueError): + for amp_configs in amp_configs_list: + model.prepare( + optimizer=optim, loss=loss, amp_configs=amp_configs) + model.prepare(optimizer=optim, loss=loss, amp_configs="O2") + model.prepare( + optimizer=optim, + loss=loss, + amp_configs={ + "custom_white_list": {"matmul"}, + "init_loss_scaling": 1.0 + }) + + def test_static_check_input(self): + paddle.enable_static() + amp_configs = {"level": "O2", "use_pure_fp16": True} + if not fluid.is_compiled_with_cuda(): + self.skipTest('module not tested when ONLY_CPU compling') + paddle.set_device('gpu') + + net = LeNet() + inputs = InputSpec([None, 1, 28, 28], "float32", 'x') + labels = InputSpec([None, 1], "int64", "y") + model = Model(net, inputs, labels) + + optim = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + loss = CrossEntropyLoss(reduction="sum") + with self.assertRaises(ValueError): + model.prepare(optimizer=optim, loss=loss, amp_configs=amp_configs) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index af54b046fe699fa29cf6948f990a5cb9d44ddcda..10ceb48796903864b979cc21534206d2d936cbcd 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -622,6 +622,8 @@ class TestModelFunction(unittest.TestCase): paddle.enable_static() def test_dygraph_export_deploy_model_about_inputs(self): + self.set_seed() + np.random.seed(201) mnist_data = MnistDataset(mode='train') paddle.disable_static() # without inputs