未验证 提交 418edae5 编写于 作者: X xu98bin 提交者: GitHub

auto parallel bf16 (#49079)

* auto parallel bf16
上级 1078e064
...@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat, ...@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat,
ops::CConcatOpCUDAKernel<double>, ops::CConcatOpCUDAKernel<double>,
ops::CConcatOpCUDAKernel<int>, ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>, ops::CConcatOpCUDAKernel<int64_t>,
#if NCCL_VERSION_CODE >= 21000
ops::CConcatOpCUDAKernel<plat::bfloat16>,
#endif
ops::CConcatOpCUDAKernel<plat::float16>); ops::CConcatOpCUDAKernel<plat::float16>);
...@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity, ...@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity,
ops::CIdentityOpKernel<double>, ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>, ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>, ops::CIdentityOpKernel<int64_t>,
#if NCCL_VERSION_CODE >= 21000
ops::CIdentityOpKernel<plat::bfloat16>,
#endif
ops::CIdentityOpKernel<plat::float16>); ops::CIdentityOpKernel<plat::float16>);
...@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False) ...@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True) set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False) set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "enable_bf16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_pure_bf16", False)
set_field_default_config(AMP, "use_bf16_guard", False)
######################################### #########################################
# sharding configuration # sharding configuration
######################################### #########################################
......
...@@ -266,8 +266,12 @@ def is_parameter_related(varname, block): ...@@ -266,8 +266,12 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".subprog_")] varname = varname[: varname.index(".subprog_")]
if ".cast_fp" in varname: if ".cast_fp" in varname:
varname = varname[: varname.index(".cast_fp")] varname = varname[: varname.index(".cast_fp")]
if ".cast_bf" in varname:
varname = varname[: varname.index(".cast_bf")]
if ".quantized" in varname: if ".quantized" in varname:
varname = varname[: varname.index(".quantized")] varname = varname[: varname.index(".quantized")]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert block._find_var_recursive(varname) assert block._find_var_recursive(varname)
var = block._var_recursive(varname) var = block._var_recursive(varname)
return var.is_parameter return var.is_parameter
......
...@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype( check_variable_and_dtype(
Out_grad, Out_grad,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, intermediate_var_0,
'x', 'x',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
set_comm_op_dist_attr_for_program( set_comm_op_dist_attr_for_program(
...@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
X_var, X_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
...@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
attrs = { attrs = {
...@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear' X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
) )
check_dtype( check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear' X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
attrs = { attrs = {
'transpose_X': trans_x, 'transpose_X': trans_x,
...@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
X_var, X_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
...@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
attrs = { attrs = {
...@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear' X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
) )
check_dtype( check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear' X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
attrs = { attrs = {
'trans_x': trans_x, 'trans_x': trans_x,
...@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
check_variable_and_dtype( check_variable_and_dtype(
X_var, X_var,
'tensor', 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'], ['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity', '_c_identity',
) )
c_identity_op = main_block.append_op( c_identity_op = main_block.append_op(
...@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x) intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype( check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear' intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
check_dtype( check_dtype(
intermediate_var_0.dtype, intermediate_var_0.dtype,
'dtype', 'dtype',
['float16', 'float32', 'float64'], ['float16', 'float32', 'float64', 'uint16'],
'linear', 'linear',
) )
# attrs = {'trans_x': False, 'trans_y': False} # attrs = {'trans_x': False, 'trans_y': False}
...@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl):
group = new_process_group(group_ranks) group = new_process_group(group_ranks)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear' X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
) )
check_dtype( check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear' X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
) )
# attrs = {'trans_x': False, 'trans_y': False} # attrs = {'trans_x': False, 'trans_y': False}
attrs = { attrs = {
......
...@@ -221,13 +221,21 @@ class Parallelizer: ...@@ -221,13 +221,21 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"] self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"] + self._dist_context.serial_feed_vars["labels"]
) )
if config["use_pure_fp16"]: if config["enable_bf16"]:
auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config)
auto_parallel_bf16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_bf16_pass.get_loss()
elif config["use_pure_fp16"]:
config["base_opt"] = optimizer config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
loss = auto_parallel_fp16_pass.get_loss() loss = auto_parallel_fp16_pass.get_loss()
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply( auto_parallel_amp_pass.apply(
......
...@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403 ...@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403 from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403 from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403 from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_bf16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403 from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403 from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403 from .auto_parallel_data_parallel_optimization import * # noqa: F403
......
此差异已折叠。
...@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_partition MODULES test_cluster_partition) py_test_modules(test_cluster_partition MODULES test_cluster_partition)
py_test_modules(test_convert_to_process_meshes MODULES py_test_modules(test_convert_to_process_meshes MODULES
test_convert_to_process_meshes) test_convert_to_process_meshes)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver) py_test_modules(test_dist_saver MODULES test_dist_saver)
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.distributed.fleet import auto
from paddle.fluid.contrib.mixed_precision.bf16.amp_utils import _valid_types
from paddle.fluid.contrib.mixed_precision.fp16_utils import find_true_prev_op
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.static import InputSpec
from paddle.vision.datasets import MNIST
paddle.enable_static()
def apply_pass(use_bf16=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_bf16:
amp = strategy.amp
amp.enable = True
amp.enable_bf16 = True
return strategy
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super().__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return (img,)
def __len__(self):
return len(self.images)
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Model(nn.Layer):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 120)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(120, 10)
def forward(self, input):
input.stop_gradient = True
x = self.flatten(input)
x = self.relu1(self.fc1(x))
x = self.fc2(x)
return x
class TestBF16Pass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 256
self.batch_num = 10
self.dataset = MnistDataset("train")
self.eval_dataset = MnistDataset("test")
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_bf16=False):
reset_prog()
strategy = apply_pass(use_bf16)
model = Model()
opt = paddle.optimizer.SGD(0.001, parameters=model.parameters())
loss = nn.CrossEntropyLoss()
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_program(self, program):
bf16_op_list = {
"matmul_v2",
"elementwise_add",
"relu",
"elementwise_add_grad",
"matmul_v2_grad",
"relu_grad",
}
fp32_op_list = {
"flatten_contiguous_range",
"reduce_mean",
"softmax_with_cross_entropy",
"fill_constant",
"reduce_mean_grad",
"softmax_with_cross_entropy_grad",
}
for block in program.blocks:
for op in block.ops:
if op not in bf16_op_list and op not in fp32_op_list:
continue
for in_name in op.input_names:
for in_var_name in op.input(in_name):
var = None
try:
var = block.var(in_var_name)
except ValueError as e:
var = block._var_recursive(in_var_name)
if var is None or var.type not in _valid_types:
break
if op.type in bf16_op_list:
assert var.dtype == core.VarDesc.VarType.BF16
if "cast_bf16" in in_var_name:
if "@GRAD" in in_var_name:
tmp_in_var_name = in_var_name[
: in_var_name.find("@GRAD")
]
else:
tmp_in_var_name = in_var_name
prev_op = find_true_prev_op(
block.ops, op, tmp_in_var_name
)
assert prev_op is not None
assert prev_op.type == "cast"
for in_name in prev_op.input_names:
for in_var_name in prev_op.input(in_name):
var = block.var(in_var_name)
assert (
var.dtype
== core.VarDesc.VarType.FP32
)
elif op.type in fp32_op_list:
if (
op.type == "softmax_with_cross_entropy"
or op.type == "softmax_with_cross_entropy_grad"
) and in_var_name == "label0":
continue
assert var.dtype == core.VarDesc.VarType.FP32
if "cast_fp32" in in_var_name:
prev_op = find_true_prev_op(
block.ops, op, tmp_in_var_name
)
assert prev_op is not None
assert prev_op.type == "cast"
for in_name in prev_op.input_names:
for in_var_name in prev_op.input(in_name):
var = block.var(in_var_name)
assert (
var.dtype
== core.VarDesc.VarType.BF16
)
for out_name in op.output_names:
for out_var_name in op.output(out_name):
var = None
try:
var = block.var(out_var_name)
except ValueError as e:
var = block._var_recursive(out_var_name)
if var is None or var.type not in _valid_types:
break
if op.type in bf16_op_list:
assert var.dtype == core.VarDesc.VarType.BF16
elif op.type in fp32_op_list:
assert var.dtype == core.VarDesc.VarType.FP32
def test_bf16_pass(self):
bf16_o1_engine = self.get_engine(True)
inputs_spec = [InputSpec([None, 1, 28, 28], 'float32', 'input0')]
labels_spec = [InputSpec([None, 1], 'int64', 'label0')]
bf16_o1_engine.prepare(
inputs_spec=inputs_spec, labels_spec=labels_spec, mode="train"
)
self.check_program(bf16_o1_engine._dist_main_progs["train"][0])
print("BF16!check program successfully!")
if __name__ == "__main__":
unittest.main()
...@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase): ...@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.use_fp16_guard, True) self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False) self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.enable_bf16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_pure_bf16, False)
self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding sharding = strategy.sharding
self.assertEqual(sharding.enable, False) self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.stage, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册