未验证 提交 6f3c9643 编写于 作者: J JZ-LIANG 提交者: GitHub

Eb118 BF16 Adoption (#52827)

* pr1

* pr2

* pr3

* fixed unitest

* adopt for scale
上级 8cbc75ca
...@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False) ...@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
######################################### #########################################
AMP = "amp" AMP = "amp"
set_field_default_config(AMP, "enable", False) set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "dtype", "float16")
set_field_default_config(AMP, "level", "o1")
set_field_default_config(AMP, "init_loss_scaling", 32768.0) set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000) set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
...@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True) ...@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", []) set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", []) set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", []) set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False) set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_fp16_guard", True) set_field_default_config(AMP, "use_bf16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False) set_field_default_config(AMP, "use_optimizer_fp16", False)
######################################### #########################################
......
...@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_var, Out_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'c_allreduce_sum', 'c_allreduce_sum',
) )
...@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, Out_grad,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}, },
) )
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
......
...@@ -254,17 +254,26 @@ class Parallelizer: ...@@ -254,17 +254,26 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"] self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"] + self._dist_context.serial_feed_vars["labels"]
) )
if config["use_pure_fp16"]: self._logger.info(
"Applying AMP-{}-{} ...".format(
config["dtype"], config['level']
),
)
if config['level'] == "o1":
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_amp_pass.get_loss()
elif config['level'] in ['o2', 'o3']:
config["base_opt"] = optimizer config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) raise ValueError("AMP level should be one of o1, o2, o3")
auto_parallel_amp_pass.apply(
[main_program], [startup_program], self._pass_context
)
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
......
...@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS}) py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_amp_o2_pass MODULES test_amp_o2_pass ENVS ${dist_ENVS})
set_tests_properties(test_amp_o2_pass PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_iterable_dataset set_tests_properties(test_iterable_dataset
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.framework import core
paddle.enable_static()
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def apply_pass(use_amp=False, amp_dtype="bfloat16"):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = amp_dtype
amp.level = "o2"
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingStage2WithNewEXE(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, amp_dtype="bfloat16"):
reset_prog()
strategy = apply_pass(use_amp, amp_dtype)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip = None
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_bf16(self, program):
num_bf16 = 0
num_fp16 = 0
num_fp32 = 0
for p in program.all_parameters():
if p.dtype == core.VarDesc.VarType.FP32:
num_fp32 += 1
if p.dtype == core.VarDesc.VarType.FP16:
num_fp16 += 1
if p.dtype == core.VarDesc.VarType.BF16:
num_bf16 += 1
self.assertEqual(num_bf16, 25)
self.assertEqual(num_fp16, 0)
self.assertEqual(num_fp32, 11)
def test_param_grad_fuse_overlap(self):
# std
mp_engine = self.get_engine(use_amp=False)
mp_history = mp_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss0 = mp_history.history['loss'][0]
# bf16
mp_bf16_engine = self.get_engine(use_amp=True)
if not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000:
return
mp_bf16_history = mp_bf16_engine.fit(
self.dataset,
3,
epochs=1,
steps_per_epoch=self.batch_num,
log_freq=1,
batch_size=self.batch_size,
)
loss1 = mp_bf16_history.history['loss'][0]
np.testing.assert_allclose(loss0, loss1, atol=1e-3, rtol=1e-2)
self.check_bf16(mp_bf16_engine.main_program)
if __name__ == "__main__":
unittest.main()
...@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None): ...@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
] ]
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"] amp.level = level
amp.use_optimizer_fp16 = level == "o3" amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level) print("amp level: ", level)
return strategy return strategy
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestAMPO2(unittest.TestCase):
def test_bf16(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_o2_pass.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
# limitations under the License. # limitations under the License.
import os import os
# import yaml # import yaml
import unittest import unittest
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
class TestStrategy(unittest.TestCase): class TestStrategy(unittest.TestCase):
def test_default_config(self): def test_default_config(self):
strategy = auto.Strategy() strategy = auto.Strategy()
...@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase): ...@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp amp = strategy.amp
self.assertEqual(amp.enable, False) self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000) self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2) self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
...@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase): ...@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, []) self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, []) self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, []) self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False) self.assertEqual(amp.use_optimizer_fp16, False)
sharding = strategy.sharding sharding = strategy.sharding
...@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase): ...@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp.custom_white_list = ["x"] amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"] amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"] amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True) self.assertEqual(amp.enable, True)
...@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase): ...@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, ["x"]) self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"]) self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"]) self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False) self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True) self.assertEqual(amp.use_optimizer_fp16, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册