未验证 提交 4569ae13 编写于 作者: L LiYuRio 提交者: GitHub

Implement reshard from s to r with same process_mesh (#56039)

上级 bfa65993
......@@ -18,16 +18,20 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/pybind.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#endif
namespace py = pybind11;
......@@ -111,7 +115,43 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
void BindAutoParallel(py::module *m) {
#ifdef PADDLE_WITH_DISTRIBUTE
py::class_<phi::distributed::RToSReshardFunction>(*m, "RToSReshardFunction")
auto ReshardFunction =
py::class_<phi::distributed::ReshardFunction>(*m, "ReshardFunction")
.def(
"is_suitable",
[](phi::distributed::ReshardFunction &self,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
return self.IsSuitable(*p_dist, dist_attr);
},
py::call_guard<py::gil_scoped_release>())
.def(
"eval",
[](phi::distributed::ReshardFunction &self,
phi::DeviceContext *dev_ctx,
py::handle py_tensor,
const std::shared_ptr<phi::distributed::TensorDistAttr>
&dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto p_dist =
std::dynamic_pointer_cast<phi::distributed::DistTensor>(
tensor.impl());
auto res_dist = self.Eval(dev_ctx, *p_dist, dist_attr);
return paddle::Tensor(res_dist);
},
py::call_guard<py::gil_scoped_release>());
py::class_<phi::distributed::RToSReshardFunction>(
*m, "RToSReshardFunction", ReshardFunction)
.def(py::init<>());
py::class_<phi::distributed::SToRReshardFunction>(
*m, "SToRReshardFunction", ReshardFunction)
.def(py::init<>());
#endif
......
......@@ -3,8 +3,15 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")
if(WITH_DISTRIBUTE)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
reshard_split_functor.cc r_to_s_reshard_function.cc)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
endif()
collect_srcs(
......
......@@ -48,7 +48,7 @@ bool RToSReshardFunction::IsSuitable(
}
std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
const phi::DeviceContext& dev_ctx,
phi::DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
......@@ -85,7 +85,7 @@ std::shared_ptr<DistTensor> RToSReshardFunction::Eval(
num_of_process, in.dims()[split_axis] / num_of_process));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
......
......@@ -29,7 +29,7 @@ class RToSReshardFunction final : public ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/all_gather_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids) {
DenseTensor out;
int64_t world_size = process_ids.size();
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids);
dev_ctx->SetCommContext(comm_context);
if (phi::CPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const GPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
namespace phi {
class DenseTensor;
class DeviceContext;
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids);
} // namespace distributed
} // namespace phi
......@@ -36,7 +36,7 @@ class ReshardFunction {
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
virtual std::shared_ptr<DistTensor> Eval(
const DeviceContext& dev_ctx,
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) = 0;
};
......
......@@ -16,7 +16,12 @@
#include <cstdlib>
#include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi {
namespace distributed {
......@@ -109,6 +114,21 @@ std::string GetMasterEndpoint() {
return master_endpoint;
}
std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}
} // namespace
std::string GetMasterAddr() {
......@@ -133,5 +153,41 @@ std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
return store;
}
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = GenUniqueCommKey(process_ids);
if (!CommContextManager::GetInstance().Has(unique_comm_key)) {
int64_t world_size = process_ids.size();
int64_t rank = GetLocalRankInParticipate(process_ids);
VLOG(3) << "local world size: " << world_size << " local rank: " << rank;
auto store = CreateOrGetGlobalTCPStore();
if (phi::CPUContext::classof(&dev_ctx)) {
#if defined(PADDLE_WITH_GLOO)
CommContextManager::CreateGlooCommContext(
store, unique_comm_key, rank, world_size);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Cannot use gloo on CPU, please turn PADDLE_WITH_GLOO flag on."));
#endif
} else {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (phi::GPUContext::classof(&dev_ctx)) {
CommContextManager::CreateNCCLCommContext(
store, unique_comm_key, rank, world_size);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"CommContext is only supported on CPU and GPU for now, other devices "
"will be supported later."));
#endif
}
}
auto* comm_context = CommContextManager::GetInstance().Get(unique_comm_key);
return comm_context;
}
} // namespace distributed
} // namespace phi
......@@ -23,7 +23,11 @@
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi {
class DeviceContext;
namespace distributed {
class CommContext;
namespace auto_parallel {
class ProcessMesh;
......@@ -48,6 +52,13 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);
// Create a comm context of the input process_ids. Once the newly comm context
// created, it will be cached in the global instance, and get from the global
// cache later. If the input dev_ctx is GPU, then nccl comm context will be
// created. If the input dev_ctx is CPU, then gloo comm context will be created.
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);
int64_t GetCurGlobalRank();
std::string GetMasterAddr();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace phi {
namespace distributed {
bool SToRReshardFunction::IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
bool flag = true;
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr->dims_mapping();
const auto& out_dims_mapping = out_dist_attr->dims_mapping();
flag &= IsDimsMappingShard(in_dims_mapping);
flag &= IsDimsMappingReplicated(out_dims_mapping);
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& out_process_mesh = out_dist_attr->process_mesh();
flag &= (in_process_mesh.ndim() == 1);
flag &= (out_process_mesh.ndim() == 1);
flag &= (in_process_mesh == out_process_mesh);
return flag;
}
std::shared_ptr<DistTensor> SToRReshardFunction::Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) {
// TODO(liyurui): Only support transfer shard(0) to replicate for now.
// Concat is needed when transfer shard(x) to replicate, will be supported
// later.
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr->process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
// Since the precondition ensure the out_process_ids is equal to the
// in_process_ids, so the participate process ids mush equal to either
// in_process_ids or out_process_ids.
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
return std::make_shared<DistTensor>(
std::make_shared<DenseTensor>(out_all_gather), out_dist_attr);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
namespace phi {
namespace distributed {
class SToRReshardFunction final : public ReshardFunction {
public:
SToRReshardFunction() = default;
~SToRReshardFunction() = default;
bool IsSuitable(
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
std::shared_ptr<DistTensor> Eval(
DeviceContext* dev_ctx,
const DistTensor& in,
const std::shared_ptr<TensorDistAttr>& out_dist_attr) override;
};
} // namespace distributed
} // namespace phi
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -24,4 +25,16 @@ void AllGatherKernel(const Context& dev_ctx,
int nranks,
DenseTensor* out);
template <typename T, typename Context>
void AllGather(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
MetaTensor out_meta(*out);
MetaTensor* out_meta_ptr = &out_meta;
AllGatherInferMeta(phi::MetaTensor(x), nranks, out_meta_ptr);
AllGatherKernel<T, Context>(dev_ctx, x, nranks, out);
}
} // namespace phi
......@@ -61,5 +61,6 @@ PD_REGISTER_KERNEL(all_gather,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
......@@ -69,6 +69,7 @@ PD_REGISTER_KERNEL(all_gather,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
bool,
phi::dtype::bfloat16,
......@@ -83,6 +84,7 @@ PD_REGISTER_KERNEL(all_gather,
int,
uint8_t,
int8_t,
int16_t,
int64_t,
bool,
phi::dtype::float16) {}
......
......@@ -77,6 +77,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
set_tests_properties(test_pass_quantization
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 60)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.fluid import core
class TestReshardSToR:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._shard = eval(os.getenv("shard"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
place = paddle.CPUPlace()
elif self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())
dev_ctx = core.DeviceContext.create(place)
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
in_shard_specs[self._shard] = "x"
out_shard_specs = [None for i in range(len(self._shape))]
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
reshard_func = core.SToRReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
out_shape = list(self._shape)
out_shape[self._shard] = out_shape[self._shard] * 2
assert np.equal(out.shape, out_shape).all()
if __name__ == '__main__':
TestReshardSToR().run_test_case()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import collective.test_communication_api_base as test_base
class TestReshardSToR(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
"shard": "0",
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
}
def test_reshard_s_to_r(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_s_to_r.py",
user_defined_envs=envs,
)
if __name__ == "__main__":
unittest.main()
......@@ -34,6 +34,7 @@ class CommunicationTestDistBase(unittest.TestCase):
def run_test_case(self, script_file, user_defined_envs=None):
runtime_envs = os.environ
if user_defined_envs is not None:
runtime_envs.update(user_defined_envs)
runtime_envs["CUDA_VISIBLE_DEVICES"] = self._devices
start_command = f"{self._python_interp} -u -m paddle.distributed.launch --log_dir {self._log_dir.name} --devices {self._devices} {script_file}"
......
......@@ -114,7 +114,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_cpu_1d_mesh) {
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(*context, *input, out_dist_attr);
r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 12);
......@@ -136,7 +136,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
std::shared_ptr<TensorDistAttr> out_dist_attr =
std::make_shared<TensorDistAttr>(tensor_shape);
std::vector<int64_t> out_dims_mapping = {0, -1};
std::vector<int64_t> out_dims_mapping = {0, -1, -1};
out_dist_attr->set_dims_mapping(out_dims_mapping);
out_dist_attr->set_process_mesh(mesh);
......@@ -145,7 +145,7 @@ TEST(reshard_r_to_s, r_to_s_same_placement_gpu_1d_mesh) {
RToSReshardFunction r_to_s_func;
std::shared_ptr<DistTensor> output =
r_to_s_func.Eval(*context, *input, out_dist_attr);
r_to_s_func.Eval(context, *input, out_dist_attr);
CHECK_EQ(r_to_s_func.IsSuitable(*input, out_dist_attr), true);
CHECK_EQ(output->numel(), 32);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册