未验证 提交 0449b841 编写于 作者: R Roc 提交者: GitHub

Save dygraph model for auto inference (#47463)

上级 21277904
...@@ -144,6 +144,8 @@ class VocabParallelEmbedding(Layer): ...@@ -144,6 +144,8 @@ class VocabParallelEmbedding(Layer):
) )
self.weight.is_distributed = True if self.is_mp else False self.weight.is_distributed = True if self.is_mp else False
if self.weight.is_distributed:
setattr(self.weight, "split_axis", 0)
def forward(self, x): def forward(self, x):
if self.is_mp: if self.is_mp:
...@@ -276,6 +278,9 @@ class ColumnParallelLinear(Layer): ...@@ -276,6 +278,9 @@ class ColumnParallelLinear(Layer):
self.weight.is_distributed = True if self.is_mp else False self.weight.is_distributed = True if self.is_mp else False
if self.weight.is_distributed:
setattr(self.weight, "split_axis", 1)
if has_bias: if has_bias:
# initialize bias to zero like Megatron # initialize bias to zero like Megatron
self.bias = self.create_parameter( self.bias = self.create_parameter(
...@@ -285,6 +290,8 @@ class ColumnParallelLinear(Layer): ...@@ -285,6 +290,8 @@ class ColumnParallelLinear(Layer):
is_bias=True, is_bias=True,
) )
self.bias.is_distributed = True if self.is_mp else False self.bias.is_distributed = True if self.is_mp else False
if self.bias.is_distributed:
setattr(self.bias, "split_axis", 0)
else: else:
self.bias = None self.bias = None
...@@ -437,6 +444,8 @@ class RowParallelLinear(Layer): ...@@ -437,6 +444,8 @@ class RowParallelLinear(Layer):
) )
self.weight.is_distributed = True if self.is_mp else False self.weight.is_distributed = True if self.is_mp else False
if self.weight.is_distributed:
setattr(self.weight, "split_axis", 0)
if has_bias: if has_bias:
self.bias = self.create_parameter( self.bias = self.create_parameter(
......
...@@ -62,6 +62,17 @@ _first_cap_re = re.compile('(.)([A-Z][a-z]+)') ...@@ -62,6 +62,17 @@ _first_cap_re = re.compile('(.)([A-Z][a-z]+)')
_all_cap_re = re.compile('([a-z])([A-Z])') _all_cap_re = re.compile('([a-z])([A-Z])')
def _scope_dist2single(dist_scope):
mapping = {
"row_parallel_linear": "linear",
"column_parallel_linear": "linear",
"vocab_parallel_embedding": "embedding",
# "parallel_cross_entropy": "cross_entropy", while mp_layer has parallel_cross_entropy,
# but there is no parameters so the mapping of parallel_cross_entropy is not neccessary.
}
return mapping.get(dist_scope, dist_scope)
def _convert_camel_to_snake(name): def _convert_camel_to_snake(name):
s1 = _first_cap_re.sub(r'\1_\2', name) s1 = _first_cap_re.sub(r'\1_\2', name)
return _all_cap_re.sub(r'\1_\2', s1).lower() return _all_cap_re.sub(r'\1_\2', s1).lower()
...@@ -137,6 +148,7 @@ class Layer(object): ...@@ -137,6 +148,7 @@ class Layer(object):
self.training = True self.training = True
if name_scope is None: if name_scope is None:
name_scope = _convert_camel_to_snake(self.__class__.__name__) name_scope = _convert_camel_to_snake(self.__class__.__name__)
name_scope = _scope_dist2single(name_scope)
self._full_name = unique_name.generate(name_scope) self._full_name = unique_name.generate(name_scope)
self._helper = LayerObjectHelper(self._full_name) self._helper = LayerObjectHelper(self._full_name)
self._built = False self._built = False
......
...@@ -952,3 +952,11 @@ if((WITH_GPU) AND (LINUX)) ...@@ -952,3 +952,11 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_dygraph_dist_save_load set_tests_properties(test_dygraph_dist_save_load
PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST") PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST")
endif() endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_dygraph_save_for_auto_infer MODULES test_dygraph_save_for_auto_infer
ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_dygraph_save_for_auto_infer
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import numpy as np
import tempfile
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear, Embedding
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu.mp_layers import (
RowParallelLinear,
ColumnParallelLinear,
VocabParallelEmbedding,
)
from paddle.distributed.auto_parallel import engine
from paddle.distributed.sharding.group_sharded import group_sharded_parallel
from paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers import (
PipelineLayer,
LayerDesc,
)
import sys
import subprocess
import argparse
import copy
from paddle import distributed as dist
from paddle.distributed.utils.log_utils import get_logger
from paddle.fluid.dataloader.dataset import IterableDataset
from paddle.incubate.distributed.utils.io import save_for_auto_inference
logger = get_logger("INFO", __file__)
epoch = 2
linear_size = 1000
class MLP_pipe(PipelineLayer):
def __init__(
self,
embedding_size=1000,
linear_size=1000,
param_attr=None,
bias_attr=None,
):
desc = [
LayerDesc(
VocabParallelEmbedding,
num_embeddings=embedding_size,
embedding_dim=linear_size,
),
LayerDesc(
RowParallelLinear,
in_features=linear_size,
out_features=linear_size,
has_bias=True,
),
LayerDesc(
ColumnParallelLinear,
in_features=linear_size,
out_features=linear_size,
gather_output=True,
has_bias=True,
),
LayerDesc(Linear, input_dim=linear_size, output_dim=10),
]
super(MLP_pipe, self).__init__(
desc,
num_stages=2,
loss_fn=paddle.nn.CrossEntropyLoss(),
topology=fleet.get_hybrid_communicate_group()._topo,
)
class MLP_Hybrid(fluid.Layer):
def __init__(
self,
embedding_size=1000,
linear_size=1000,
param_attr=None,
bias_attr=None,
):
super(MLP_Hybrid, self).__init__()
self.embedding = VocabParallelEmbedding(embedding_size, linear_size)
self._linear1 = RowParallelLinear(
linear_size, linear_size, has_bias=True, input_is_parallel=True
)
self._linear2 = ColumnParallelLinear(
linear_size, linear_size, gather_output=True, has_bias=True
)
self._linear3 = Linear(linear_size, 10)
def forward(self, src):
inputs = self.embedding(src)
# slice for a bug in row parallel linear
mp_group = (
fleet.get_hybrid_communicate_group().get_model_parallel_group()
)
step = inputs.shape[-1] // mp_group.nranks
mp_rank = dist.get_rank(mp_group)
mp_rank = mp_rank if mp_rank >= 0 else 0
inputs = inputs[..., step * mp_rank : step * mp_rank + step]
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
class MLP(fluid.Layer):
def __init__(
self,
embedding_size=1000,
linear_size=1000,
param_attr=None,
bias_attr=None,
):
super(MLP, self).__init__()
self.embedding = Embedding((embedding_size, linear_size))
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, src):
inputs = self.embedding(src)
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def gen_uniq_random_numbers(low, high, size, seed):
assert np.prod(size) <= high - low
pool = list(range(low, high))
data = np.zeros(size).astype("int32").reshape(-1)
np.random.seed(10245)
for i in range(np.prod(size)):
pos = int(np.random.randint(0, len(pool)))
data[i] = pool[pos]
pool.remove(pool[pos])
np.random.seed(seed)
return data.reshape(size)
class RangeIterableDataset(IterableDataset):
def __init__(
self, data_path, ebd=1000, start=0, end=100, linear_size=1000, seed=1024
):
self.start = start
self.end = end
self.img = gen_uniq_random_numbers(0, 1000, (100, 1), seed)
def __iter__(self):
for idx in range(self.start, self.end):
label = np.ones(1).astype('int32')
yield self.img[idx], label
def optimizer_setting(args, model):
optimizer = paddle.optimizer.SGD(
learning_rate=0.0 if args.strategy == "static" else 0.01,
parameters=model.parameters(),
weight_decay=0.01,
)
return optimizer
def train_mlp(args, model, loss, opt_state=None, save_model=False):
optimizer = optimizer_setting(args, model=model)
if args.strategy in ["mp", "dp", "pp"]:
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
elif args.strategy == "sharding_stage2":
model, optimizer, _ = wrap_sharding_2_3(
model, optimizer, None, False, 2
)
elif args.strategy == "sharding_stage3":
model, optimizer, _ = wrap_sharding_2_3(
model, optimizer, None, False, 3
)
elif args.strategy != "single":
raise ValueError(f"not supported strategy: {args.strategy}")
dataset = RangeIterableDataset(
data_path=os.path.join(args.output_dir, "data.npy"), seed=args.seed
)
train_loader = paddle.io.DataLoader(dataset, batch_size=100, drop_last=True)
if dist.get_world_size() > 1:
pp_degree = (
fleet.get_hybrid_communicate_group().get_pipe_parallel_world_size()
)
else:
pp_degree = 0
model.train()
for epo in range(epoch):
for step, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
if pp_degree <= 1:
out = model(img)
avg_loss = loss(out, label)
paddle.device.cuda.synchronize()
avg_loss.backward()
optimizer.step()
else:
avg_loss = model.train_batch(data, optimizer)
model.eval()
print("=============== predict in dygraph mode =================")
for step, data in enumerate(train_loader()):
img, label = data
if pp_degree <= 1:
out = model(img)
out = out.numpy()
else:
out = model.eval_batch(data)
out = np.array(out)
paddle.device.cuda.synchronize()
if save_model:
return model, optimizer, out
return None
def train_mlp_static(args, model, loss, opt_state=None, save_model=False):
optimizer = optimizer_setting(args, model=model)
model = engine.Engine(model, loss=loss, optimizer=optimizer, strategy=None)
dataset = RangeIterableDataset(
data_path=os.path.join(args.output_dir, "data.npy"), seed=args.seed
)
model.load(os.path.join(args.load_dir, "saved"), load_optimizer=False)
model.fit(dataset, epochs=1)
model.save(os.path.join(args.output_dir, "static_save"))
paddle.device.cuda.synchronize()
print("=============== predict in static mode =================")
out = model.predict(dataset, verbose=1000)
if save_model:
return model, optimizer
return out
def step_check(output_dir):
p1 = os.path.join(output_dir, "static.npy")
p2 = os.path.join(output_dir, "dygraph.npy")
m1 = np.load(p1).reshape(-1)
m2 = np.load(p2).reshape(-1)
try:
assert np.allclose(m1, m2, rtol=1e-5, atol=1e-6)
except:
diff = m1 - m2
logger.error(f"max diff{diff.max()}, min diff: {diff.min()}")
logger.error(f"{m1[:10]}")
logger.error(f"{m2[:10]}")
raise ValueError("diff is too large")
def step_save(strategy, output_dir, seed):
python_exe = sys.executable
# save data
os.makedirs(output_dir + "/logs", exist_ok=True)
filename = os.path.basename(__file__)
if strategy != "single":
cmd = (
f"{python_exe} -m paddle.distributed.launch --log_dir {output_dir}/logs"
f" --gpus 0,1 {filename} --cmd save --strategy {strategy} --output_dir {output_dir} --seed {seed}"
)
else:
cmd = f"{python_exe} {filename} --cmd save --strategy {strategy} --output_dir {output_dir} --seed {seed}"
logger.info(f"exe: {cmd}")
p = subprocess.Popen(cmd.split())
p.communicate()
assert p.poll() == 0
def step_load(curent_strateggy, saved_dir, seed):
python_exe = sys.executable
os.makedirs(f"{saved_dir}/load/logs", exist_ok=True)
filename = os.path.basename(__file__)
# load dp
cmd = (
f"{python_exe} -m paddle.distributed.launch --log_dir {saved_dir}/load/logs"
f" --gpus 0 {filename} --cmd load --strategy {curent_strateggy} --output_dir {saved_dir} --load_dir {saved_dir} --seed {seed}"
)
logger.info(f"exe: {cmd}")
env = copy.copy(os.environ)
env["CUDA_VISIBLE_DEVICES"] = "0"
p = subprocess.Popen(cmd.split(), env=env)
p.communicate()
assert p.poll() == 0
def wrap_sharding_2_3(model, optimizer, scaler, sharding_offload, stage):
group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
level = "p_g_os" if stage == 3 else "os_g"
return group_sharded_parallel(
model=model,
optimizer=optimizer,
level=level,
scaler=scaler,
group=group,
offload=sharding_offload,
)
def test_save_load(args):
np.random.seed(args.seed)
paddle.seed(args.seed)
if args.cmd == "main":
run_case(args)
return
paddle.distributed.init_parallel_env()
strategy = fleet.DistributedStrategy()
if args.strategy == "dp":
strategy.hybrid_configs = {
"dp_degree": 2,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": 1,
}
elif args.strategy in ["sharding_stage2", "sharding_stage3"]:
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": 2,
}
elif args.strategy == "mp":
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 1,
"sharding_degree": 1,
}
elif args.strategy == "pp":
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 2,
"sharding_degree": 1,
}
strategy.pipeline_configs = {
"accumulate_steps": 10,
"micro_batch_size": 10,
}
elif args.strategy == "static":
paddle.enable_static()
elif args.strategy != "single":
raise ValueError(f"Not supported strategy: {args.strategy}")
loss = paddle.nn.CrossEntropyLoss()
fleet.set_log_level("INFO")
if dist.get_world_size() <= 1:
mlp1 = MLP()
if args.strategy == "static":
out_static = train_mlp_static(args, mlp1, loss, save_model=False)
np.save(os.path.join(args.output_dir, "static.npy"), out_static)
else:
model, _, out_dygraph = train_mlp(args, mlp1, loss, save_model=True)
np.save(os.path.join(args.output_dir, "dygraph.npy"), out_dygraph)
else:
fleet.init(is_collective=True, strategy=strategy)
pp_group = (
fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
)
if pp_group.nranks > 1:
mlp1 = MLP_pipe()
else:
mlp1 = MLP_Hybrid()
model, _, out_dygraph = train_mlp(args, mlp1, loss, save_model=True)
if (
dist.get_world_size() == 0
or dist.get_rank() == dist.get_world_size() - 1
):
np.save(os.path.join(args.output_dir, "dygraph.npy"), out_dygraph)
if args.cmd == "save":
save_for_auto_inference(os.path.join(args.output_dir, "saved"), model)
def run_case(args):
saving_strategy = args.test_case.split(":")[0]
loading_strategy = args.test_case.split(":")[1]
output_dir = tempfile.mkdtemp()
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)
try:
step_save(saving_strategy, output_dir, args.seed)
step_load(loading_strategy, output_dir, args.seed + 1)
step_check(output_dir)
except Exception as e:
shutil.rmtree(output_dir)
raise RuntimeError(f"Test failed.\n {e.__str__()}")
shutil.rmtree(output_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--cmd", default="main", choices=["main", "save", "load"]
)
parser.add_argument(
"--strategy",
required=False,
choices=[
"single",
"dp",
"mp",
"pp",
"sharding_stage2",
"sharding_stage3",
"static",
],
)
parser.add_argument(
"--load_way", choices=["paddle.load", "load"], required=False
)
parser.add_argument("--load_dir", required=False)
parser.add_argument("--output_dir", required=False)
parser.add_argument("--output_param_path", required=False)
parser.add_argument(
"--test_case",
required=False,
choices=[
"dp:static",
"mp:static",
"pp:static",
"sharding_stage2:static",
"sharding_stage3:static",
"single:static",
],
)
parser.add_argument("--gather_to", required=False, default=0)
parser.add_argument("--seed", type=int, default=2022)
args = parser.parse_args()
test_save_load(args)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import subprocess
import sys
def strategy_test(saving, seed=1024, loading="static"):
cmd = f"{sys.executable} dygraph_save_for_auto_infer.py --test_case {saving}:{loading} --cmd main --seed {seed}"
p = subprocess.Popen(cmd.split())
p.communicate()
assert p.poll() == 0
class TestHybrid(unittest.TestCase):
def test_dygraph_save_load_dp_sharding_stage2(self):
strategy_test("dp")
strategy_test("mp")
strategy_test("pp")
class TestSharding(unittest.TestCase):
def test_dygraph_save_load_dp_sharding_stage2(self):
strategy_test("sharding_stage2")
strategy_test("sharding_stage3")
class TestSingleCard(unittest.TestCase):
def test_dygraph_save_load_dp_sharding_stage2(self):
strategy_test("single")
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
unittest.main()
...@@ -84,3 +84,4 @@ test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_ ...@@ -84,3 +84,4 @@ test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_
test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dist_save import save from .dist_save import save, save_for_auto_inference
from .dist_load import load from .dist_load import load
...@@ -20,10 +20,10 @@ from paddle.distributed.fleet.utils.log_util import logger ...@@ -20,10 +20,10 @@ from paddle.distributed.fleet.utils.log_util import logger
from paddle.fluid.framework import dygraph_only from paddle.fluid.framework import dygraph_only
import copy import copy
import sys import sys
from .save_for_auto import save_for_auto_inference
from paddle.distributed.fleet.utils.log_util import logger from paddle.distributed.fleet.utils.log_util import logger
__all__ = ["save"] __all__ = ["save", "save_for_auto_inference"]
@dygraph_only @dygraph_only
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
import re
import paddle
from paddle.distributed.fleet.utils.log_util import logger
import os
import pickle
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import (
GroupShardedStage3,
)
from paddle.fluid.framework import dygraph_only
import copy
import numpy as np
__all__ = ["save_for_auto_inference"]
@dygraph_only
def save_for_auto_inference(path_prefix, dist_model, cvt2cpu=False):
"""
Description:
Save model parameters for auto parallel inference.
Supporting dp + mp + pp + sharding(stage1), dp + sharding stage2-3.
MoE not sdupported till MoE is supported in auto parallel mode.
Args:
path_prefix: path prefix to save
If `path_preifx` ends with path sepreator,
the path is processed as a directory and parameters will be saved in it,
automatically named saved_parameters.
Otherwisw, the parameters will be saved with name
path_preifx_dist{global_rank}.pdparams and path_preifx_dist{global_rank}.pdattrs
dist_model:
model in distributed modeß
cvt2cpu: wheather to move parameters to CPU when using sharding stage 3.
The var is invalid if not using sharding stage 3.
Returns:
None
Examples:
dist_model = build_distributed_model()
path_prefix = "path/to/save_infer"
save_for_auto_inference(path_prefix, dist_model=dist_model, original_model=single_model, cvt2cpu=False)
Outputs:
path/to/save_infer_dist0.pdparams path/to/save_infer_dist1.pdparams path/to/save_infer_dist2.pdparams ...
path/to/save_infer_dist0.pdattr path/to/save_infer_dist1.pdattr path/to/save_infer_dist2.pdattr ...
"""
save_dir, basename_prefix = _get_abs_saved_prefix(path_prefix)
if isinstance(dist_model, GroupShardedStage3):
dist_model.get_all_parameters(cvt2cpu)
wrapped_dict = _get_wrapped_dist_state_dict(dist_model.state_dict())
global_rank = paddle.distributed.get_rank()
# save parameters
paddle.save(
wrapped_dict,
os.path.join(save_dir, f"{basename_prefix}_dist{global_rank}.pdparams"),
)
# save attributes
_save_param_attr(
wrapped_dict,
os.path.join(save_dir, f"{basename_prefix}_dist{global_rank}.pdattr"),
)
# unset dims mapping after saving attrs
for _, dist_param in wrapped_dict.items():
_unset_dims_mapping(dist_param)
def _is_first_used(param):
return not hasattr(param, "is_firstly_shared") or param.is_firstly_shared
def _get_all_ranks_of_pp(pp_rank, dp_degree, mp_degree, pp_degree):
"""
Description:
get all global ranks involving given pp_rank
"""
process_group = []
world_size = dp_degree * mp_degree * pp_degree
for i in range(dp_degree):
for k in range(mp_degree):
process_group.append(
i * world_size // dp_degree
+ pp_rank * world_size // dp_degree // pp_degree
+ k
)
return process_group
def _save_param_attr(state_dict_, path, dims_mapping_dict=None):
"""
Description:
save params' attr dict
Args:
state_dict_:
state for which to save attrs, when the state is optimzier state, the master and LRScheduler will be reomoved.
path:
path to save
dims_mapping_dict:
Dims mapping dict, mapping from parameter name in state_dict_ to dims_mapping.
If parameter in state_dict_ has attribute 'dims_mapping', the dims_mapping is ignored.
If parameter has no attribute 'dims_mapping', the dims mapping must contains the parameter's name.
"""
state_dict = copy.copy(state_dict_)
# remove master_weights and LRScheduler, which needs no parameter attributes to save
state_dict.pop("master_weights", None)
state_dict.pop("LR_Scheduler", None)
if dims_mapping_dict is not None:
assert isinstance(
dims_mapping_dict, dict
), "dims_mapping_dict must be an instance of dict"
for k in state_dict.keys():
assert (
k in dims_mapping_dict
), f"param {k} cannot find dims mapping in dims_mapping_dict"
if dist.get_world_size() > 1:
hcg = fleet.get_hybrid_communicate_group()
dp_degree = hcg.get_data_parallel_world_size()
mp_degree = hcg.get_model_parallel_world_size()
pp_degree = hcg.get_pipe_parallel_world_size()
sharding_degree = hcg.get_sharding_parallel_world_size()
dp_degree = dp_degree * sharding_degree
pp_group = hcg.get_pipe_parallel_group()
else:
pp_degree = 1
dp_degree = 1
mp_degree = 1
pp_group = None
hcg = None
logger.debug(f"dp degree * sharding degree : {dp_degree}")
logger.debug(f"mp degree: {mp_degree}")
logger.debug(f"pp degree: {pp_degree}")
pp_rank = dist.get_rank(pp_group)
# Why condition 'pp_rank < 0' exists?
# Because if pp_degree = 1, pp_rank is set -1
pp_rank = 0 if pp_rank <= 0 else pp_rank
if dist.get_world_size() > 1:
process_group = _get_all_ranks_of_pp(
pp_rank, dp_degree, mp_degree, pp_degree
)
else:
process_group = [0]
attr_dict = {}
for k, v in state_dict.items():
dims = len(v.shape)
logger.debug(f"shape: , {k}, {dims}")
attr_d = {
"process_shape": [dp_degree, mp_degree] if hcg else [1],
"process_group": process_group,
"dims_mapping": v.dims_mapping
if hasattr(v, "dims_mapping")
else [-1 for _ in v.shape],
}
attr_dict[k] = attr_d
with open(path, "wb") as f:
pickle.dump(attr_dict, f)
def _unset_dims_mapping(param):
if hasattr(param, "dims_mapping"):
delattr(param, "dims_mapping")
def _get_dims_mapping(dist_parameter, mp_group):
"""
Description:
return the sliting mapping:
{tensor_name: spiting_strategy}
Args:
dist_parameters(list): distributed model parameters
mp_group(ProcessGroup): Model Parallel communication group
Return:
The sliting mapping
Examples:
spliting_strategy's format (-1, -1, -1, 0), meaing the dims
of the tennsor is 4 and it is splited along the first strategy axis in mesh
Mesh Examples: (2, 4) means dp=2, mp=4
"""
import numpy as np
dist_shape = np.array(dist_parameter.shape)
if hasattr(dist_parameter, "split_axis"):
aixs = getattr(dist_parameter, "split_axis")
mapping = [-1 for _ in dist_shape]
mapping[aixs] = 1
logger.debug(
f"{dist_parameter.name} has attr split_axis: mapping: {mapping}"
)
else:
mapping = [-1 for _ in dist_shape]
logger.debug(f"normal parameter: {dist_parameter.name}")
return mapping
def _get_abs_saved_prefix(path_prefix):
"""
Description:
Get absolute dir path and basename prefix of path_prefix, with making path_prefix's directories.
If path_prefix is a directory name, basename is set 'saved_parameters'.
If path_prefix is a file name, basename is extracted from path_prefix.
Args:
path_prefix: str
Return:
(dirpath: str, basename: str)
"""
abs_prefix = os.path.abspath(path_prefix)
if abs_prefix[-1] == os.path.sep:
save_dir = abs_prefix
basename_prefix = "saved_parameters"
else:
save_dir = os.path.dirname(abs_prefix)
basename_prefix = os.path.basename(abs_prefix)
os.makedirs(save_dir, exist_ok=True)
return save_dir, basename_prefix
def _name_mapping_dist2single(state_dict, pp_group):
key_list = []
param_keys = [
v.name
for _, v in state_dict.items()
if isinstance(v, paddle.Tensor) and _is_first_used(v)
]
if pp_group.nranks == 1:
return {k: k for k in param_keys}
dist.all_gather_object(key_list, param_keys, pp_group)
# find how many a op in a each pp:
# {"linear:"[0, 2,0,1,1,...]}
param_types = {}
matcher = re.compile(r"^\w+_\d+(?=\.)")
for pp, keys in enumerate(key_list):
param_type_idx = {}
for k in keys:
matched = matcher.search(k)
logger.debug(f"matched: {k}: {matched}")
assert (
matched is not None
), f"the name of param, '{k}', is not satisfyied the format 'name_idx.xxx'"
name_idx = k[matched.start() : matched.end()]
logger.debug(f"get param_type_idx: {name_idx}")
if name_idx in param_type_idx:
continue
name = "_".join(name_idx.split("_")[:-1])
idx = int(name_idx.split("_")[-1])
param_type_idx.update({name_idx: (name, idx)})
if name not in param_types:
param_types[name] = [0] * pp_group.nranks
param_types[name][pp] += 1
# check if continous
types_idx = {}
for _, v in param_type_idx.items():
if v[0] not in types_idx:
types_idx.update({v[0]: [v[1]]})
else:
types_idx[v[0]].append(v[1])
for k, v in types_idx.items():
assert v == list(
range(v[0], v[-1] + 1)
), f"{k} is not continous: {v}"
logger.debug(f"param type: {param_types}")
# analyse starting index
for k in param_types.keys():
param_types[k] = np.cumsum([0] + param_types[k][:-1])
logger.debug(f"params type: {param_types}")
name_mapping = {}
pp_rank = dist.get_rank(pp_group)
for k in key_list[pp_rank]:
matched = matcher.search(k)
name_idx = k[matched.start() : matched.end()]
name = "_".join(name_idx.split("_")[:-1])
idx = int(name_idx.split("_")[-1])
logger.debug(f"idx: {idx}")
new_idx = param_types[name][pp_rank] + idx
logger.debug(f"new idx: {new_idx}")
new_name_idx = name + "_" + str(new_idx)
name_mapping[k] = new_name_idx + k[matched.end() :]
return name_mapping
def _get_wrapped_dist_state_dict(dist_state_dict):
wrapped_state_dict = dict()
if dist.get_world_size() <= 1:
for _, v in dist_state_dict.items():
wrapped_state_dict[v.name] = v
return wrapped_state_dict
hcg = fleet.get_hybrid_communicate_group()
pp_group = hcg.get_pipe_parallel_group()
mp_group = hcg.get_model_parallel_group()
logger.debug("execute _name_mapping_dist2single")
name_mapping = _name_mapping_dist2single(dist_state_dict, pp_group)
for _, v in dist_state_dict.items():
if not _is_first_used(v):
logger.debug(f"not first used : {v.name}")
continue
wrapped_state_dict[name_mapping[v.name]] = v
setattr(v, "dims_mapping", _get_dims_mapping(v, mp_group))
logger.debug(
f"saving param: {v.name} -> {name_mapping[v.name]} shape: {v.shape}"
)
return wrapped_state_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册