未验证 提交 b9defb4f 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] Save&Load Module (#36558)

* AutoParallel Save&Load

* tiny modi

* update func name

* tiny fix

* add NotImplementedError

* fix doc

* update func name

* update func param

* update interface

* add unitest & modi make_data_unshard

* update unittest

* update unittest

* fix unittest

* fix cmakelist

* update unittest
上级 53b3f40f
...@@ -111,21 +111,6 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): ...@@ -111,21 +111,6 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
return best_compatible_impl, idx return best_compatible_impl, idx
# def copy_distributed_attr_for_var(src_op_dist_attr, dst_var, src_var):
# """
# copy src var's dist_attr to dst var
# """
# import copy
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# dist_attr = copy.deepcopy(
# auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
# dist_attr._owner_tensor = var
# dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
# src_var)._owner_context
# auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def copy_distributed_attr_for_var(dist_context, dst_var, src_var): def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
""" """
copy src var's dist_attr to dst var copy src var's dist_attr to dst var
...@@ -134,38 +119,6 @@ def copy_distributed_attr_for_var(dist_context, dst_var, src_var): ...@@ -134,38 +119,6 @@ def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr) dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)
# def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
# """
# copy src op's dist_attr to dst dist op
# """
# from ..attribute import OperatorDistributedAttribute
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
# auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
# op_dist_attr)
# auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
# op_dist_attr)
# op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
# op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
# for input_varname in dist_op.desc.input_arg_names():
# input_var = dst_block.var(input_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# input_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
# for output_varname in dist_op.desc.output_arg_names():
# output_var = dst_block.var(output_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# output_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_output_dims_mapping(output_varname,
# tensor_dims_mapping)
def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block, def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
src_op_dist_attr): src_op_dist_attr):
""" """
......
...@@ -17,6 +17,7 @@ from paddle.distributed.fleet import cloud_utils ...@@ -17,6 +17,7 @@ from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
...@@ -38,8 +39,7 @@ class AutoParallelizer: ...@@ -38,8 +39,7 @@ class AutoParallelizer:
self._fleet = fleet self._fleet = fleet
self._optimizer = self._fleet.user_defined_optimizer self._optimizer = self._fleet.user_defined_optimizer
self._dist_strategy = self._fleet._user_defined_strategy self._dist_strategy = self._fleet._user_defined_strategy
# self._dist_context = DistributedContext() self._dist_context = DistributedContext()
self._dist_context = get_default_distributed_context()
def _remove_distributed_attrs(self, main_program): def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix() suffix = core.kAutoParallelSuffix()
...@@ -53,23 +53,15 @@ class AutoParallelizer: ...@@ -53,23 +53,15 @@ class AutoParallelizer:
def parallelize(self, def parallelize(self,
loss, loss,
startup_program=None, startup_program,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
self._original_main_program = loss.block.program
# For now, we only allow user to use the default startup and main program
assert startup_program is not None assert startup_program is not None
if startup_program == None: main_program = loss.block.program
self._original_startup_program = \
paddle.static.default_startup_program().clone(for_test=False)
startup_program = paddle.static.default_startup_program()
else:
self._original_startup_program = \
startup_program.clone(for_test=False)
# Annotation completion # Annotation completion
completed_main_program = complete_annotation( completed_main_program = complete_annotation(main_program,
self._original_main_program, self._dist_context) self._dist_context)
# Logical partition # Logical partition
rank = paddle.distributed.get_rank() rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
...@@ -93,9 +85,13 @@ class AutoParallelizer: ...@@ -93,9 +85,13 @@ class AutoParallelizer:
# The last step: remove all distributed attributes to be compatiable # The last step: remove all distributed attributes to be compatiable
# with inference. # with inference.
self._remove_distributed_attrs(partitioned_main_prog) self._remove_distributed_attrs(partitioned_main_prog)
make_data_unshard(partitioned_main_prog, partitioned_startup_prog) make_data_unshard(partitioned_main_prog, partitioned_startup_prog,
self._dist_context)
reshard(partitioned_main_prog, partitioned_startup_prog, rank, reshard(partitioned_main_prog, partitioned_startup_prog, rank,
self._dist_context) self._dist_context)
# Copy distributed info to the default context
set_default_distributed_context(self._dist_context)
return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog return dist_optimize_ops, dist_params_grads, partitioned_startup_prog, partitioned_main_prog
...@@ -12,9 +12,16 @@ ...@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import os
import paddle
import threading import threading
import paddle.fluid.core as core
import numpy as np import numpy as np
import warnings
import logging
import paddle.fluid.core as core
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.framework.io import _to_LodTensor
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -338,9 +345,10 @@ def _get_unshard_dist_shape(var, dist_attr): ...@@ -338,9 +345,10 @@ def _get_unshard_dist_shape(var, dist_attr):
return new_shape return new_shape
def make_data_unshard(dist_main_prog, dist_startup_prog): def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None):
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
dist_context = get_default_distributed_context() if dist_context is None:
dist_context = get_default_distributed_context()
for var in dist_main_prog.list_vars(): for var in dist_main_prog.list_vars():
if var.is_data: if var.is_data:
...@@ -352,3 +360,140 @@ def make_data_unshard(dist_main_prog, dist_startup_prog): ...@@ -352,3 +360,140 @@ def make_data_unshard(dist_main_prog, dist_startup_prog):
dim_mapping = [-1] * len(dim_mapping) dim_mapping = [-1] * len(dim_mapping)
tensor_dist_attr.dims_mapping = dim_mapping tensor_dist_attr.dims_mapping = dim_mapping
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
def _check_addition_info(addition_info):
"""
Validity check of additional information
"""
if not addition_info:
return addition_info
elif not isinstance(addition_info, dict):
raise TypeError(
"The type of addition_info should be 'dict', but got {}".format(
str(type(addition_info))))
else:
return addition_info
def _check_valid_path(file_path):
"""
Validity check of input file path
"""
if not file_path:
return file_path
elif isinstance(file_path, str):
if not os.path.exists(file_path):
raise ValueError("The file_path '{}' does not exist.".format(
file_path))
else:
return [file_path]
elif isinstance(file_path, list):
if not all(isinstance(file, str) for file in file_path):
raise ValueError("The type of each file_path should be str.")
if not all(os.path.exists(file) for file in file_path):
raise ValueError("The file_path's file does not exist.")
return file_path
else:
raise TypeError(
"The type of file_path should be 'str' or 'list', but got '{}'.".
format(str(type(file_path))))
def save_distributed_checkpoint(program,
checkpoint_path,
is_integrated=False,
addition_info=None,
dist_attr_path=None):
"""
Save model parameter state, optimzer state, distributed attribute and
additional information of each rank.
Args:
program(Program): The program to be saved.
checkpoint_path(str): The path of the checkpoint file to be saved.
is_integrated(bool, optional): Whether to integrate param before save. Default: False.
addition_info(dict, optional): Additional information. Default: None.
dist_attr_path(str, optional): The path of distributed attribute file to be saved. Default: None
Returns:
None
Examples:
.. code-block:: python
ckpt_path = os.path.join(args.output_dir, "step_%d" % step)
os.makedirs(ckpt_path, exist_ok=True)
save_distributed_checkpoint(program, ckpt_path)
"""
if not is_integrated:
rank = paddle.distributed.get_rank()
ckpt_file_name = os.path.join(checkpoint_path,
"model_state_rank{}.pdmodel".format(rank))
state_dict = {
"model": program.state_dict(),
"ranks": paddle.distributed.get_world_size()
}
if _check_addition_info(addition_info):
state_dict["addition_info"] = addition_info
paddle.save(state_dict, ckpt_file_name)
logging.info("Already save model to {}".format(checkpoint_path))
if dist_attr_path:
raise NotImplementedError(
"Save distributed attribute has not been implemented.")
else:
# TODO: integrate param before save
raise NotImplementedError(
"Integrating parameter has not been implemented.")
def load_distributed_checkpoint(checkpoint_path,
program=None,
dist_attr_path=None):
"""
Load parameter, optimizer, distributed attribute and addition_info of model.
Args:
checkpoint_path(str|list[str]): checkpoint_path's type can be 'str' or 'list', \
which must be in order of rank id when type is 'list'.
program(Program, optional): The program to be updated with checkpoint_path. Default: None.
dist_attr_path(str|list[str], optional): dist_attr_path's type can be 'str' or 'list', \
which must be in order of rank id when type is 'list'. Default: None.
Returns:
None or addition_info which user saved in last train.
Examples:
.. code-block:: python
exe.run(startup_program)
ckpt_path = ['./output/step_10/model_state_rank0.pdmodel',
'./output/step_10/model_state_rank1.pdmodel']
load_distributed_checkpoint(ckpt_path, main_program)
"""
checkpoint_path = _check_valid_path(checkpoint_path)
dist_attr_path = _check_valid_path(dist_attr_path)
if checkpoint_path and dist_attr_path:
raise NotImplementedError(
"Merge&Slice parameter with dist_attr has not been implemented.")
elif checkpoint_path:
assert len(checkpoint_path) == paddle.distributed.get_world_size(), \
"The number of checkpoint_path must equal to the number of ranks"
rank = paddle.distributed.get_rank()
state_dict_info = paddle.load(checkpoint_path[rank])
state_dict = state_dict_info["model"]
else:
raise ValueError("'checkpoint_path' can not be None.")
program.set_state_dict(state_dict) if program else \
warnings.warn("'Program' is None, parameters will not be loaded.")
if "addition_info" not in state_dict_info:
return
return state_dict_info["addition_info"]
...@@ -38,6 +38,7 @@ list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) ...@@ -38,6 +38,7 @@ list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
list(APPEND DIST_TEST_OPS test_parallel_class_center_sample) list(APPEND DIST_TEST_OPS test_parallel_class_center_sample)
list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy) list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy)
list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard)
list(APPEND DIST_TEST_OPS test_auto_parallel_save_load)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests. #remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
...@@ -253,6 +254,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -253,6 +254,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_cost_model) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_cost_model)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_save_load)
elseif(WITH_GPU) elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100) if (${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
...@@ -1032,6 +1034,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1032,6 +1034,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import random
import numpy as np
import os
import shutil
import paddle
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.process_group import get_all_process_groups
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
PP_MESH_0 = None
PP_MESH_1 = None
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=64,
intermediate_size=4 * 64,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
np.random.seed(2021)
arr = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
weight_attr = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp":
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,start_program), \
utils.unique_name.guard():
batch_size = 4
hidden_size = 64
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(
input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp":
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "mp":
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_distributed_program():
train_program = static.Program()
startup_program = static.Program()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
loss, startup_program)
return dist_main_prog, dist_startup_prog, loss
class TestMLPSaveLoad(unittest.TestCase):
def setUp(self):
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
def test_mlp_dp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
dist_main_prog, dist_start_prog, loss = get_distributed_program()
place = paddle.set_device("gpu")
exe = paddle.static.Executor(place)
exe.run(dist_start_prog)
input = np.random.random(size=(80, 64)).astype('float32')
label = np.random.random(size=(80, 1)).astype('float32')
for step in range(20):
if step == 10:
path = "./output_dp{}".format(paddle.distributed.get_rank())
os.makedirs(path, exist_ok=True)
save_distributed_checkpoint(dist_main_prog, path)
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
},
fetch_list=[loss])
last_res = res[0]
ckpt_path = [
"./output_dp0/model_state_rank0.pdmodel",
"./output_dp1/model_state_rank1.pdmodel"
]
load_distributed_checkpoint(ckpt_path, dist_main_prog)
for step in range(10, 20):
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
},
fetch_list=[loss])
self.assertEqual(last_res, res[0])
shutil.rmtree("./output_dp{}".format(paddle.distributed.get_rank()))
def test_mlp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
dist_main_prog, dist_start_prog, loss = get_distributed_program()
place = paddle.set_device("gpu")
exe = paddle.static.Executor(place)
exe.run(dist_start_prog)
input = np.random.random(size=(80, 64)).astype('float32')
label = np.random.random(size=(80, 1)).astype('float32')
for step in range(20):
if step == 10:
path = "./output_mp{}".format(paddle.distributed.get_rank())
os.makedirs(path, exist_ok=True)
save_distributed_checkpoint(dist_main_prog, path)
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
},
fetch_list=[loss])
last_res = res[0]
ckpt_path = [
"./output_mp0/model_state_rank0.pdmodel",
"./output_mp1/model_state_rank1.pdmodel"
]
load_distributed_checkpoint(ckpt_path, dist_main_prog)
for step in range(10, 20):
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
},
fetch_list=[loss])
self.assertEqual(last_res, res[0])
shutil.rmtree("./output_mp{}".format(paddle.distributed.get_rank()))
def test_mlp_pp(self):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
dist_main_prog, dist_start_prog, loss = get_distributed_program()
place = paddle.set_device("gpu")
exe = paddle.static.Executor(place)
exe.run(dist_start_prog)
input = np.random.random(size=(80, 64)).astype('float32')
label = np.random.random(size=(80, 1)).astype('float32')
for step in range(20):
if step == 10:
path = "./output_pp{}".format(paddle.distributed.get_rank())
os.makedirs(path, exist_ok=True)
save_distributed_checkpoint(dist_main_prog, path)
if paddle.distributed.get_rank() in [0]:
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
})
else:
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
},
fetch_list=[loss])
if paddle.distributed.get_rank() in [1]:
last_res = res[0]
ckpt_path = [
"./output_pp0/model_state_rank0.pdmodel",
"./output_pp1/model_state_rank1.pdmodel"
]
load_distributed_checkpoint(ckpt_path, dist_main_prog)
for step in range(10, 20):
if paddle.distributed.get_rank() in [0]:
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
})
else:
res = exe.run(dist_main_prog,
feed={
"input": input[step * 4:(step + 1) * 4, :],
"label": label[step * 4:(step + 1) * 4, :]
},
fetch_list=[loss])
if paddle.distributed.get_rank() in [1]:
self.assertEqual(last_res, res[0])
shutil.rmtree("./output_pp{}".format(paddle.distributed.get_rank()))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestAutoParallelSaveLoad(TestMultipleGpus):
def test_auto_parallel_save_load(self):
self.run_mnist_2gpu('auto_parallel_save_load.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册