From 9bc596730227e576c2adf69f3d14d83e76a86f9b Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 19 Jul 2021 19:22:46 +0800 Subject: [PATCH] [amp] pass found_inf to adam to suppport skip_update (#34176) * pass found_inf to adam * add unittest * fix bug * refine unittest * change unit test's directory * disable unittest on cpu --- .../contrib/mixed_precision/decorator.py | 4 +- python/paddle/fluid/optimizer.py | 21 +++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../tests/unittests/test_mixed_precision.py | 123 ++++++++++++++++++ 4 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_mixed_precision.py diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 09b8629a97..7a646e069d 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -398,7 +398,9 @@ class OptimizerWithMixedPrecision(object): self._incr_ratio, self._decr_ratio, name="update_loss_scaling") - + # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow + if isinstance(self._optimizer, paddle.fluid.optimizer.Adam): + self._optimizer._set_auxiliary_var('found_inf', found_inf) optimize_ops = self._optimizer.apply_gradients(params_grads) return optimize_ops diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index b89793bfd6..ddd9ef2327 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -142,6 +142,11 @@ class Optimizer(object): self._opti_name_list = [] self._accumulators_holder = {} self._param_device_map = dict() + # NOTE(zhiqiu): sometimes we want to add some variables(Tenosr) to the optimizer for a specific optimization, + # for example, we want to pass 'found_inf' to adam optimizer so it can skip update when found_inf is True. + # And these variables should not be the parameters of Optimizer's construnctor (because not commonly used). + # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. + self._auxiliary_vars = dict() @framework.dygraph_only def state_dict(self): @@ -294,6 +299,15 @@ class Optimizer(object): def get_opti_var_name_list(self): return self._opti_name_list + def _set_auxiliary_var(self, key, val): + self._auxiliary_vars[key] = val + + def _get_auxiliary_var(self, key): + if key in self._auxiliary_vars: + return self._auxiliary_vars[key] + else: + return None + def _create_global_learning_rate(self): from paddle.optimizer.lr import LRScheduler if isinstance(self._learning_rate, LRScheduler): @@ -2467,6 +2481,13 @@ class AdamOptimizer(Optimizer): "Beta1Pow": [beta1_pow_acc], "Beta2Pow": [beta2_pow_acc] } + + # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow + found_inf = self._get_auxiliary_var('found_inf') + + if found_inf: + inputs['SkipUpdate'] = found_inf + outputs = { "ParamOut": [param_and_grad[0]], "Moment1Out": [moment1], diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 5914a74c0c..fcb2dbfa2e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -190,6 +190,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) + LIST(REMOVE_ITEM TEST_OPS test_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single) LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute) elseif(WITH_GPU) diff --git a/python/paddle/fluid/tests/unittests/test_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_mixed_precision.py new file mode 100644 index 0000000000..89d40e9314 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_mixed_precision.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.contrib.mixed_precision import fp16_utils +import paddle +import paddle.nn as nn +import paddle.static as static +import numpy as np + +paddle.enable_static() + + +class SimpleNet(nn.Layer): + def __init__(self, input_size, output_size): + super(SimpleNet, self).__init__() + self.linear1 = nn.Linear(input_size, output_size) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(input_size, output_size) + self.relu2 = nn.ReLU() + self.linear3 = nn.Linear(input_size, output_size) + + def forward(self, x): + + x = self.linear1(x) + # currently, paddle's relu may hide nan/inf, relu(nan) = 0, relu(inf)= inf + # so, do not use it here. + #x = self.relu1(x) + x = self.linear2(x) + #x = self.relu2(x) + x = self.linear3(x) + + return x + + +class AMPTest(unittest.TestCase): + def net(self): + input_size = 4096 + output_size = 4096 + x = static.data(name='X', shape=[1000, 4096], dtype='float32') + label = static.data(name='Y', shape=[1000, 4096], dtype='float32') + model = SimpleNet(input_size, output_size) # 定义模型 + mse = paddle.nn.MSELoss() + + out = model(x) + loss = mse(out, label) + + opt = paddle.fluid.optimizer.Adam( + learning_rate=0.0001, parameter_list=model.parameters()) # 定义优化器 + opt = paddle.static.amp.decorate( + opt, init_loss_scaling=128.0, use_dynamic_loss_scaling=True) + opt.minimize(loss) + return model, loss, opt + + def test_skip_update(self): + input_size = 4096 + output_size = 4096 + batch_size = 1000 + nums_batch = 10 + startup_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with static.program_guard(main_prog, startup_prog): + model, loss, opt = self.net() + weight = model.linear1.weight + moment1 = opt._optimizer._get_accumulator( + opt._optimizer._moment1_acc_str, weight) + beta_pow1 = opt._optimizer._get_accumulator( + opt._optimizer._beta1_pow_acc_str, weight) + fetch_list = [ + loss, weight, moment1, beta_pow1, 'find_infinite_scale.tmp_0' + ] + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + + train_data = [ + np.random.rand(batch_size, input_size).astype(np.float32) + for _ in range(nums_batch) + ] + labels = [ + np.random.rand(batch_size, output_size).astype(np.float32) + for _ in range(nums_batch) + ] + + weight_, moment1_, beta_pow1_ = exe.run( + startup_prog, fetch_list=[weight, moment1, beta_pow1]) + pre_weight_, pre_moment1_, pre_beta_pow1_ = weight_, moment1_, beta_pow1_ + for i in range(nums_batch): + if i % 2: + train_data[i][10] = np.inf + loss_, weight_, moment1_, beta_pow1_, found_inf = exe.run( + main_prog, + feed={"X": train_data[i], + "Y": labels[i]}, + fetch_list=fetch_list) + print(loss_, weight_[0][0], moment1_[0][0], beta_pow1_, + found_inf) + if i % 2: + self.assertTrue(found_inf) + self.assertTrue(np.array_equal(weight_, pre_weight_)) + self.assertTrue(np.array_equal(moment1_, pre_moment1_)) + self.assertTrue(np.array_equal(beta_pow1_, pre_beta_pow1_)) + else: + self.assertFalse(found_inf) + self.assertFalse(np.array_equal(weight_, pre_weight_)) + self.assertFalse(np.array_equal(moment1_, pre_moment1_)) + self.assertFalse(np.array_equal(beta_pow1_, pre_beta_pow1_)) + pre_weight_, pre_moment1_, pre_beta_pow1_ = weight_, moment1_, beta_pow1_ + + +if __name__ == '__main__': + unittest.main() -- GitLab