未验证 提交 d2f09015 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Allow to switch whether to use promote strategy to choose kernel for O2...

[AMP] Allow to switch whether to use promote strategy to choose kernel for O2 training. (#53742) (#53841)

Pcard-70458

cherry-pick #53742

中文文档:PaddlePaddle/docs#5882
上级 5d10e910
...@@ -123,18 +123,35 @@ inline phi::DataType GetAmpDestDtype( ...@@ -123,18 +123,35 @@ inline phi::DataType GetAmpDestDtype(
egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype(); egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
auto dst_type = amp_setting_dtype; auto dst_type = amp_setting_dtype;
if (paddle::imperative::AmpOperators::Instance().GetMutableAllowOps()->count( bool use_promote = true;
op_name)) { if (amp_level == paddle::imperative::AmpLevel::O2) {
dst_type = amp_setting_dtype; use_promote =
} else if (paddle::imperative::AmpOperators::Instance() egr::Controller::Instance().GetCurrentTracer()->GetUsePromote();
.GetMutableBlockOps() }
->count(op_name)) {
dst_type = phi::DataType::FLOAT32; if (use_promote) {
} else { if (paddle::imperative::AmpOperators::Instance()
if (amp_level == paddle::imperative::AmpLevel::OD) { .GetMutableAllowOps()
->count(op_name)) {
dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32; dst_type = phi::DataType::FLOAT32;
} else { } else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); if (amp_level == paddle::imperative::AmpLevel::OD) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type =
GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
}
} else {
// use_promote can be set to false only for O2 training.
if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
} }
} }
......
...@@ -58,6 +58,9 @@ class Controller { ...@@ -58,6 +58,9 @@ class Controller {
return tracer_->GetAmpLevel(); return tracer_->GetAmpLevel();
} }
void SetUsePromote(bool use_promote) { tracer_->SetUsePromote(use_promote); }
bool GetUsePromote() const { return tracer_->GetUsePromote(); }
bool UseLayoutAutoTune() { bool UseLayoutAutoTune() {
bool use_autotune = false; bool use_autotune = false;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
......
...@@ -43,6 +43,8 @@ thread_local bool Tracer::enable_program_desc_tracing_ = false; ...@@ -43,6 +43,8 @@ thread_local bool Tracer::enable_program_desc_tracing_ = false;
thread_local bool Tracer::has_grad_ = true; thread_local bool Tracer::has_grad_ = true;
thread_local bool Tracer::use_promote_ = true;
thread_local bool Tracer::use_layout_autotune_ = false; thread_local bool Tracer::use_layout_autotune_ = false;
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0; thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
......
...@@ -156,6 +156,13 @@ class Tracer { ...@@ -156,6 +156,13 @@ class Tracer {
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
void SetUsePromote(bool use_promote) {
VLOG(4) << "set use_promote to " << use_promote;
use_promote_ = use_promote;
}
bool GetUsePromote() const { return use_promote_; }
void SetAmpLevel(AmpLevel level) { void SetAmpLevel(AmpLevel level) {
VLOG(4) << "set amp_level to " << static_cast<unsigned int>(level); VLOG(4) << "set amp_level to " << static_cast<unsigned int>(level);
amp_level_ = level; amp_level_ = level;
...@@ -220,6 +227,7 @@ class Tracer { ...@@ -220,6 +227,7 @@ class Tracer {
static thread_local bool enable_program_desc_tracing_; static thread_local bool enable_program_desc_tracing_;
static thread_local bool use_layout_autotune_; static thread_local bool use_layout_autotune_;
static thread_local bool has_grad_; static thread_local bool has_grad_;
static thread_local bool use_promote_;
static thread_local AmpLevel amp_level_; static thread_local AmpLevel amp_level_;
static thread_local phi::DataType amp_dtype_; static thread_local phi::DataType amp_dtype_;
}; };
......
...@@ -2185,6 +2185,9 @@ void BindImperative(py::module *m_ptr) { ...@@ -2185,6 +2185,9 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing", .def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing) &imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_use_promote",
&imperative::Tracer::GetUsePromote,
&imperative::Tracer::SetUsePromote)
.def_property("_amp_level", .def_property("_amp_level",
&imperative::Tracer::GetAmpLevel, &imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel) &imperative::Tracer::SetAmpLevel)
......
...@@ -274,6 +274,7 @@ def amp_guard( ...@@ -274,6 +274,7 @@ def amp_guard(
custom_black_list=None, custom_black_list=None,
level='O1', level='O1',
dtype='float16', dtype='float16',
use_promote=True,
): ):
""" """
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
...@@ -438,6 +439,11 @@ def amp_guard( ...@@ -438,6 +439,11 @@ def amp_guard(
original_amp_dtype = tracer._amp_dtype original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype tracer._amp_dtype = amp_dtype
# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = tracer._use_promote
tracer._use_promote = use_promote
# restore status # restore status
try: try:
yield yield
...@@ -448,6 +454,8 @@ def amp_guard( ...@@ -448,6 +454,8 @@ def amp_guard(
tracer._set_amp_op_list(original_white_list, original_black_list) tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags) # set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype tracer._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
tracer._use_promote = original_use_promote
class StateDictHook: class StateDictHook:
...@@ -641,6 +649,7 @@ def auto_cast( ...@@ -641,6 +649,7 @@ def auto_cast(
custom_black_list=None, custom_black_list=None,
level='O1', level='O1',
dtype='float16', dtype='float16',
use_promote=True,
): ):
""" """
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
...@@ -663,6 +672,7 @@ def auto_cast( ...@@ -663,6 +672,7 @@ def auto_cast(
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1. default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
use_promote(bool, optional): Whether to promotes to fp32 when op has any float32 inputs. It is only supported when amp level is O2. Default is True.
Examples: Examples:
...@@ -696,7 +706,9 @@ def auto_cast( ...@@ -696,7 +706,9 @@ def auto_cast(
print(d.dtype) # paddle.float16 print(d.dtype) # paddle.float16
""" """
return amp_guard(enable, custom_white_list, custom_black_list, level, dtype) return amp_guard(
enable, custom_white_list, custom_black_list, level, dtype, use_promote
)
def decorate( def decorate(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import unittest import unittest
import numpy as np import numpy as np
...@@ -19,6 +20,7 @@ import numpy as np ...@@ -19,6 +20,7 @@ import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import _non_static_mode
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32") _fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
...@@ -30,20 +32,27 @@ def _build_optimizer( ...@@ -30,20 +32,27 @@ def _build_optimizer(
amp_lists=None, amp_lists=None,
use_grad_clip=False, use_grad_clip=False,
use_promote=False, use_promote=False,
model=None,
): ):
if use_grad_clip: if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else: else:
grad_clip = None grad_clip = None
if _non_static_mode():
assert model is not None
parameters = model.parameters()
else:
parameters = None
optimizer = paddle.optimizer.AdamW( optimizer = paddle.optimizer.AdamW(
learning_rate=0.01, learning_rate=0.01,
parameters=parameters,
grad_clip=grad_clip, grad_clip=grad_clip,
beta1=0.78, beta1=0.78,
beta2=0.836, beta2=0.836,
epsilon=1e-4, epsilon=1e-4,
weight_decay=0.01, weight_decay=0.01,
) )
if use_amp: if not _non_static_mode() and use_amp:
optimizer = paddle.static.amp.decorate( optimizer = paddle.static.amp.decorate(
optimizer, optimizer,
amp_lists, amp_lists,
...@@ -118,7 +127,7 @@ class SimpleConvNet(nn.Layer): ...@@ -118,7 +127,7 @@ class SimpleConvNet(nn.Layer):
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
out = nn.functional.relu(out) out = nn.functional.relu(out.cast("float32"))
out = out.flatten(start_axis=1, stop_axis=3) out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out) out = self.linear(out)
out = nn.functional.softmax(out) out = nn.functional.softmax(out)
...@@ -128,6 +137,22 @@ class SimpleConvNet(nn.Layer): ...@@ -128,6 +137,22 @@ class SimpleConvNet(nn.Layer):
def build_conv_model( def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
): ):
if _non_static_mode():
model = SimpleConvNet()
optimizer = _build_optimizer(use_amp=False, model=model)
if use_amp and amp_dtype == "float16":
scaler = paddle.amp.GradScaler()
else:
scaler = None
if use_amp and amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=amp_level,
dtype=amp_dtype,
)
return model, optimizer, scaler
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard(): with paddle.utils.unique_name.guard():
...@@ -237,19 +262,36 @@ class AmpTestBase(unittest.TestCase): ...@@ -237,19 +262,36 @@ class AmpTestBase(unittest.TestCase):
self.amp_level = None self.amp_level = None
def _check_op_calls( def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={} self,
op_stats_dict,
expected_bf16_calls={},
expected_fp16_calls={},
debug_info=None,
): ):
for op_type, value in expected_bf16_calls.items(): def _extract_op_call(op_calls_str, pos):
return int(copy.copy(op_calls_str).split(",")[pos])
for op_type, expected_value in expected_bf16_calls.items():
# print(f"[BF16] op_type={op_type}, value={value}")
if isinstance(op_stats_dict[op_type], str):
actual_value = _extract_op_call(op_stats_dict[op_type], 1)
else:
actual_value = op_stats_dict[op_type].bf16_calls
self.assertEqual( self.assertEqual(
op_stats_dict[op_type].bf16_calls, actual_value,
value, expected_value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.", f"[{debug_info}] The number of bf16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
) )
for op_type, value in expected_fp16_calls.items(): for op_type, expected_value in expected_fp16_calls.items():
# print(f"[FP16] op_type={op_type}, value={value}")
if isinstance(op_stats_dict[op_type], str):
actual_value = _extract_op_call(op_stats_dict[op_type], 0)
else:
actual_value = op_stats_dict[op_type].fp16_calls
self.assertEqual( self.assertEqual(
op_stats_dict[op_type].fp16_calls, actual_value,
value, expected_value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.", f"[debug_info] The number of fp16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
) )
def run_program( def run_program(
...@@ -263,6 +305,7 @@ class AmpTestBase(unittest.TestCase): ...@@ -263,6 +305,7 @@ class AmpTestBase(unittest.TestCase):
exe, exe,
x_np, x_np,
max_iters, max_iters,
dtype,
level, level,
): ):
losses = [] losses = []
...@@ -277,6 +320,8 @@ class AmpTestBase(unittest.TestCase): ...@@ -277,6 +320,8 @@ class AmpTestBase(unittest.TestCase):
feed={feed_vars[0].name: x_np}, feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars, fetch_list=fetch_vars,
) )
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}") print(
f"-- [AMP {dtype} {level}] iter={iter_id}, loss={results[0]}"
)
losses.append(results[0]) losses.append(results[0])
return losses return losses
...@@ -20,15 +20,17 @@ import paddle ...@@ -20,15 +20,17 @@ import paddle
class TestAutoCast(AmpTestBase): class TestAutoCast(AmpTestBase):
def test_amp_OD_level(self): def setUp(self):
conv = paddle.nn.Conv2D( self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
) )
linear = paddle.nn.Linear(in_features=4, out_features=4) self._linear = paddle.nn.Linear(in_features=4, out_features=4)
def test_amp_OD_level(self):
with paddle.amp.auto_cast(level='OD'): with paddle.amp.auto_cast(level='OD'):
out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32')) out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16') out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = linear(out2) out3 = self._linear(out2)
self.assertEqual(out1.dtype, paddle.float16) self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32) self.assertEqual(out2.dtype, paddle.float32)
......
...@@ -131,6 +131,7 @@ class TestUnittedEmbedding(AmpTestBase): ...@@ -131,6 +131,7 @@ class TestUnittedEmbedding(AmpTestBase):
exe, exe,
x_np, x_np,
max_iters, max_iters,
"float16",
level, level,
) )
return losses return losses
......
...@@ -20,13 +20,12 @@ from amp_base_models import AmpTestBase, build_conv_model ...@@ -20,13 +20,12 @@ from amp_base_models import AmpTestBase, build_conv_model
import paddle import paddle
from paddle.static import amp from paddle.static import amp
paddle.enable_static()
class TestStaticAmpPromoteStats(AmpTestBase):
class TestAMPPromote(AmpTestBase):
def check_promote_results( def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls self, use_amp, dtype, level, use_promote, expected_op_calls, debug_info
): ):
paddle.enable_static()
( (
main_program, main_program,
startup_program, startup_program,
...@@ -40,7 +39,9 @@ class TestAMPPromote(AmpTestBase): ...@@ -40,7 +39,9 @@ class TestAMPPromote(AmpTestBase):
op_stats_list = amp.debugging._get_op_stats_list(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls( self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls op_stats_list[0],
expected_fp16_calls=expected_op_calls,
debug_info=debug_info,
) )
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
...@@ -58,8 +59,10 @@ class TestAMPPromote(AmpTestBase): ...@@ -58,8 +59,10 @@ class TestAMPPromote(AmpTestBase):
exe, exe,
x_fp32, x_fp32,
max_iters, max_iters,
dtype,
level, level,
) )
paddle.disable_static()
def test_static_amp_o1(self): def test_static_amp_o1(self):
expected_fp16_calls = { expected_fp16_calls = {
...@@ -77,13 +80,14 @@ class TestAMPPromote(AmpTestBase): ...@@ -77,13 +80,14 @@ class TestAMPPromote(AmpTestBase):
'O1', 'O1',
use_promote=True, use_promote=True,
expected_op_calls=expected_fp16_calls, expected_op_calls=expected_fp16_calls,
debug_info="TestStaticAmpPromoteStats/test_static_amp_o1",
) )
def test_static_amp_o2(self): def test_static_amp_o2(self):
expected_fp16_calls = { expected_fp16_calls = {
"conv2d": 1, "conv2d": 1,
"elementwise_add": 2, "elementwise_add": 2,
"relu": 1, "relu": 0,
"matmul_v2": 1, "matmul_v2": 1,
"softmax": 1, "softmax": 1,
"reduce_mean": 1, "reduce_mean": 1,
...@@ -95,7 +99,109 @@ class TestAMPPromote(AmpTestBase): ...@@ -95,7 +99,109 @@ class TestAMPPromote(AmpTestBase):
'O2', 'O2',
use_promote=True, use_promote=True,
expected_op_calls=expected_fp16_calls, expected_op_calls=expected_fp16_calls,
debug_info="TestStaticAmpPromoteStats/test_static_amp_o2",
)
class TestEagerAmpPromoteStats(AmpTestBase):
def check_promote_results(
self, dtype, level, use_promote, expected_op_calls, debug_info
):
model, optimizer, scaler = build_conv_model(
use_amp=True,
amp_dtype=dtype,
amp_level=level,
use_promote=use_promote,
)
model.train()
paddle.amp.debugging.enable_operator_stats_collection()
with paddle.amp.auto_cast(
enable=True, dtype=dtype, level=level, use_promote=use_promote
):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
out = model(x)
loss = paddle.mean(out)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
paddle.amp.debugging.disable_operator_stats_collection()
op_stats = paddle.fluid.core.get_low_precision_op_list()
self._check_op_calls(
op_stats,
expected_fp16_calls=expected_op_calls,
debug_info=debug_info,
)
def test_o2_promote_on(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 0,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_on",
)
def test_o2_promote_off(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=False,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_off",
)
class TestEagerAmpPromoteSimple(AmpTestBase):
def setUp(self):
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
) )
self._linear = paddle.nn.Linear(in_features=4, out_features=4)
def test_o2_use_promote_on(self):
with paddle.amp.auto_cast(level='O2'):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)
self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float32)
def test_o2_use_promote_off(self):
with paddle.amp.auto_cast(level='O2', use_promote=False):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)
self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float16)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -310,6 +310,7 @@ class TestStaticBF16(AmpTestBase): ...@@ -310,6 +310,7 @@ class TestStaticBF16(AmpTestBase):
exe, exe,
x_np, x_np,
max_iters, max_iters,
"bfloat16",
level, level,
) )
return losses return losses
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册