...
 
Commits (4)
    https://gitcode.net/paddlepaddle/Paddle/-/commit/8d42540f595de44146715687e23dd1b82e7b3e3e add paddle.async_save to reduce time cost by checkpoint saving (#55115) 2023-07-21T11:50:52+08:00 Tian 121000916+SylarTiaNII@users.noreply.github.com * add paddle.async_save to reduce time cost by checkpoint saving * adapt save_for_auto_inference to paddle.async_save * modify UT * modify UT * fix on cpu only version * revert commit on save_auto_inference * fix threading https://gitcode.net/paddlepaddle/Paddle/-/commit/9daba606d9e6c3c0b8f6acdb5fba505240e5bc2b make sharding reduce mode by default (#55529) 2023-07-22T10:20:15+08:00 sneaxiy 32832641+sneaxiy@users.noreply.github.com * make sharding reduce mode by default * Update dygraph_sharding_optimizer.py * Update hybrid_parallel_optimizer.py * Update pipeline_parallel.py https://gitcode.net/paddlepaddle/Paddle/-/commit/8520a5b352d4e44d8c3cf796df5b7d4b107ef640 add check for cembedding (#55621) 2023-07-22T14:48:54+08:00 ShenLiang 1422485404@qq.com https://gitcode.net/paddlepaddle/Paddle/-/commit/f275ad2becd4c56475166ae1ee0bb14e0127e82a [Distributed] Support dp/sharding overlap in virtual pp (#55651) 2023-07-26T11:39:29+08:00 ShenLiang 1422485404@qq.com * Add virtual pp and dp overlap * add sharding/dp overlap * add dp/vpp overlap * fix code * fix log
......@@ -87,6 +87,10 @@ class CEmbeddingOpMaker : public framework::OpProtoAndCheckerMaker {
"(int64, default 0), The starting index is indeed, "
"and the out-of-bounds will be set to 0 ")
.SetDefault(0);
AddAttr<int64_t>("vocab_size",
"(int64, default -1), The total vocabulary size to check"
"the out-of-bounds ids. If it is -1, no check will be ")
.SetDefault(-1);
AddComment(R"DOC(
c_embedding Operator.
......
......@@ -42,21 +42,25 @@ __global__ void CEmbedding(T *out,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
const int64_t limit,
const int64_t vocab_size) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
PADDLE_ENFORCE(
id >= 0 && (vocab_size < 0 || id < vocab_size),
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
vocab_size,
id);
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N,
real_idx);
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
......@@ -95,6 +99,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
const auto &dev_ctx = context.template device_context<phi::GPUContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index");
const int64_t vocab_size = context.Attr<int64_t>("vocab_size");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
......@@ -119,7 +125,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
N,
start_idx,
end_idx,
limit);
limit,
vocab_size);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbedding<T, int64_t>
......@@ -131,7 +138,8 @@ class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
N,
start_idx,
end_idx,
limit);
limit,
vocab_size);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"GPU c_embedding ids only support int32 or int64."));
......
......@@ -345,6 +345,7 @@ from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
from .framework import load # noqa: F401
from .framework import async_save, clear_async_save_task_queue # noqa: F401
from .distributed import DataParallel # noqa: F401
from .framework import set_default_dtype # noqa: F401
......
......@@ -124,6 +124,7 @@ class VocabParallelEmbedding(paddle.nn.Layer):
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name
self.num_embeddings = num_embeddings
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
......@@ -151,6 +152,7 @@ class VocabParallelEmbedding(paddle.nn.Layer):
self.weight,
x,
start_index=self.vocab_start_index,
vocab_size=self.num_embeddings,
name=self._name,
)
output = mp_ops._mp_allreduce(
......
......@@ -295,7 +295,7 @@ def _mp_allreduce(
return out
def _c_lookup_table(table, index, start_index=0, name=None):
def _c_lookup_table(table, index, start_index=0, vocab_size=-1, name=None):
"""
Lookup table according to index.
......@@ -311,7 +311,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
"""
if in_dygraph_mode():
return _legacy_C_ops.c_embedding(
table, index, "start_index", start_index
table, index, "start_index", start_index, "vocab_size", vocab_size
)
else:
op_type = 'c_embedding'
......@@ -323,7 +323,7 @@ def _c_lookup_table(table, index, start_index=0, name=None):
type='c_embedding',
inputs={'Ids': index, 'W': table},
outputs={'Out': tmp},
attrs={"start_index": start_index},
attrs={"start_index": start_index, "vocab_size": vocab_size},
)
return tmp
......@@ -655,7 +655,11 @@ def _parallel_embedding(
main_block.vars[weight.name].is_distributed = True
output_parallel = _c_lookup_table(
weight, x, start_index=vocab_start_index, name=name
weight,
x,
start_index=vocab_start_index,
vocab_size=origin_size[0],
name=name,
)
out = _mp_allreduce(
output_parallel,
......
......@@ -24,10 +24,8 @@ from paddle.fluid.dygraph import base as imperative_base
from ...utils.log_util import logger
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0))
if g_shard_norm_align_dp:
assert (
......
......@@ -41,8 +41,7 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = []
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 0))
class HybridParallelClipGrad:
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
import os
import sys
from collections import defaultdict
import paddle
from paddle import framework
......@@ -31,8 +33,7 @@ from .pp_utils.utils import HOOK_ACTION, FusedCommBuffer, assign_group_by_size
__all__ = []
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 1))
# assume only the first stage and last stage need data, and data consumption are ordred;
......@@ -182,7 +183,7 @@ class PipelineParallel(MetaParallelBase):
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."
self._comm_buffers = []
self._chunk_2_comm_buffers = defaultdict(list)
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)
......@@ -256,7 +257,9 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
def register_allreduce_overlap_hook(
self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
):
if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
......@@ -273,7 +276,7 @@ class PipelineParallel(MetaParallelBase):
else HOOK_ACTION.REDUCE
)
for model in models:
for chunk_idx, model in enumerate(models):
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
......@@ -302,12 +305,12 @@ class PipelineParallel(MetaParallelBase):
if not dp:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]
var_groups = assign_group_by_size(parameter_list)
var_groups = assign_group_by_size(parameter_list, group_size)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
)
self._comm_buffers.append(buffer)
self._chunk_2_comm_buffers[chunk_idx].append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
......@@ -403,9 +406,12 @@ class PipelineParallel(MetaParallelBase):
p2p.send_backward(input_tensor_grad, self.is_pipeline_first_stage())
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
assert (
len(self._chunk_2_comm_buffers) > 0
), "comm buffers should be created"
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_and_split_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
......@@ -446,7 +452,7 @@ class PipelineParallel(MetaParallelBase):
self._layers.train()
if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0:
self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False
)
......@@ -767,6 +773,40 @@ class PipelineParallelWithInterleave(PipelineParallel):
return output_tensor
def _overlap_comm_grads(self):
if self._comm_overlap:
self._backward_step_count += 1
sync_step = self._backward_step_count - self.stage_id
if sync_step > 0 and sync_step % self.accumulate_steps == 0:
chunk_idx = self._virtual_pp_world_size - (
sync_step // self.accumulate_steps
)
for buffer in self._chunk_2_comm_buffers[chunk_idx]:
buffer.comm_grads()
if self.stage_id != 0:
if (
self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size
):
for buffer in self._chunk_2_comm_buffers[0]:
buffer.comm_grads()
def _sync_overlap_grads(self):
if self._comm_overlap:
assert (
self._backward_step_count
== self.accumulate_steps * self._virtual_pp_world_size
), "backward step count should be equal to accumulate steps * "
"virtual pp world size, but get {}, excepted result is {}".format(
self._backward_step_count,
self.accumulate_steps * self._virtual_pp_world_size,
)
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_and_split_grads()
def _backward_step_helper(self, micro_step):
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
self.set_virtual_pipeline_rank(virtual_pp_rank)
......@@ -787,8 +827,24 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor, output_tensor, output_tensor_grad
)
self._overlap_comm_grads()
return input_tensor_grad
def bw_hook_func(self, buffer, param):
# For pipeline with interleave, we need to add grad to buffer without communication.
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@paddle.autograd.no_grad()
def fused_allreduce(*_):
buffer.add_grad(param, use_comm=False)
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
super().register_allreduce_overlap_hook(
model, comm_group, acc_steps, dp, group_size=sys.maxsize
)
def forward_backward_pipeline(
self, data, scaler, forward_only=False, compute_loss=True
):
......@@ -806,6 +862,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.micro_batch_id = 0
self._forward_only = forward_only
# store the number of backward steps
self._backward_step_count = 0
# init some data buffers for interleave scheduler
self.input_tensors = [[] for _ in range(self.num_model_chunks)]
self.output_tensors = [[] for _ in range(self.num_model_chunks)]
......@@ -1012,10 +1071,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
)
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_and_split_grads()
self._sync_overlap_grads()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
......
......@@ -218,7 +218,7 @@ class FusedCommBuffer:
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
def add_grad(self, param, use_comm=True):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
......@@ -239,12 +239,17 @@ class FusedCommBuffer:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
if self._all_params_checked_in and use_comm:
self.comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
def comm_grads(self):
assert self._all_params_checked_in, (
"Not all params checked in."
"Parameter number: {}, Check-in number: {}".format(
len(self._params), self._params_checked_in
)
)
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
......@@ -263,9 +268,8 @@ class FusedCommBuffer:
@imperative_base.no_grad
def scale_and_split_grads(self):
assert self._task is not None
assert self._task is not None, "Task is not initialized. "
self._task.wait()
scale_factor = 1.0 / self._comm_group.nranks
self.grad_storage.scale_(scale_factor)
......
......@@ -1049,6 +1049,88 @@ class TestSaveLoad(unittest.TestCase):
)
class TestAsyncSaveLoad(unittest.TestCase):
def setUp(self):
# enable dygraph mode
paddle.disable_static()
# config seed
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def build_and_train_model(self):
# create network
layer = LinearNet()
loss_fn = nn.CrossEntropyLoss()
adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters())
# create data loader
# TODO: using new DataLoader cause unknown Timeout on windows, replace it
loader = random_batch_reader()
# train
train(layer, loader, loss_fn, adam)
return layer, adam
def check_load_state_dict(self, orig_dict, load_dict):
for var_name, value in orig_dict.items():
load_value = (
load_dict[var_name].numpy()
if hasattr(load_dict[var_name], 'numpy')
else np.array(load_dict[var_name])
)
np.testing.assert_array_equal(value.numpy(), load_value)
def test_async_save_load(self):
layer, opt = self.build_and_train_model()
# save
layer_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.linear.pdparams"
)
opt_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.linear.pdopt"
)
layer_state_dict = layer.state_dict()
opt_state_dict = opt.state_dict()
paddle.async_save(
layer_state_dict, layer_save_path, sync_other_task=True
)
paddle.async_save(opt_state_dict, opt_save_path)
paddle.clear_async_save_task_queue()
# load
load_layer_state_dict = paddle.load(layer_save_path)
load_opt_state_dict = paddle.load(opt_save_path)
self.check_load_state_dict(layer_state_dict, load_layer_state_dict)
self.check_load_state_dict(opt_state_dict, load_opt_state_dict)
# test assertion on illegal object
some_tuple_obj = (1, 2, 3)
tuple_save_path = os.path.join(
self.temp_dir.name, "test_paddle_async_save_load.tuple.pdparams"
)
with self.assertRaises(TypeError):
paddle.async_save(some_tuple_obj, tuple_save_path)
# test assertion on static graph
paddle.enable_static()
static_save_path = os.path.join(
self.temp_dir.name,
"static_mode_test/test_paddle_async_save_load.linear.pdparams",
)
with self.assertRaises(ValueError):
paddle.async_save(layer_state_dict, static_save_path)
class TestSaveLoadProgram(unittest.TestCase):
def test_save_load_program(self):
paddle.enable_static()
......
......@@ -33,6 +33,7 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..fluid.dygraph.base import grad # noqa: F401
from .io import save # noqa: F401
from .io import load # noqa: F401
from .io import async_save, clear_async_save_task_queue # noqa: F401
from .io_utils import _open_file_buffer # noqa: F401
from .io_utils import is_parameter # noqa: F401
......
......@@ -17,6 +17,7 @@ import copyreg
import os
import pickle
import sys
import threading
import warnings
from collections.abc import Iterable
......@@ -48,6 +49,81 @@ from .io_utils import (
)
__all__ = []
async_save_queue = []
def clear_async_save_task_queue():
'''
wait until all async save task to be done.
'''
while len(async_save_queue) > 0:
task = async_save_queue.pop()
if task and task.is_alive():
task.join()
def async_save(obj, path, protocol=4, sync_other_task=False, **configs):
'''
async version of paddle.save.
Note:
currently only support dygraph mode.
Note:
any argument passed through configs will be overrided by default setting.
Args:
obj(Object) : The object to be saved.
path(str|BytesIO) : The path/buffer of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 4
sync_other_task(bool) : Determine whether to wait other async save task to be finished before this one be put in queue.
**configs(dict, optional): compatible argument to paddle.save, but will be overrided by default setting.
Examples:
.. code-block:: python
:name: code-example-1
import paddle
emb = paddle.nn.Embedding(10, 10)
layer_state_dict = emb.state_dict()
# call paddle.async_save with the same style of paddle.save
paddle.async_save(layer_state_dict, "emb.pdparams")
for i in range(10):
# do some calculations here
# wait if any async_save task has not been done
paddle.clear_async_task_queue()
'''
if not _non_static_mode():
raise ValueError(
"async_save currently is not supported in static mode."
)
if len(configs) > 0:
warnings.warn(
"configs are not supported in async mode, will be overided by default settings."
)
# TODO: make this part async
def move_state_dict_to_cpu(sd):
for k, v in sd.items():
if isinstance(v, dict):
move_state_dict_to_cpu(v)
elif isinstance(v, core.eager.Tensor):
sd[k] = v.pin_memory() if core.is_compiled_with_cuda() else v
return
if isinstance(obj, dict):
move_state_dict_to_cpu(obj)
elif isinstance(obj, core.eager.Tensor):
obj = obj.pin_memory() if core.is_compiled_with_cuda() else obj
else:
# other types are currently not supported
raise TypeError(
f"currently async_save does not support this type: {type(obj)}"
)
if sync_other_task:
clear_async_save_task_queue()
t = threading.Thread(target=save, args=(obj, path, protocol))
t.start()
async_save_queue.append(t)
def _build_saved_state_dict(state_dict):
......