未验证 提交 56947361 编写于 作者: N niuliling123 提交者: GitHub

Fix ctest error in test_amp_api (#53885)

上级 f7083f47
...@@ -916,8 +916,6 @@ def decorate( ...@@ -916,8 +916,6 @@ def decorate(
warnings.warn( warnings.warn(
"If the Amp level is set to OD, the amp list will not be used." "If the Amp level is set to OD, the amp list will not be used."
) )
amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
if use_dynamic_loss_scaling is None: if use_dynamic_loss_scaling is None:
......
...@@ -622,8 +622,6 @@ def cast_model_to_fp16( ...@@ -622,8 +622,6 @@ def cast_model_to_fp16(
if amp_lists is not None: if amp_lists is not None:
dtype = get_low_precision_dtypestr(dest_type) dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype) amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
global_block = program.global_block() global_block = program.global_block()
......
...@@ -15,9 +15,10 @@ ...@@ -15,9 +15,10 @@
import unittest import unittest
import numpy as np import numpy as np
from amp_base_models import AmpTestBase, build_conv_model from amp_base_models import AmpTestBase
import paddle import paddle
from paddle import nn
from paddle.static import amp from paddle.static import amp
...@@ -39,25 +40,47 @@ class TestAutoCast(AmpTestBase): ...@@ -39,25 +40,47 @@ class TestAutoCast(AmpTestBase):
self.assertEqual(out3.dtype, paddle.float32) self.assertEqual(out3.dtype, paddle.float32)
class SimpleConvNet(nn.Layer):
def __init__(self):
super().__init__()
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
self._linear = paddle.nn.Linear(in_features=4, out_features=4)
def forward(self, x):
out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = self._linear(out2)
return out3
class TestStaticDecorate(AmpTestBase): class TestStaticDecorate(AmpTestBase):
def check_results( def check_results(
self, use_amp, dtype, level, use_promote, expected_op_calls self, use_amp, dtype, level, use_promote, expected_op_calls
): ):
( main_program = paddle.static.Program()
main_program, startup_program = paddle.static.Program()
startup_program, with paddle.utils.unique_name.guard():
optimizer, with paddle.static.program_guard(main_program, startup_program):
feed_vars, model = SimpleConvNet()
fetch_vars, x = paddle.static.data(
) = build_conv_model(use_amp, dtype, level, use_promote) name='input', shape=[None, 1, 6, 6], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
level=level,
)
optimizer.minimize(loss)
feed_vars = [x]
fetch_vars = [loss]
self.assertEqual(main_program.num_blocks, 1) self.assertEqual(main_program.num_blocks, 1)
optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
level=level,
)
amp.debugging.collect_operator_stats(main_program) amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program)
...@@ -85,16 +108,13 @@ class TestStaticDecorate(AmpTestBase): ...@@ -85,16 +108,13 @@ class TestStaticDecorate(AmpTestBase):
level, level,
) )
def test_static_amp_o1(self): def test_static_amp_OD(self):
paddle.enable_static() paddle.enable_static()
expected_fp16_calls = { expected_fp16_calls = {
"conv2d": 1, "conv2d": 1,
"elementwise_add": 0, "elementwise_add": 0,
"relu": 0,
"matmul_v2": 1, "matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0, "reduce_mean": 0,
"adamw": 0,
} }
self.check_results( self.check_results(
True, True,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册