未验证 提交 db407bf0 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Allow to switch whether to use promote strategy to choose kernel for O2 training. (#53742)

* Allow to switch whether to use promote strategy to choose kernel for O2 training.

* Fix comparing error and add unittest.
上级 2a94b817
......@@ -120,8 +120,16 @@ inline phi::DataType GetAmpDestDtype(
egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
auto dst_type = amp_setting_dtype;
if (paddle::imperative::AmpOperators::Instance().GetMutableAllowOps()->count(
op_name)) {
bool use_promote = true;
if (amp_level == paddle::imperative::AmpLevel::O2) {
use_promote =
egr::Controller::Instance().GetCurrentTracer()->GetUsePromote();
}
if (use_promote) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
......@@ -131,7 +139,16 @@ inline phi::DataType GetAmpDestDtype(
if (amp_level == paddle::imperative::AmpLevel::OD) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
dst_type =
GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
}
} else {
// use_promote can be set to false only for O2 training.
if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
}
}
......
......@@ -58,6 +58,9 @@ class Controller {
return tracer_->GetAmpLevel();
}
void SetUsePromote(bool use_promote) { tracer_->SetUsePromote(use_promote); }
bool GetUsePromote() const { return tracer_->GetUsePromote(); }
bool UseLayoutAutoTune() {
bool use_autotune = false;
#if defined(PADDLE_WITH_CUDA)
......
......@@ -44,6 +44,8 @@ thread_local bool Tracer::enable_program_desc_tracing_ = false;
thread_local bool Tracer::has_grad_ = true;
thread_local bool Tracer::use_promote_ = true;
thread_local bool Tracer::use_layout_autotune_ = false;
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
......
......@@ -156,6 +156,13 @@ class Tracer {
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
void SetUsePromote(bool use_promote) {
VLOG(4) << "set use_promote to " << use_promote;
use_promote_ = use_promote;
}
bool GetUsePromote() const { return use_promote_; }
void SetAmpLevel(AmpLevel level) {
VLOG(4) << "set amp_level to " << static_cast<unsigned int>(level);
amp_level_ = level;
......@@ -220,6 +227,7 @@ class Tracer {
static thread_local bool enable_program_desc_tracing_;
static thread_local bool use_layout_autotune_;
static thread_local bool has_grad_;
static thread_local bool use_promote_;
static thread_local AmpLevel amp_level_;
static thread_local phi::DataType amp_dtype_;
};
......
......@@ -2156,6 +2156,9 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_use_promote",
&imperative::Tracer::GetUsePromote,
&imperative::Tracer::SetUsePromote)
.def_property("_amp_level",
&imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel)
......
......@@ -274,6 +274,7 @@ def amp_guard(
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
......@@ -438,6 +439,11 @@ def amp_guard(
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = tracer._use_promote
tracer._use_promote = use_promote
# restore status
try:
yield
......@@ -448,6 +454,8 @@ def amp_guard(
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
tracer._use_promote = original_use_promote
class StateDictHook:
......@@ -641,6 +649,7 @@ def auto_cast(
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
......@@ -663,6 +672,7 @@ def auto_cast(
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
use_promote(bool, optional): Whether to promotes to fp32 when op has any float32 inputs. It is only supported when amp level is O2. Default is True.
Examples:
......@@ -696,7 +706,9 @@ def auto_cast(
print(d.dtype) # paddle.float16
"""
return amp_guard(enable, custom_white_list, custom_black_list, level, dtype)
return amp_guard(
enable, custom_white_list, custom_black_list, level, dtype, use_promote
)
def decorate(
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
import numpy as np
......@@ -19,6 +20,7 @@ import numpy as np
import paddle
from paddle import nn
from paddle.fluid import core
from paddle.fluid.framework import _non_static_mode
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
......@@ -30,20 +32,27 @@ def _build_optimizer(
amp_lists=None,
use_grad_clip=False,
use_promote=False,
model=None,
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else:
grad_clip = None
if _non_static_mode():
assert model is not None
parameters = model.parameters()
else:
parameters = None
optimizer = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=parameters,
grad_clip=grad_clip,
beta1=0.78,
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
)
if use_amp:
if not _non_static_mode() and use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
amp_lists,
......@@ -118,7 +127,7 @@ class SimpleConvNet(nn.Layer):
def forward(self, x):
out = self.conv(x)
out = nn.functional.relu(out)
out = nn.functional.relu(out.cast("float32"))
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out)
out = nn.functional.softmax(out)
......@@ -128,6 +137,22 @@ class SimpleConvNet(nn.Layer):
def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
if _non_static_mode():
model = SimpleConvNet()
optimizer = _build_optimizer(use_amp=False, model=model)
if use_amp and amp_dtype == "float16":
scaler = paddle.amp.GradScaler()
else:
scaler = None
if use_amp and amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=amp_level,
dtype=amp_dtype,
)
return model, optimizer, scaler
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
......@@ -237,19 +262,36 @@ class AmpTestBase(unittest.TestCase):
self.amp_level = None
def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
self,
op_stats_dict,
expected_bf16_calls={},
expected_fp16_calls={},
debug_info=None,
):
for op_type, value in expected_bf16_calls.items():
def _extract_op_call(op_calls_str, pos):
return int(copy.copy(op_calls_str).split(",")[pos])
for op_type, expected_value in expected_bf16_calls.items():
# print(f"[BF16] op_type={op_type}, value={value}")
if isinstance(op_stats_dict[op_type], str):
actual_value = _extract_op_call(op_stats_dict[op_type], 1)
else:
actual_value = op_stats_dict[op_type].bf16_calls
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
actual_value,
expected_value,
f"[{debug_info}] The number of bf16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
)
for op_type, value in expected_fp16_calls.items():
for op_type, expected_value in expected_fp16_calls.items():
# print(f"[FP16] op_type={op_type}, value={value}")
if isinstance(op_stats_dict[op_type], str):
actual_value = _extract_op_call(op_stats_dict[op_type], 0)
else:
actual_value = op_stats_dict[op_type].fp16_calls
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
actual_value,
expected_value,
f"[debug_info] The number of fp16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
)
def run_program(
......@@ -263,6 +305,7 @@ class AmpTestBase(unittest.TestCase):
exe,
x_np,
max_iters,
dtype,
level,
):
losses = []
......@@ -277,6 +320,8 @@ class AmpTestBase(unittest.TestCase):
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
print(
f"-- [AMP {dtype} {level}] iter={iter_id}, loss={results[0]}"
)
losses.append(results[0])
return losses
......@@ -20,15 +20,17 @@ import paddle
class TestAutoCast(AmpTestBase):
def test_amp_OD_level(self):
conv = paddle.nn.Conv2D(
def setUp(self):
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
linear = paddle.nn.Linear(in_features=4, out_features=4)
self._linear = paddle.nn.Linear(in_features=4, out_features=4)
def test_amp_OD_level(self):
with paddle.amp.auto_cast(level='OD'):
out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = linear(out2)
out3 = self._linear(out2)
self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
......
......@@ -131,6 +131,7 @@ class TestUnittedEmbedding(AmpTestBase):
exe,
x_np,
max_iters,
"float16",
level,
)
return losses
......
......@@ -20,13 +20,12 @@ from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
paddle.enable_static()
class TestAMPPromote(AmpTestBase):
class TestStaticAmpPromoteStats(AmpTestBase):
def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
self, use_amp, dtype, level, use_promote, expected_op_calls, debug_info
):
paddle.enable_static()
(
main_program,
startup_program,
......@@ -40,7 +39,9 @@ class TestAMPPromote(AmpTestBase):
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
op_stats_list[0],
expected_fp16_calls=expected_op_calls,
debug_info=debug_info,
)
place = paddle.CUDAPlace(0)
......@@ -58,8 +59,10 @@ class TestAMPPromote(AmpTestBase):
exe,
x_fp32,
max_iters,
dtype,
level,
)
paddle.disable_static()
def test_static_amp_o1(self):
expected_fp16_calls = {
......@@ -77,13 +80,14 @@ class TestAMPPromote(AmpTestBase):
'O1',
use_promote=True,
expected_op_calls=expected_fp16_calls,
debug_info="TestStaticAmpPromoteStats/test_static_amp_o1",
)
def test_static_amp_o2(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"relu": 0,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
......@@ -95,8 +99,110 @@ class TestAMPPromote(AmpTestBase):
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
debug_info="TestStaticAmpPromoteStats/test_static_amp_o2",
)
class TestEagerAmpPromoteStats(AmpTestBase):
def check_promote_results(
self, dtype, level, use_promote, expected_op_calls, debug_info
):
model, optimizer, scaler = build_conv_model(
use_amp=True,
amp_dtype=dtype,
amp_level=level,
use_promote=use_promote,
)
model.train()
paddle.amp.debugging.enable_operator_stats_collection()
with paddle.amp.auto_cast(
enable=True, dtype=dtype, level=level, use_promote=use_promote
):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
out = model(x)
loss = paddle.mean(out)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
paddle.amp.debugging.disable_operator_stats_collection()
op_stats = paddle.fluid.core.get_low_precision_op_list()
self._check_op_calls(
op_stats,
expected_fp16_calls=expected_op_calls,
debug_info=debug_info,
)
def test_o2_promote_on(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 0,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_on",
)
def test_o2_promote_off(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=False,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_off",
)
class TestEagerAmpPromoteSimple(AmpTestBase):
def setUp(self):
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
self._linear = paddle.nn.Linear(in_features=4, out_features=4)
def test_o2_use_promote_on(self):
with paddle.amp.auto_cast(level='O2'):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)
self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float32)
def test_o2_use_promote_off(self):
with paddle.amp.auto_cast(level='O2', use_promote=False):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)
self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float16)
if __name__ == '__main__':
unittest.main()
......@@ -310,6 +310,7 @@ class TestStaticBF16(AmpTestBase):
exe,
x_np,
max_iters,
"bfloat16",
level,
)
return losses
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册