未验证 提交 9bc59673 编写于 作者: L Leo Chen 提交者: GitHub

[amp] pass found_inf to adam to suppport skip_update (#34176)

* pass found_inf to adam

* add unittest

* fix bug

* refine unittest

* change unit test's directory

* disable unittest on cpu
上级 cc007dce
......@@ -398,7 +398,9 @@ class OptimizerWithMixedPrecision(object):
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
if isinstance(self._optimizer, paddle.fluid.optimizer.Adam):
self._optimizer._set_auxiliary_var('found_inf', found_inf)
optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops
......
......@@ -142,6 +142,11 @@ class Optimizer(object):
self._opti_name_list = []
self._accumulators_holder = {}
self._param_device_map = dict()
# NOTE(zhiqiu): sometimes we want to add some variables(Tenosr) to the optimizer for a specific optimization,
# for example, we want to pass 'found_inf' to adam optimizer so it can skip update when found_inf is True.
# And these variables should not be the parameters of Optimizer's construnctor (because not commonly used).
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
self._auxiliary_vars = dict()
@framework.dygraph_only
def state_dict(self):
......@@ -294,6 +299,15 @@ class Optimizer(object):
def get_opti_var_name_list(self):
return self._opti_name_list
def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val
def _get_auxiliary_var(self, key):
if key in self._auxiliary_vars:
return self._auxiliary_vars[key]
else:
return None
def _create_global_learning_rate(self):
from paddle.optimizer.lr import LRScheduler
if isinstance(self._learning_rate, LRScheduler):
......@@ -2467,6 +2481,13 @@ class AdamOptimizer(Optimizer):
"Beta1Pow": [beta1_pow_acc],
"Beta2Pow": [beta2_pow_acc]
}
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
found_inf = self._get_auxiliary_var('found_inf')
if found_inf:
inputs['SkipUpdate'] = found_inf
outputs = {
"ParamOut": [param_and_grad[0]],
"Moment1Out": [moment1],
......
......@@ -190,6 +190,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single)
LIST(REMOVE_ITEM TEST_OPS test_dygraph_recompute)
elseif(WITH_GPU)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.contrib.mixed_precision import fp16_utils
import paddle
import paddle.nn as nn
import paddle.static as static
import numpy as np
paddle.enable_static()
class SimpleNet(nn.Layer):
def __init__(self, input_size, output_size):
super(SimpleNet, self).__init__()
self.linear1 = nn.Linear(input_size, output_size)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(input_size, output_size)
self.relu2 = nn.ReLU()
self.linear3 = nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear1(x)
# currently, paddle's relu may hide nan/inf, relu(nan) = 0, relu(inf)= inf
# so, do not use it here.
#x = self.relu1(x)
x = self.linear2(x)
#x = self.relu2(x)
x = self.linear3(x)
return x
class AMPTest(unittest.TestCase):
def net(self):
input_size = 4096
output_size = 4096
x = static.data(name='X', shape=[1000, 4096], dtype='float32')
label = static.data(name='Y', shape=[1000, 4096], dtype='float32')
model = SimpleNet(input_size, output_size) # 定义模型
mse = paddle.nn.MSELoss()
out = model(x)
loss = mse(out, label)
opt = paddle.fluid.optimizer.Adam(
learning_rate=0.0001, parameter_list=model.parameters()) # 定义优化器
opt = paddle.static.amp.decorate(
opt, init_loss_scaling=128.0, use_dynamic_loss_scaling=True)
opt.minimize(loss)
return model, loss, opt
def test_skip_update(self):
input_size = 4096
output_size = 4096
batch_size = 1000
nums_batch = 10
startup_prog = paddle.static.Program()
main_prog = paddle.static.Program()
with static.program_guard(main_prog, startup_prog):
model, loss, opt = self.net()
weight = model.linear1.weight
moment1 = opt._optimizer._get_accumulator(
opt._optimizer._moment1_acc_str, weight)
beta_pow1 = opt._optimizer._get_accumulator(
opt._optimizer._beta1_pow_acc_str, weight)
fetch_list = [
loss, weight, moment1, beta_pow1, 'find_infinite_scale.tmp_0'
]
exe = paddle.static.Executor(paddle.CUDAPlace(0))
train_data = [
np.random.rand(batch_size, input_size).astype(np.float32)
for _ in range(nums_batch)
]
labels = [
np.random.rand(batch_size, output_size).astype(np.float32)
for _ in range(nums_batch)
]
weight_, moment1_, beta_pow1_ = exe.run(
startup_prog, fetch_list=[weight, moment1, beta_pow1])
pre_weight_, pre_moment1_, pre_beta_pow1_ = weight_, moment1_, beta_pow1_
for i in range(nums_batch):
if i % 2:
train_data[i][10] = np.inf
loss_, weight_, moment1_, beta_pow1_, found_inf = exe.run(
main_prog,
feed={"X": train_data[i],
"Y": labels[i]},
fetch_list=fetch_list)
print(loss_, weight_[0][0], moment1_[0][0], beta_pow1_,
found_inf)
if i % 2:
self.assertTrue(found_inf)
self.assertTrue(np.array_equal(weight_, pre_weight_))
self.assertTrue(np.array_equal(moment1_, pre_moment1_))
self.assertTrue(np.array_equal(beta_pow1_, pre_beta_pow1_))
else:
self.assertFalse(found_inf)
self.assertFalse(np.array_equal(weight_, pre_weight_))
self.assertFalse(np.array_equal(moment1_, pre_moment1_))
self.assertFalse(np.array_equal(beta_pow1_, pre_beta_pow1_))
pre_weight_, pre_moment1_, pre_beta_pow1_ = weight_, moment1_, beta_pow1_
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册