diff --git a/python/paddle/static/amp/decorator.py b/python/paddle/static/amp/decorator.py index 06169c668b8e8a0f6a9d7dd7bac7c0de28ea3e05..feadae70022be4e1f2328c7ca6b52cabba1f14e8 100644 --- a/python/paddle/static/amp/decorator.py +++ b/python/paddle/static/amp/decorator.py @@ -916,8 +916,6 @@ def decorate( warnings.warn( "If the Amp level is set to OD, the amp list will not be used." ) - - amp_lists.white_list = {"conv2d", "matmul_v2"} amp_lists.black_list = amp_lists.all_list - amp_lists.white_list if use_dynamic_loss_scaling is None: diff --git a/python/paddle/static/amp/fp16_utils.py b/python/paddle/static/amp/fp16_utils.py index 1e85770251dc3435ee482598f0b4df3683e878cf..ea73d48cf3a967de22f211009642d8417a6bab7b 100644 --- a/python/paddle/static/amp/fp16_utils.py +++ b/python/paddle/static/amp/fp16_utils.py @@ -622,8 +622,6 @@ def cast_model_to_fp16( if amp_lists is not None: dtype = get_low_precision_dtypestr(dest_type) amp_lists = AutoMixedPrecisionLists(dtype) - - amp_lists.white_list = {"conv2d", "matmul_v2"} amp_lists.black_list = amp_lists.all_list - amp_lists.white_list global_block = program.global_block() diff --git a/test/amp/test_amp_api.py b/test/amp/test_amp_api.py index 1bd41517f8f48c9ca4542d7fc2ddc1a4bdf47257..179236d909cbe0f0e1dc89fefb5f91ef3e9a11dd 100644 --- a/test/amp/test_amp_api.py +++ b/test/amp/test_amp_api.py @@ -15,9 +15,10 @@ import unittest import numpy as np -from amp_base_models import AmpTestBase, build_conv_model +from amp_base_models import AmpTestBase import paddle +from paddle import nn from paddle.static import amp @@ -39,25 +40,47 @@ class TestAutoCast(AmpTestBase): self.assertEqual(out3.dtype, paddle.float32) +class SimpleConvNet(nn.Layer): + def __init__(self): + super().__init__() + self._conv = paddle.nn.Conv2D( + in_channels=1, out_channels=6, kernel_size=3, bias_attr=False + ) + self._linear = paddle.nn.Linear(in_features=4, out_features=4) + + def forward(self, x): + out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) + out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16') + out3 = self._linear(out2) + return out3 + + class TestStaticDecorate(AmpTestBase): def check_results( self, use_amp, dtype, level, use_promote, expected_op_calls ): - ( - main_program, - startup_program, - optimizer, - feed_vars, - fetch_vars, - ) = build_conv_model(use_amp, dtype, level, use_promote) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.utils.unique_name.guard(): + with paddle.static.program_guard(main_program, startup_program): + model = SimpleConvNet() + x = paddle.static.data( + name='input', shape=[None, 1, 6, 6], dtype='float32' + ) + out = model(x) + loss = paddle.mean(out) + optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001) + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + level=level, + ) + optimizer.minimize(loss) + + feed_vars = [x] + fetch_vars = [loss] self.assertEqual(main_program.num_blocks, 1) - optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001) - optimizer = paddle.static.amp.decorate( - optimizer, - init_loss_scaling=128.0, - use_dynamic_loss_scaling=True, - level=level, - ) amp.debugging.collect_operator_stats(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program) @@ -85,16 +108,13 @@ class TestStaticDecorate(AmpTestBase): level, ) - def test_static_amp_o1(self): + def test_static_amp_OD(self): paddle.enable_static() expected_fp16_calls = { "conv2d": 1, "elementwise_add": 0, - "relu": 0, "matmul_v2": 1, - "softmax": 0, "reduce_mean": 0, - "adamw": 0, } self.check_results( True,