未验证 提交 3c7cde95 编写于 作者: L liuruyan 提交者: GitHub

Address bug of open amp after dynamic to static, when control op in program. (#50799)

上级 1a8cc15e
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
import paddle.nn as nn
class Net_Cond(nn.Layer):
def __init__(self):
super().__init__()
def forward(self):
cond_input_x = paddle.ones(shape=[32, 32], dtype="float32")
cond_input_y = paddle.zeros(shape=[32, 32], dtype="float32")
if paddle.shape(cond_input_x)[0] <= paddle.shape(cond_input_y)[0]:
cond_input_y = paddle.matmul(
cond_input_x,
cond_input_x.T,
)
return cond_input_y.mean()
class Net_While(nn.Layer):
def __init__(self):
super().__init__()
def forward(self):
while_input_x = paddle.ones(shape=[64, 32], dtype="float32")
while_input_y = paddle.zeros(shape=[32, 32], dtype="float32")
while paddle.shape(while_input_x)[1] >= paddle.shape(while_input_y)[1]:
while_input_y = paddle.matmul(
while_input_x,
while_input_x.T,
)
return while_input_y.mean()
class Net_Sub_Block_FP32(nn.Layer):
def __init__(self):
super().__init__()
def forward(self):
cond_input_x = paddle.ones(shape=[32, 32], dtype="float32")
cond_input_y = paddle.zeros(shape=[32, 32], dtype="float32")
if paddle.shape(cond_input_x)[0] <= paddle.shape(cond_input_y)[0]:
cond_input_y = paddle.log(cond_input_x)
return cond_input_y.mean()
class TestD2SAmpWithControlFlowOp(unittest.TestCase):
def test_cond_op(self):
model = Net_Cond()
model = paddle.jit.to_static(model)
model = paddle.amp.decorate(
models=model, level='O2', save_dtype="float32"
)
with paddle.amp.auto_cast(level='O2'):
model()
def test_while_op(self):
model = Net_While()
model = paddle.jit.to_static(model)
model = paddle.amp.decorate(
models=model, level='O2', save_dtype="float32"
)
with paddle.amp.auto_cast(level='O2'):
model()
def test_sub_block_fp32_op(self):
model = Net_Sub_Block_FP32()
model = paddle.jit.to_static(model)
model = paddle.amp.decorate(
models=model, level='O2', save_dtype="float32"
)
with paddle.amp.auto_cast(level='O2'):
model()
if __name__ == '__main__':
unittest.main()
...@@ -430,6 +430,15 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -430,6 +430,15 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
if amp_lists is None: if amp_lists is None:
amp_lists = AutoMixedPrecisionLists() amp_lists = AutoMixedPrecisionLists()
amp_lists.unsupported_list -= {
"conditional_block_grad",
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"while_grad",
"cast",
}
global_block = program.global_block() global_block = program.global_block()
keep_fp32_ops = set() keep_fp32_ops = set()
to_fp16_var_names = set() to_fp16_var_names = set()
...@@ -454,7 +463,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -454,7 +463,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
for in_var_name in op.input(in_name): for in_var_name in op.input(in_name):
in_var = None in_var = None
try: try:
in_var = block.var(in_var_name) in_var = block._var_recursive(in_var_name)
except ValueError as e: except ValueError as e:
_logger.debug( _logger.debug(
"-- {}, try to get it in the global block --".format( "-- {}, try to get it in the global block --".format(
...@@ -491,7 +500,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -491,7 +500,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
out_var = None out_var = None
try: try:
out_var = block.var(out_var_name) out_var = block._var_recursive(out_var_name)
except ValueError as e: except ValueError as e:
_logger.debug( _logger.debug(
"-- {}, try to get it in the global block --".format( "-- {}, try to get it in the global block --".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册