From 0900a7902efb467d3b4c445c127305ea74938f54 Mon Sep 17 00:00:00 2001 From: JYChen Date: Wed, 6 Sep 2023 16:05:13 +0800 Subject: [PATCH] fix setvalue dtype error when using dy2st and amp O2 (#56868) * fix setvalue dtype error when using dy2st and amp O2 * add one test * remove test share_buffer since win/linux have different number --- python/paddle/static/amp/fp16_utils.py | 4 +++ test/amp/test_amp_api.py | 46 ++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 46c669ba54e..b5c78c69f66 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -642,6 +642,10 @@ def cast_model_to_fp16( def need_process(op): need_process = True + if op.type in ["set_value"]: + # NOTE(zoooo0820): OP set_value has attribute "dtype", but its output type is + # determined by the input.dtype instead of attribute. So, here we still process it. + return need_process if op.type in ["create_py_reader", "read"]: need_process = False else: diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 607117c84aa..58d7c1dd2f8 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -258,5 +258,51 @@ class TestFp16Guard(AmpTestBase): paddle.disable_static() +class SimpleModelIncludeSetValue(nn.Layer): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(3) + + def forward(self, x): + x = x + 1 + tmp = x * 1 + y = self.norm(tmp) + x[:] = y + + z = x * 1 + return z + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 7.0, + "run test when gpu's compute capability is at least 7.0.", +) +class TestDy2STWithSetValue(AmpTestBase): + def test_op_called_as_expected(self): + expected_fp16_calls = { + "cast": 0, + "layer_norm": 1, + "scale": 3, + "set_value": 1, + } + + func = SimpleModelIncludeSetValue() + func = paddle.amp.decorate(func, level='O2') + func = paddle.jit.to_static(func) + input = paddle.randn((2, 3)) + + with paddle.amp.auto_cast(level='O2'): + res = func(input) + loss = res.sum() + prog = func.forward.get_concrete_program(input)[1].forward_program + amp.debugging.collect_operator_stats(prog) + op_stats_list = amp.debugging._get_op_stats_list(prog) + loss.backward() + self._check_op_calls( + op_stats_list[0], expected_fp16_calls=expected_fp16_calls + ) + + if __name__ == '__main__': unittest.main() -- GitLab