diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 46c669ba54e46b3be2fc3f0bee92b81c954639ac..b5c78c69f66790d2f4062b3483d35a4e958a26c4 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -642,6 +642,10 @@ def cast_model_to_fp16( def need_process(op): need_process = True + if op.type in ["set_value"]: + # NOTE(zoooo0820): OP set_value has attribute "dtype", but its output type is + # determined by the input.dtype instead of attribute. So, here we still process it. + return need_process if op.type in ["create_py_reader", "read"]: need_process = False else: diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 607117c84aa043ad86a93574451127846c661a7d..58d7c1dd2f8e23403040840b3cf2302f465d8ab5 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -258,5 +258,51 @@ class TestFp16Guard(AmpTestBase): paddle.disable_static() +class SimpleModelIncludeSetValue(nn.Layer): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(3) + + def forward(self, x): + x = x + 1 + tmp = x * 1 + y = self.norm(tmp) + x[:] = y + + z = x * 1 + return z + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or paddle.device.cuda.get_device_capability()[0] < 7.0, + "run test when gpu's compute capability is at least 7.0.", +) +class TestDy2STWithSetValue(AmpTestBase): + def test_op_called_as_expected(self): + expected_fp16_calls = { + "cast": 0, + "layer_norm": 1, + "scale": 3, + "set_value": 1, + } + + func = SimpleModelIncludeSetValue() + func = paddle.amp.decorate(func, level='O2') + func = paddle.jit.to_static(func) + input = paddle.randn((2, 3)) + + with paddle.amp.auto_cast(level='O2'): + res = func(input) + loss = res.sum() + prog = func.forward.get_concrete_program(input)[1].forward_program + amp.debugging.collect_operator_stats(prog) + op_stats_list = amp.debugging._get_op_stats_list(prog) + loss.backward() + self._check_op_calls( + op_stats_list[0], expected_fp16_calls=expected_fp16_calls + ) + + if __name__ == '__main__': unittest.main()