未验证 提交 3650c4a8 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

pp 策略调整后,模型转换,以便模型热启 (#52927)

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
上级 0e30d56a
......@@ -564,7 +564,7 @@ class PipelineLayer(nn.Layer):
self.segment_parts = seg.do_segment()
logger.info(
"segment result:"
f"segment with method: {seg_method}; result: "
+ ", ".join(str(arg) for arg in self.segment_parts)
)
......@@ -594,7 +594,7 @@ class PipelineLayer(nn.Layer):
self.segment_parts = seg.do_segment()
logger.info(
"segment result:"
f"segment with method: {seg_method}; result: "
+ ", ".join(str(arg) for arg in self.segment_parts)
)
......
......@@ -168,6 +168,19 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT)
test_parallel_dygraph_pipeline_parallel_with_virtual_stage
PROPERTIES TIMEOUT "500")
endif()
if((WITH_GPU) AND LOCAL_ALL_PLAT)
bash_test_modules(
test_parallel_dygraph_pp_adaptor
START_BASH
../../dist_test.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21976;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(test_parallel_dygraph_pp_adaptor PROPERTIES TIMEOUT
"500")
endif()
if((WITH_GPU OR WITH_XPU) AND (LINUX))
py_test_modules(
test_fleet_localsgd_meta_optimizer MODULES
......
......@@ -121,11 +121,11 @@ class CriterionPipe(Layer):
class ModelPipe(PipelineLayer):
def __init__(self, topology):
def __init__(self, topology, transformer_layer_num: int = 6):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(6):
for x in range(transformer_layer_num):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from hybrid_parallel_pp_transformer import ModelPipe, set_random_seed
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
transformer_layer_num = 8
class TestDistPPSaveTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
print(f"pwd {os.getcwd()}")
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology, transformer_layer_num=transformer_layer_num)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
output_dir = "{}/mp_00_sharding_00_pp_{:0>2d}".format(
"./pp_transformer", pp_id
)
try:
os.makedirs(output_dir)
except:
# dir is already created, do nothing
pass
for step_id in range(2):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
paddle.save(
model.state_dict(),
os.path.join(output_dir, "model.pdparams"),
)
paddle.save(
optimizer.state_dict(),
os.path.join(output_dir, "model_state.pdopt"),
)
meta_dict = {
"epoch": 0,
"step": 2,
"cuda_rng_state": paddle.get_cuda_rng_state(),
}
paddle.save(meta_dict, os.path.join(output_dir, "meta_state.pdopt"))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from hybrid_parallel_pp_transformer_with_virtual_stage import (
ModelPipe,
set_random_seed,
)
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
transformer_layer_num = 8
class TestDistPPSaveTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
print(f"pwd {os.getcwd()}")
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology, transformer_layer_num=transformer_layer_num)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
output_dir = "{}/mp_00_sharding_00_pp_{:0>2d}".format(
"./pp_transformer_vp", pp_id
)
try:
os.makedirs(output_dir)
except:
# dir is already created, do nothing
pass
for step_id in range(2):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
paddle.save(
model.state_dict(),
os.path.join(output_dir, "model.pdparams"),
)
paddle.save(
optimizer.state_dict(),
os.path.join(output_dir, "model_state.pdopt"),
)
meta_dict = {
"epoch": 0,
"step": 2,
"cuda_rng_state": paddle.get_cuda_rng_state(),
}
paddle.save(meta_dict, os.path.join(output_dir, "meta_state.pdopt"))
if __name__ == "__main__":
unittest.main()
......@@ -120,11 +120,10 @@ class CriterionPipe(Layer):
class ModelPipe(PipelineLayer):
def __init__(self, topology):
def __init__(self, topology, transformer_layer_num: int = 8):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(8):
for x in range(transformer_layer_num):
self.descs.append(LayerDesc(TransformerNetPipe))
self.descs.append(lambda x: x[0])
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus
import paddle
from paddle.distributed.fleet.utils.pp_parallel_adaptor import (
ParallelConfig,
PipeLineModelAdaptor,
adaptor_from_args,
parse_args,
)
class TestPPAdaptor(TestMultipleGpus):
def test_parse_args(self):
args = parse_args()
self.assertEqual(args.src_mp, args.dst_mp)
adaptor = adaptor_from_args(args)
self.assertTrue(adaptor is not None)
def test_hybrid_parallel_transformer_unbalanced_data(self):
print(f"pwd {os.getcwd()}")
self.run_mnist_2gpu('hybrid_parallel_pp_transformer_save.py')
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_save_with_virtual_stage.py'
)
# test pp adaptor
dir1 = "./pp_transformer"
p_config1 = ParallelConfig(mp=1, pp=2, vpp=1, sharding=1)
dir2 = "./pp_transformer_vp"
p_config2 = ParallelConfig(mp=1, pp=2, vpp=2, sharding=1)
pp_to_vp = PipeLineModelAdaptor(
src_parallel_config=p_config1,
dst_parallel_config=p_config2,
transformer_layer_num=8,
segment_method="layer",
)
vp_to_pp = PipeLineModelAdaptor(
src_parallel_config=p_config2,
dst_parallel_config=p_config1,
transformer_layer_num=8,
segment_method="layer",
)
def check_converted_model(converted_model_dir, expected_model_dir):
# for compatibility, converted_model_dir may contain more key than
# expected model, which does not hinder model recovering
for i in range(p_config1.pp):
sub_converted_model_dir = (
"{}/mp_00_sharding_00_pp_{:0>2d}".format(
converted_model_dir, i
)
)
sub_expected_model_dir = (
"{}/mp_00_sharding_00_pp_{:0>2d}".format(
expected_model_dir, i
)
)
print(
f"converted_model_dir: {sub_converted_model_dir}; expected_model_dir: {sub_expected_model_dir}"
)
def check_names(dict_1, dict_2):
for (k, v) in dict_2.items():
self.assertTrue(k in dict_1)
self.assertEqual(
getattr(v, "name", ""),
getattr(dict_1[k], "name", ""),
)
# check param
params_1 = paddle.load(
f"{sub_converted_model_dir}/model.pdparams"
)
params_2 = paddle.load(
f"{sub_expected_model_dir}/model.pdparams"
)
check_names(params_1, params_2)
del params_1
del params_2
# check opt
opt_1 = paddle.load(
f"{sub_converted_model_dir}/model_state.pdopt"
)
opt_2 = paddle.load(
f"{sub_expected_model_dir}/model_state.pdopt"
)
check_names(opt_1, opt_2)
# check master wieghts
if "master_weights" in opt_2:
self.assertTrue("master_weights" in opt_1)
check_names(
opt_2["master_weights"], opt_1["master_weights"]
)
def create_dir_if_nonexist(dir: str):
if not os.path.exists(dir):
os.makedirs(dir)
# check pp to vp
tmp_dir1 = "./tmp_pp_to_vp"
create_dir_if_nonexist(tmp_dir1)
pp_to_vp.apply(dir1, tmp_dir1)
# browse the converted model
pp_to_vp.peek_model(tmp_dir1)
# check
check_converted_model(tmp_dir1, dir2)
# check vp to pp
tmp_dir2 = "./tmp_vp_to_pp"
create_dir_if_nonexist(tmp_dir2)
vp_to_pp.apply(dir2, tmp_dir2)
vp_to_pp.peek_model(tmp_dir2)
check_converted_model(tmp_dir2, dir1)
# check uniform segment
tmp_dir3 = "./tmp_vp_to_pp_uniform"
create_dir_if_nonexist(tmp_dir3)
vp_to_pp_uniform = PipeLineModelAdaptor(
src_parallel_config=p_config2,
dst_parallel_config=p_config1,
transformer_layer_num=8,
segment_method="uniform",
)
vp_to_pp_uniform.apply(dir2, tmp_dir3)
vp_to_pp_uniform.peek_model(tmp_dir3)
tmp_dir4 = "./tmp_pp_to_pp_uniform"
create_dir_if_nonexist(tmp_dir4)
pp_to_pp_uniform = PipeLineModelAdaptor(
src_parallel_config=p_config1,
dst_parallel_config=p_config1,
transformer_layer_num=8,
segment_method="uniform",
)
pp_to_pp_uniform.apply(dir1, tmp_dir4)
pp_to_pp_uniform.peek_model(tmp_dir4)
check_converted_model(tmp_dir3, tmp_dir4)
# rm dirs
for d in [dir1, dir2, tmp_dir1, tmp_dir2, tmp_dir3, tmp_dir4]:
shutil.rmtree(d, ignore_errors=True)
if __name__ == "__main__":
unittest.main()
......@@ -14,6 +14,7 @@ test_dygraph_sharding_stage3_for_eager,,,350,DIST,../../dist_test.sh,2,,http_pro
test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pp_adaptor,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
......
......@@ -1071,11 +1071,18 @@ def load(path, **configs):
# paddle2.0: paddle.save/load
if "StructuredToParameterName@@" in load_result:
for key in load_result["StructuredToParameterName@@"]:
for (key, name) in load_result[
"StructuredToParameterName@@"
].items():
if isinstance(load_result[key], np.ndarray):
load_result[key] = _ndarray_to_tensor(
load_result[key], config.return_numpy
)
# default name is "generatedxxx" which is set in Tensor init, if not set
if not config.return_numpy and getattr(
load_result[key], "name", ""
):
load_result[key].name = name
if (
not config.keep_name_table
......
......@@ -76,7 +76,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
set_tests_properties(test_pass_quantization PROPERTIES TIMEOUT 60)
py_test_modules(test_tuning_recompute MODULES test_tuning_recompute)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 300)
py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass)
set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 20)
py_test_modules(test_align_tool MODULES test_align_tool)
......
......@@ -29,6 +29,6 @@ endif()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
list(APPEND DIST_TEST_OPS ${TEST_OP})
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 120)
set_tests_properties(${TEST_OP} PROPERTIES TIMEOUT 200)
set_tests_properties(${TEST_OP} PROPERTIES LABELS "RUN_TYPE=DIST")
endforeach()
......@@ -197,7 +197,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT
100)
set_tests_properties(test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 180)
set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册