未验证 提交 5cb0f3aa 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] BF16-o1/FP16-o1 PASS support training and generation (#51147)

* [AutoParallel] support bloom

* fix import

* align amp and bf16

* update func name

* clipbyglobalnorm and add_n support bf16

* upgrade amp strategy api

* update bf16 unittest

* fix static clip

---------
Co-authored-by: Nliangjianzhong <liangjianzhong@baidu.com>
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
上级 32baca93
......@@ -723,6 +723,14 @@ class Completer:
tensor_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
for node in while_op_node.inputs:
if node.var().name() == tensor_name:
node_dist_attr = (
self._dist_context.get_dist_attr_for_graph(node)
)
node_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
# Step 4: set the process meshes of the outputs in while_op to the process meshes of the outside output nodes
while_op_outputs_dist_attrs = while_op_dist_attr.outputs_dist_attrs
......@@ -749,6 +757,14 @@ class Completer:
tensor_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
for node in while_op_node.outputs:
if node.var().name() == tensor_name:
node_dist_attr = (
self._dist_context.get_dist_attr_for_graph(node)
)
node_dist_attr.process_mesh = (
nearest_tensor_dist_attr.process_mesh
)
# Amend the process meshes related to array
for array_node_list in self._array_nodes.values():
......
......@@ -75,11 +75,6 @@ set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_fp16_guard", False)
set_field_default_config(AMP, "use_optimizer_fp16", False)
set_field_default_config(AMP, "custom_bf16_list", [])
set_field_default_config(AMP, "custom_fp32_list", [])
set_field_default_config(AMP, "custom_fp32_varnames", [])
set_field_default_config(AMP, "use_bf16_guard", False)
#########################################
......
......@@ -1557,6 +1557,19 @@ class Engine:
cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
for name, param in program.state_dict().items():
param_array = np.array(param)
if name not in state_dict:
continue
if param_array.dtype != state_dict[name].dtype:
self._logger.info(
"cast {}'s dtype from '{}' to '{}'".format(
name,
str(state_dict[name].dtype),
str(param_array.dtype),
)
)
state_dict[name] = state_dict[name].astype(param_array.dtype)
program.set_state_dict(state_dict)
def save(self, path, training=True):
......
......@@ -272,7 +272,8 @@ def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True):
return best_compatible_impl
def is_parameter_related(varname, block):
def is_parameter_related(varname, block, dist_context=None):
# TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are be renamed
if ".subprog_" in varname:
varname = varname[: varname.index(".subprog_")]
if ".cast_fp" in varname:
......@@ -281,10 +282,17 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".cast_bf")]
if ".quantized" in varname:
varname = varname[: varname.index(".quantized")]
# if "@RESHARD" in varname:
# varname = varname[: varname.index("@RESHARD")]
assert block._find_var_recursive(varname)
assert block._find_var_recursive(
varname
), "cannot find var {} in cur block".format(varname)
var = block._var_recursive(varname)
# NOTE(hack method): to find the param which is resharded
if dist_context and "@RESHARD" in varname:
varname = varname[: varname.index("@RESHARD")]
serial_program = dist_context.serial_main_program
var = serial_program.global_block()._find_var_recursive(varname)
if var is None:
return False
return var.is_parameter
......
......@@ -28,6 +28,9 @@ class DistributedScale(DistributedOperatorImplContainer):
register_distributed_operator_impl_container(DistributedScale("scale"))
register_distributed_operator_impl_container(DistributedScale("fill_any_like"))
register_distributed_operator_impl_container(DistributedScale("where"))
register_distributed_operator_impl_container(DistributedScale("tanh"))
class DistributedScaleImpl(DistributedOperatorImpl):
......@@ -50,13 +53,17 @@ class DistributedScaleImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if x_dims_mapping != out_dims_mapping:
return False
in_dims_mappings = []
for in_name in op_desc.input_arg_names():
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
in_dims_mappings.append(in_dims_mapping)
for x_dims_mapping in in_dims_mappings:
if x_dims_mapping != out_dims_mapping:
return False
return True
......@@ -78,10 +85,6 @@ class DistributedScaleImpl(DistributedOperatorImpl):
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
changed = True
if changed:
op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping)
op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping)
return changed
@staticmethod
......@@ -94,3 +97,8 @@ class DistributedScaleImpl(DistributedOperatorImpl):
register_distributed_operator_impl("scale", DistributedScaleImpl("scale"))
register_distributed_operator_impl(
"fill_any_like", DistributedScaleImpl("fill_any_like")
)
register_distributed_operator_impl("where", DistributedScaleImpl("where"))
register_distributed_operator_impl("tanh", DistributedScaleImpl("tanh"))
......@@ -2213,7 +2213,11 @@ class Resharder:
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
assert op_input_attrs
assert (
op_input_attrs
), "The input '{}' of op '{}' has no distibution attributes in subblock".format(
op.name, var_name
)
return op_input_attrs
......
......@@ -18,7 +18,6 @@ from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
from .auto_parallel_fp16 import * # noqa: F403
from .auto_parallel_bf16 import * # noqa: F403
from .auto_parallel_recompute import * # noqa: F403
from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
......
......@@ -1611,7 +1611,9 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
dp_group = None
for input_name in op.input_arg_names:
if not is_parameter_related(input_name, op.block):
# TODO(zhaoyingli): maintain a dict in dist_context to record all variables which are renamed,
# to solve the param@RESHARD cannot be identifed.
if not is_parameter_related(input_name, op.block, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
process_mesh = dist_attr.process_mesh
input_dim_mapping = dist_attr.get_input_dims_mapping(input_name)
......
......@@ -126,6 +126,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_convert_to_process_meshes)
py_test_modules(test_pass_bf16 MODULES test_pass_bf16)
py_test_modules(test_dist_saver MODULES test_dist_saver)
py_test_modules(test_engine_save_load MODULES test_engine_save_load)
# End of unittests WITH single card WITHOUT timeout
endif()
......@@ -29,6 +29,8 @@ def apply_pass(use_amp=False, level=None):
if use_amp:
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = level
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
......@@ -37,8 +39,6 @@ def apply_pass(use_amp=False, level=None):
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.level = level
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
......
......@@ -31,6 +31,8 @@ def apply_pass():
amp = dist_strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
amp.custom_white_list = ["lookup_table", "lookup_table_v2"]
amp.custom_black_list = [
"reduce_sum",
......@@ -38,8 +40,6 @@ def apply_pass():
"elementwise_div",
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.level = "o2"
qat = dist_strategy.qat
qat.enable = True
......@@ -119,9 +119,6 @@ class TestQuantizationPassExport(unittest.TestCase):
def test_qat_pass_2(self):
batch_size = 1
batch_num = 10
strategy = apply_pass()
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, strategy=strategy)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.fleet import auto
paddle.enable_static()
batch_size = 2
hidden_size = 1024
# sequence_len = 512
image_size = hidden_size
class_num = 10
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super().__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
auto.shard_tensor(input, auto.ProcessMesh([0]), [None, None])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
class TestSaveLoad(unittest.TestCase):
def test_fp32_save_fp16_load(self):
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None,
)
metric = paddle.metric.Accuracy()
inputs_spec = [
paddle.static.InputSpec(
shape=[batch_size, image_size], name="input", dtype="float32"
)
]
labels_spec = [
paddle.static.InputSpec(
shape=[batch_size, 1], name="label", dtype="int64"
)
]
# build fp32 model
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine_fp32 = auto.Engine(
mlp, loss, optimizer, metric, strategy=strategy
)
engine_fp32.prepare(inputs_spec, labels_spec, mode="train")
fp32_state = {
k: np.array(v)
for k, v in engine_fp32.main_program.state_dict("param").items()
}
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine_fp32.save(model_filename)
# build fp16 model
strategy = auto.Strategy()
strategy.auto_mode = "semi"
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
engine_fp16 = auto.Engine(
mlp, loss, optimizer, metric, strategy=strategy
)
engine_fp16.load(model_filename)
engine_fp16.prepare(inputs_spec, labels_spec, mode="train")
fp16_state = {
k: np.array(v)
for k, v in engine_fp16.main_program.state_dict("param").items()
}
# check param
for name, fp32_param in fp32_state.items():
fp16_param = fp16_state[name]
if "layer_norm" in name:
assert fp16_param.dtype == np.float32
else:
assert fp16_param.dtype == np.float16
np.testing.assert_allclose(fp32_param, fp16_param, atol=1e-4)
temp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -36,7 +36,8 @@ def apply_pass(use_bf16=False):
if use_bf16:
amp = strategy.amp
amp.enable = True
amp.enable_bf16 = True
amp.dtype = "bfloat16"
amp.level = "o1"
return strategy
......
......@@ -28,8 +28,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp
self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.dtype, "float16")
self.assertAlmostEqual(amp.level, "o1")
self.assertEqual(amp.dtype, "float16")
self.assertEqual(amp.level, "o1")
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
......@@ -40,10 +40,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, False)
self.assertEqual(amp.custom_bf16_list, [])
self.assertEqual(amp.custom_fp32_list, [])
self.assertEqual(amp.custom_fp32_varnames, [])
self.assertEqual(amp.use_bf16_guard, False)
sharding = strategy.sharding
......@@ -91,6 +87,8 @@ class TestStrategy(unittest.TestCase):
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
amp.init_loss_scaling = 16384.0
amp.incr_every_n_steps = 2000
amp.decr_every_n_nan_or_inf = 4
......@@ -101,8 +99,9 @@ class TestStrategy(unittest.TestCase):
amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"]
amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True)
self.assertEqual(amp.dtype, "float16")
self.assertEqual(amp.level, "o2")
self.assertAlmostEqual(amp.init_loss_scaling, 16384.0)
self.assertEqual(amp.incr_every_n_steps, 2000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 4)
......@@ -113,7 +112,6 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
sharding = strategy.sharding
sharding.enable = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册