未验证 提交 418edae5 编写于 作者: X xu98bin 提交者: GitHub

auto parallel bf16 (#49079)

* auto parallel bf16
上级 1078e064
......@@ -134,4 +134,7 @@ REGISTER_OP_CUDA_KERNEL(c_concat,
ops::CConcatOpCUDAKernel<double>,
ops::CConcatOpCUDAKernel<int>,
ops::CConcatOpCUDAKernel<int64_t>,
#if NCCL_VERSION_CODE >= 21000
ops::CConcatOpCUDAKernel<plat::bfloat16>,
#endif
ops::CConcatOpCUDAKernel<plat::float16>);
......@@ -22,4 +22,7 @@ REGISTER_OP_CUDA_KERNEL(c_identity,
ops::CIdentityOpKernel<double>,
ops::CIdentityOpKernel<int>,
ops::CIdentityOpKernel<int64_t>,
#if NCCL_VERSION_CODE >= 21000
ops::CIdentityOpKernel<plat::bfloat16>,
#endif
ops::CIdentityOpKernel<plat::float16>);
......@@ -76,6 +76,13 @@ set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "enable_bf16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_pure_bf16", False)
set_field_default_config(AMP, "use_bf16_guard", False)
#########################################
# sharding configuration
#########################################
......
......@@ -266,8 +266,12 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".subprog_")]
if ".cast_fp" in varname:
varname = varname[: varname.index(".cast_fp")]
if ".cast_bf" in varname:
varname = varname[: varname.index(".cast_bf")]
if ".quantized" in varname:
varname = varname[: varname.index(".quantized")]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert block._find_var_recursive(varname)
var = block._var_recursive(varname)
return var.is_parameter
......
......@@ -376,7 +376,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype(
Out_grad,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
......@@ -417,13 +417,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
check_variable_and_dtype(
intermediate_var_0,
'x',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
set_comm_op_dist_attr_for_program(
......@@ -835,7 +835,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
check_variable_and_dtype(
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
......@@ -854,12 +854,15 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
......@@ -1183,10 +1186,13 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'transpose_X': trans_x,
......@@ -1731,7 +1737,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
check_variable_and_dtype(
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
......@@ -1749,12 +1755,15 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
......@@ -2077,10 +2086,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
attrs = {
'trans_x': trans_x,
......@@ -2610,7 +2622,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
check_variable_and_dtype(
X_var,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64', 'uint16'],
'_c_identity',
)
c_identity_op = main_block.append_op(
......@@ -2628,12 +2640,15 @@ class DistributedMulImpl0(DistributedOperatorImpl):
intermediate_var_0.desc.set_shape(ref_shape_x)
check_variable_and_dtype(
intermediate_var_0, 'x', ['float16', 'float32', 'float64'], 'linear'
intermediate_var_0,
'x',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
check_dtype(
intermediate_var_0.dtype,
'dtype',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False}
......@@ -2965,10 +2980,13 @@ class DistributedMulImpl1(DistributedOperatorImpl):
group = new_process_group(group_ranks)
check_variable_and_dtype(
X_var, 'x', ['float16', 'float32', 'float64'], 'linear'
X_var, 'x', ['float16', 'float32', 'float64', 'uint16'], 'linear'
)
check_dtype(
X_var.dtype, 'dtype', ['float16', 'float32', 'float64'], 'linear'
X_var.dtype,
'dtype',
['float16', 'float32', 'float64', 'uint16'],
'linear',
)
# attrs = {'trans_x': False, 'trans_y': False}
attrs = {
......
......@@ -221,13 +221,21 @@ class Parallelizer:
self._dist_context.serial_feed_vars["inputs"]
+ self._dist_context.serial_feed_vars["labels"]
)
if config["use_pure_fp16"]:
if config["enable_bf16"]:
auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config)
auto_parallel_bf16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_bf16_pass.get_loss()
elif config["use_pure_fp16"]:
config["base_opt"] = optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context
)
loss = auto_parallel_fp16_pass.get_loss()
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply(
......
......@@ -18,6 +18,7 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_bf16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
......
此差异已折叠。
......@@ -128,5 +128,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster_partition MODULES test_cluster_partition)
py_test_modules(test_convert_to_process_meshes MODULES
test_convert_to_process_meshes)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver)
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.distributed.fleet import auto
from paddle.fluid.contrib.mixed_precision.bf16.amp_utils import _valid_types
from paddle.fluid.contrib.mixed_precision.fp16_utils import find_true_prev_op
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.static import InputSpec
from paddle.vision.datasets import MNIST
paddle.enable_static()
def apply_pass(use_bf16=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_bf16:
amp = strategy.amp
amp.enable = True
amp.enable_bf16 = True
return strategy
class MnistDataset(MNIST):
def __init__(self, mode, return_label=True):
super().__init__(mode=mode)
self.return_label = return_label
def __getitem__(self, idx):
img = np.reshape(self.images[idx], [1, 28, 28])
if self.return_label:
return img, np.array(self.labels[idx]).astype('int64')
return (img,)
def __len__(self):
return len(self.images)
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Model(nn.Layer):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(784, 120)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(120, 10)
def forward(self, input):
input.stop_gradient = True
x = self.flatten(input)
x = self.relu1(self.fc1(x))
x = self.fc2(x)
return x
class TestBF16Pass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 256
self.batch_num = 10
self.dataset = MnistDataset("train")
self.eval_dataset = MnistDataset("test")
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_bf16=False):
reset_prog()
strategy = apply_pass(use_bf16)
model = Model()
opt = paddle.optimizer.SGD(0.001, parameters=model.parameters())
loss = nn.CrossEntropyLoss()
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_program(self, program):
bf16_op_list = {
"matmul_v2",
"elementwise_add",
"relu",
"elementwise_add_grad",
"matmul_v2_grad",
"relu_grad",
}
fp32_op_list = {
"flatten_contiguous_range",
"reduce_mean",
"softmax_with_cross_entropy",
"fill_constant",
"reduce_mean_grad",
"softmax_with_cross_entropy_grad",
}
for block in program.blocks:
for op in block.ops:
if op not in bf16_op_list and op not in fp32_op_list:
continue
for in_name in op.input_names:
for in_var_name in op.input(in_name):
var = None
try:
var = block.var(in_var_name)
except ValueError as e:
var = block._var_recursive(in_var_name)
if var is None or var.type not in _valid_types:
break
if op.type in bf16_op_list:
assert var.dtype == core.VarDesc.VarType.BF16
if "cast_bf16" in in_var_name:
if "@GRAD" in in_var_name:
tmp_in_var_name = in_var_name[
: in_var_name.find("@GRAD")
]
else:
tmp_in_var_name = in_var_name
prev_op = find_true_prev_op(
block.ops, op, tmp_in_var_name
)
assert prev_op is not None
assert prev_op.type == "cast"
for in_name in prev_op.input_names:
for in_var_name in prev_op.input(in_name):
var = block.var(in_var_name)
assert (
var.dtype
== core.VarDesc.VarType.FP32
)
elif op.type in fp32_op_list:
if (
op.type == "softmax_with_cross_entropy"
or op.type == "softmax_with_cross_entropy_grad"
) and in_var_name == "label0":
continue
assert var.dtype == core.VarDesc.VarType.FP32
if "cast_fp32" in in_var_name:
prev_op = find_true_prev_op(
block.ops, op, tmp_in_var_name
)
assert prev_op is not None
assert prev_op.type == "cast"
for in_name in prev_op.input_names:
for in_var_name in prev_op.input(in_name):
var = block.var(in_var_name)
assert (
var.dtype
== core.VarDesc.VarType.BF16
)
for out_name in op.output_names:
for out_var_name in op.output(out_name):
var = None
try:
var = block.var(out_var_name)
except ValueError as e:
var = block._var_recursive(out_var_name)
if var is None or var.type not in _valid_types:
break
if op.type in bf16_op_list:
assert var.dtype == core.VarDesc.VarType.BF16
elif op.type in fp32_op_list:
assert var.dtype == core.VarDesc.VarType.FP32
def test_bf16_pass(self):
bf16_o1_engine = self.get_engine(True)
inputs_spec = [InputSpec([None, 1, 28, 28], 'float32', 'input0')]
labels_spec = [InputSpec([None, 1], 'int64', 'label0')]
bf16_o1_engine.prepare(
inputs_spec=inputs_spec, labels_spec=labels_spec, mode="train"
)
self.check_program(bf16_o1_engine._dist_main_progs["train"][0])
print("BF16!check program successfully!")
if __name__ == "__main__":
unittest.main()
......@@ -41,6 +41,13 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.enable_bf16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_pure_bf16, False)
self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册