未验证 提交 71a513c2 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support promote kernel for static graph (#52514)

* support promote dtype for static amp training

* unify o1 and o2

* update for unittest

* fix op_role

* add use_promote arg

* fix doc

* add promote unittest

* polish unittests

* fix controflow and test
上级 040f8aa5
......@@ -4982,8 +4982,8 @@ class PipelineOptimizer:
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or op.type == "scale") and self._is_backward_op(
op
elif (op.type == "cast" or op.type == "scale") and (
self._is_backward_op(op) or self._is_forward_op(op)
):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
......
......@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase):
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase):
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = numpy.random.random(size=(2, 2)).astype('float16')
else:
x = numpy.random.random(size=(2, 2)).astype('float32')
......
......@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -294,8 +294,8 @@ class PartialProgramLayer:
def _create_amp_program(self, is_infer_mode=False):
amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
with program_guard(amp_program):
paddle.static.amp.fp16_utils.rewrite_program(
amp_program, self._amp_list
paddle.static.amp.fp16_utils.cast_model_to_fp16(
amp_program, self._amp_list, use_fp16_guard=False, level='O1'
)
if is_infer_mode:
if self._hooker:
......
......@@ -29,7 +29,6 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import (
cast_model_to_fp16,
cast_parameters_to_fp16,
rewrite_program,
update_role_var_grad,
)
from .function_overload import FunctionType, overload
......@@ -67,6 +66,7 @@ class OptimizerWithMixedPrecision:
the loss scaling.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`.
use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
"""
def __init__(
......@@ -82,6 +82,7 @@ class OptimizerWithMixedPrecision:
incr_ratio,
decr_ratio,
use_amp_guard=None,
use_promote=False,
):
self._optimizer = optimizer
self._amp_lists = amp_lists
......@@ -116,6 +117,7 @@ class OptimizerWithMixedPrecision:
self._decr_ratio = decr_ratio
self._num_good_steps = None
self._num_bad_steps = None
self.use_promote = use_promote
def _set_distributed(self, flag):
# if distributed, all cards will communication with each other,
......@@ -231,10 +233,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
level='O2',
use_promote=self.use_promote,
)
else:
rewrite_program(
self._train_program, self._amp_lists, self._amp_vartype
# use_fp16_guard is not support amp-o1.
cast_model_to_fp16(
self._train_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
)
if loss.dtype != core.VarDesc.VarType.FP32:
......@@ -362,10 +372,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
level='O2',
use_promote=self.use_promote,
)
elif use_fp16_test:
rewrite_program(
test_program, self._amp_lists, self._amp_vartype
# use_fp16_guard is not support amp-o1.
cast_model_to_fp16(
test_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
)
def apply_gradients(self, params_grads):
......@@ -624,6 +642,7 @@ def decorate(
use_pure_fp16=False,
use_fp16_guard=None,
use_bf16=False,
use_promote=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
......@@ -736,6 +755,7 @@ def decorate(
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard,
use_promote=use_promote,
)
return mp_optimizer
......@@ -754,6 +774,7 @@ def decorate(
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_amp_guard=False,
use_promote=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
......@@ -781,6 +802,7 @@ def decorate(
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard,
use_promote=use_promote,
)
return mp_optimizer
......@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype):
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops.
supported_fp16_list = {
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list
......@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype):
return unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
def _get_white_list(dtype):
white_list_for_dtype = copy.copy(white_list)
if dtype == 'float16':
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
return white_list_for_dtype
class AutoMixedPrecisionLists:
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
......@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists:
self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list)
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
......@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists:
"""
Update black and white list according to users' custom list.
"""
_logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ")
_logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ")
_logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ")
if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
......@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists:
)
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
......
......@@ -29,6 +29,7 @@ def _build_optimizer(
amp_level="O1",
amp_lists=None,
use_grad_clip=False,
use_promote=False,
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
......@@ -45,7 +46,11 @@ def _build_optimizer(
)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer, amp_lists, level=amp_level, dtype=amp_dtype
optimizer,
amp_lists,
level=amp_level,
dtype=amp_dtype,
use_promote=use_promote,
)
return optimizer
......@@ -67,7 +72,9 @@ class SimpleAddNet(nn.Layer):
return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
......@@ -92,7 +99,11 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, amp_lists
use_amp,
amp_dtype,
amp_level,
amp_lists,
use_promote=use_promote,
)
optimizer.minimize(loss)
feed_vars = [x]
......@@ -104,30 +115,37 @@ class SimpleConvNet(nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10)
self.linear = nn.Linear(in_features=96, out_features=4)
def forward(self, x):
out = self.conv(x)
out = nn.functional.relu(out)
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out)
out = nn.functional.softmax(out)
return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"):
def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
name='input', shape=[None, 1, 6, 6], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, use_promote=use_promote
)
optimizer.minimize(loss)
return main_program, startup_program
feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleEmbeddingNet(nn.Layer):
......@@ -149,7 +167,9 @@ class SimpleEmbeddingNet(nn.Layer):
return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
......@@ -159,7 +179,12 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, None, True
use_amp,
amp_dtype,
amp_level,
None,
True,
use_promote=use_promote,
)
optimizer.minimize(loss)
return main_program, startup_program
......@@ -211,3 +236,48 @@ class AmpTestBase(unittest.TestCase):
def setUp(self):
self.amp_dtype = None
self.amp_level = None
def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
for op_type, value in expected_fp16_calls.items():
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
)
def run_program(
self,
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
level,
):
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
if level == 'O2':
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
paddle.enable_static()
class TestAMPPromote(AmpTestBase):
def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_promote_results(
True,
'float16',
'O1',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
def test_static_amp_o2(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw": 4,
}
self.check_promote_results(
True,
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
if __name__ == '__main__':
unittest.main()
......@@ -221,14 +221,6 @@ class TestModelCastBF16(unittest.TestCase):
class TestProgramBF16(AmpTestBase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1"
......@@ -245,7 +237,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 0,
"adamw": 0,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model(
......@@ -263,7 +255,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 2,
"adamw": 2,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
class TestStaticBF16(AmpTestBase):
......@@ -274,60 +266,35 @@ class TestStaticBF16(AmpTestBase):
return x_fp32, x_bf16
def test_compare_o1_o2(self):
def _run_o1(place, exe, x_np, max_iters):
def _run(place, exe, x_np, max_iters, level):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O1")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
) = build_add_model(True, "bfloat16", level)
def _run_o2(place, exe, x_np, max_iters):
(
losses = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O2")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
place,
exe,
x_np,
max_iters,
level,
)
return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(place, exe, x_fp32, max_iters)
losses_o2 = _run_o2(place, exe, x_bf16, max_iters)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
if __name__ == '__main__':
......
......@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase):
# infer(use_cuda, save_dirname)
def test_amp_lists(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_1(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_2(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_3(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_4(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_5(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_6(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......
......@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase):
startup_program = paddle.static.Program()
with paddle.static.amp.fp16_guard():
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [1, 64, 64, 8])
x = paddle.static.data("x", [1, 64, 64, 8], dtype="float16")
conv2d = paddle.nn.Conv2D(
8, 32, 1, bias_attr=False, data_format='NHWC'
)
......@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase):
np.testing.assert_allclose(
before_out[0], after_out[0], rtol=1e-05, atol=0.005
)
if __name__ == '__main__':
unittest.main()
......@@ -25,10 +25,10 @@ paddle.enable_static()
def build_resnet50(use_amp=False):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
dtype = 'float16' if use_amp else 'float32'
with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data(
name='image', shape=[32, 3, 224, 224], dtype='float32'
name='image', shape=[32, 3, 224, 224], dtype=dtype
)
label = paddle.static.data(name='label', shape=[32], dtype='int64')
model = paddle.vision.models.resnet50()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册