From 3c7cde959cd171a6f43110caf89cc262563bc02c Mon Sep 17 00:00:00 2001 From: liuruyan <44316842+liuruyan@users.noreply.github.com> Date: Fri, 10 Mar 2023 11:12:28 +0800 Subject: [PATCH] Address bug of open amp after dynamic to static, when control op in program. (#50799) --- .../contrib/tests/test_d2s_amp_controlflow.py | 93 +++++++++++++++++++ python/paddle/static/amp/fp16_utils.py | 13 ++- 2 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/contrib/tests/test_d2s_amp_controlflow.py diff --git a/python/paddle/fluid/contrib/tests/test_d2s_amp_controlflow.py b/python/paddle/fluid/contrib/tests/test_d2s_amp_controlflow.py new file mode 100644 index 00000000000..2551922462a --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_d2s_amp_controlflow.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np +import paddle.nn as nn + + +class Net_Cond(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self): + cond_input_x = paddle.ones(shape=[32, 32], dtype="float32") + cond_input_y = paddle.zeros(shape=[32, 32], dtype="float32") + if paddle.shape(cond_input_x)[0] <= paddle.shape(cond_input_y)[0]: + cond_input_y = paddle.matmul( + cond_input_x, + cond_input_x.T, + ) + return cond_input_y.mean() + + +class Net_While(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self): + while_input_x = paddle.ones(shape=[64, 32], dtype="float32") + while_input_y = paddle.zeros(shape=[32, 32], dtype="float32") + while paddle.shape(while_input_x)[1] >= paddle.shape(while_input_y)[1]: + while_input_y = paddle.matmul( + while_input_x, + while_input_x.T, + ) + return while_input_y.mean() + + +class Net_Sub_Block_FP32(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self): + cond_input_x = paddle.ones(shape=[32, 32], dtype="float32") + cond_input_y = paddle.zeros(shape=[32, 32], dtype="float32") + if paddle.shape(cond_input_x)[0] <= paddle.shape(cond_input_y)[0]: + cond_input_y = paddle.log(cond_input_x) + return cond_input_y.mean() + + +class TestD2SAmpWithControlFlowOp(unittest.TestCase): + def test_cond_op(self): + model = Net_Cond() + model = paddle.jit.to_static(model) + model = paddle.amp.decorate( + models=model, level='O2', save_dtype="float32" + ) + with paddle.amp.auto_cast(level='O2'): + model() + + def test_while_op(self): + model = Net_While() + model = paddle.jit.to_static(model) + model = paddle.amp.decorate( + models=model, level='O2', save_dtype="float32" + ) + with paddle.amp.auto_cast(level='O2'): + model() + + def test_sub_block_fp32_op(self): + model = Net_Sub_Block_FP32() + model = paddle.jit.to_static(model) + model = paddle.amp.decorate( + models=model, level='O2', save_dtype="float32" + ) + with paddle.amp.auto_cast(level='O2'): + model() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 281d3638ee2..bfe0a146f23 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -430,6 +430,15 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): if amp_lists is None: amp_lists = AutoMixedPrecisionLists() + amp_lists.unsupported_list -= { + "conditional_block_grad", + "conditional_block", + "conditional_block_infer", + "select_input", + "while", + "while_grad", + "cast", + } global_block = program.global_block() keep_fp32_ops = set() to_fp16_var_names = set() @@ -454,7 +463,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): for in_var_name in op.input(in_name): in_var = None try: - in_var = block.var(in_var_name) + in_var = block._var_recursive(in_var_name) except ValueError as e: _logger.debug( "-- {}, try to get it in the global block --".format( @@ -491,7 +500,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): for out_var_name in op.output(out_name): out_var = None try: - out_var = block.var(out_var_name) + out_var = block._var_recursive(out_var_name) except ValueError as e: _logger.debug( "-- {}, try to get it in the global block --".format( -- GitLab