From 3317cf013b53c029d8530dd0cd9f9162588ff0ec Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Wed, 20 Jan 2021 17:10:42 +0800 Subject: [PATCH] [cherry pick]Add pure fp16 amp_init for fleet API. (#30592) * add fleet amp.init() * add unittest for fleet_amp_init --- .../distributed/fleet/base/fleet_base.py | 64 +++++++++++++++ .../contrib/mixed_precision/fp16_lists.py | 8 +- .../tests/unittests/test_fleet_amp_init.py | 80 +++++++++++++++++++ 3 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_amp_init.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index a45cdd6f38f..3a631edb921 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -958,6 +958,70 @@ class Fleet(object): # imitate target optimizer retrieval return self.user_defined_optimizer.clear_grad() + def amp_init(self, + place, + scope=None, + test_program=None, + use_fp16_test=False): + """ + Init the amp training, such as cast fp32 parameters to fp16 type. + + Args: + place(CUDAPlace): place is used to initialize + fp16 parameters with fp32 values. + scope(Scope): The scope is used to find fp32 parameters. + test_program(Program): The program is used for testing. + use_fp16_test(bool): Whether to use fp16 testing. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + paddle.enable_static() + + def run_example_code(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32') + conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3) + # 1) Use fp16_guard to control the range of fp16 kernels used. + with paddle.static.amp.fp16_guard(): + bn = paddle.static.nn.batch_norm(input=conv2d, act="relu") + pool = F.max_pool2d(bn, kernel_size=2, stride=2) + hidden = paddle.static.nn.fc(pool, size=10) + loss = paddle.mean(hidden) + # 2) Create the optimizer and set `multi_precision` to True. + # Setting `multi_precision` to True can avoid the poor accuracy + # or the slow convergence in a way. + optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True) + # 3) These ops in `custom_black_list` will keep in the float32 computation type. + amp_list = paddle.static.amp.CustomOpLists( + custom_black_list=['pool2d']) + # 4) The entry of Paddle AMP. + # Enable pure fp16 training by setting `use_pure_fp16` to True. + optimizer = paddle.static.amp.decorate( + optimizer, + amp_list, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True) + # If you don't use the default_startup_program(), you sholud pass + # your defined `startup_program` into `minimize`. + optimizer.minimize(loss) + exe.run(paddle.static.default_startup_program()) + # 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`). + # If you want to perform the testing process, you should pass `test_program` into `amp_init`. + optimizer.amp_init(place, scope=paddle.static.global_scope()) + + if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: + run_example_code() + """ + # imitate target optimizer retrieval + return self.user_defined_optimizer.amp_init( + place, scope=None, test_program=None, use_fp16_test=False) + def _final_strategy(self): if "valid_strategy" not in self._context: print( diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 1e428624853..c88ae2d9cbf 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -95,6 +95,9 @@ black_list = { 'sigmoid_cross_entropy_with_logits', 'cross_entropy', 'cross_entropy2', + # fp16 is slower than fp32, though fp16 is supported. + 'lookup_table', + 'lookup_table_v2', } # This set contains two types of ops. All ops supported fp16 calculation. One @@ -115,8 +118,6 @@ gray_list = { 'layer_norm', 'tanh', 'sigmoid', - 'lookup_table', - 'lookup_table_v2', 'top_k', 'pool2d', 'pool3d', @@ -284,6 +285,9 @@ unsupported_fp16_list = { 'generate_proposals', 'generate_proposal_labels', 'generate_mask_labels', + # fp16 is slower than fp32, though fp16 is supported. + 'lookup_table', + 'lookup_table_v2', } CustomOpLists = AutoMixedPrecisionLists diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py new file mode 100644 index 00000000000..d7da4ead1b0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_init.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed.fleet.base.role_maker as role_maker +import paddle.distributed.fleet as fleet +import paddle.fluid as fluid +import unittest +import paddle.nn.functional as F +import numpy as np + +paddle.enable_static() + + +def gen_data(): + return { + "x": np.random.random(size=(128, 32)).astype('float32'), + "y": np.random.randint( + 2, size=(128, 1)).astype('int64') + } + + +def mlp(input_x, input_y, hid_dim=128, label_dim=2): + fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim, activation='tanh') + fc_2 = paddle.static.nn.fc(x=fc_1, size=hid_dim, activation='tanh') + prediction = paddle.static.nn.fc(x=[fc_2], + size=label_dim, + activation='softmax') + cost = F.cross_entropy(input=prediction, label=input_y) + avg_cost = paddle.mean(x=cost) + return avg_cost + + +class TestFleetAMPInit(unittest.TestCase): + def test_fleet_amp_init(self): + if not fluid.core.is_compiled_with_cuda(): + return + input_x = paddle.static.data( + name="x", shape=[None, 32], dtype='float32') + input_y = paddle.static.data(name="y", shape=[None, 1], dtype='int64') + + cost = mlp(input_x, input_y) + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, + momentum=0.9, + weight_decay=fluid.regularizer.L2Decay(1e-4), + multi_precision=True) + + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + + optimizer = paddle.static.amp.decorate(optimizer) + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(cost) + place = paddle.CUDAPlace(0) + + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + optimizer.amp_init(place, use_fp16_test=True) + + step = 1 + for i in range(step): + cost_val = exe.run(program=paddle.static.default_main_program(), + feed=gen_data(), + fetch_list=[cost.name]) + + +if __name__ == '__main__': + unittest.main() -- GitLab