未验证 提交 5425ad7f 编写于 作者: L LiYuRio 提交者: GitHub

use macro instead of functor (#56726)

上级 0bc369ef
...@@ -28,8 +28,8 @@ limitations under the License. */ ...@@ -28,8 +28,8 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/core/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
namespace py = pybind11; namespace py = pybind11;
......
...@@ -12,8 +12,5 @@ collect_srcs( ...@@ -12,8 +12,5 @@ collect_srcs(
dist_meta_tensor.cc dist_meta_tensor.cc
inferspmd_utils.cc inferspmd_utils.cc
reshard_function.cc reshard_function.cc
reshard_split_functor.cc
reshard_concat_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc r_to_s_reshard_function.cc
s_to_r_reshard_function.cc) s_to_r_reshard_function.cc)
...@@ -15,12 +15,10 @@ ...@@ -15,12 +15,10 @@
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/kernels/split_kernel.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -73,17 +71,21 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx, ...@@ -73,17 +71,21 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
BalancedSplit(in.dims()[split_axis], num_of_process); BalancedSplit(in.dims()[split_axis], num_of_process);
IntArray sections(split_num_vec); IntArray sections(split_num_vec);
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor( std::vector<DenseTensor> split_out_vec;
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis); auto dtype = in_physical_tensor_cur_rank.dtype();
RESHARD_FUNCTOR(dev_ctx,
Split,
dtype,
in_physical_tensor_cur_rank,
sections,
split_axis,
&split_out_vec);
VLOG(3) << "The current process will remain the idx " VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor"; << coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]]; SetValue(out, split_out_vec[coord_in_mesh[mesh_axis]]);
VLOG(3) << "The shape of physical tensor after split is " SetDistProps(out, in.dims(), out_dist_attr);
<< out_physical_tensor_cur_rank.dims();
set_dist_props(out, out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
} }
REGISTER_RESHARD_FUNC(RToSReshardFunction); REGISTER_RESHARD_FUNC(RToSReshardFunction);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/all_gather_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids) {
DenseTensor out;
int64_t world_size = process_ids.size();
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids);
dev_ctx->SetCommContext(comm_context);
if (phi::CPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const GPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (phi::CustomContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CustomContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
namespace phi {
class DenseTensor;
class DeviceContext;
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids);
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis) {
DenseTensor result;
auto dtype = (*input.begin())->dtype();
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const CPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const GPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The concat in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
...@@ -29,20 +29,26 @@ std::shared_ptr<DistTensor> ReshardFunction::Eval( ...@@ -29,20 +29,26 @@ std::shared_ptr<DistTensor> ReshardFunction::Eval(
return out; return out;
} }
void ReshardFunction::set_dist_props(DistTensor* tensor, void ReshardFunction::SetValue(DistTensor* tensor, const DenseTensor& value) {
const DenseTensor& value, tensor->value_ = value;
const DDim& dims, }
const TensorDistAttr& dist_attr) {
void ReshardFunction::SetDistProps(DistTensor* tensor,
const DDim& dims,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)), PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)),
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The input dist_attr and dims are improper.")); "The input dist_attr and dims are improper."));
tensor->value_ = value;
tensor->dims_ = dims; tensor->dims_ = dims;
tensor->dist_attr_ = dist_attr; tensor->dist_attr_ = dist_attr;
} }
DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {
return &tensor->value_;
}
ReshardFunction* ChooseProperReshardFunction( ReshardFunction* ChooseProperReshardFunction(
const DistTensor& in, const TensorDistAttr& out_dist_attr) { const DistTensor& in, const TensorDistAttr& out_dist_attr) {
for (const auto& func : GetReshardFunctionList()) { for (const auto& func : GetReshardFunctionList()) {
......
...@@ -44,10 +44,11 @@ class ReshardFunction { ...@@ -44,10 +44,11 @@ class ReshardFunction {
DistTensor* out) = 0; DistTensor* out) = 0;
protected: protected:
void set_dist_props(DistTensor* tensor, void SetValue(DistTensor* tensor, const DenseTensor& value);
const DenseTensor& value, void SetDistProps(DistTensor* tensor,
const DDim& dims, const DDim& dims,
const TensorDistAttr& dist_attr); const TensorDistAttr& dist_attr);
DenseTensor* GetMutableTensor(DistTensor* tensor);
}; };
std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList(); std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/split_kernel.h"
namespace phi {
namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis) {
std::vector<DenseTensor> result;
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The split in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <vector>
#include "paddle/phi/common/int_array.h"
namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis);
} // namespace distributed
} // namespace phi
...@@ -14,22 +14,31 @@ ...@@ -14,22 +14,31 @@
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include <cstdlib>
// the <winsock2.h> needs to be included before <winsock.h>, otherwise
// there will be symbol redefinition error on windows
#include "paddle/phi/core/distributed/store/tcp_store.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
using auto_parallel::str_split;
namespace {
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}
std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}
} // namespace
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) { bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(), return std::any_of(dims_mapping.begin(),
...@@ -70,93 +79,6 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) { ...@@ -70,93 +79,6 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
return coord; return coord;
} }
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping) {
std::map<int64_t, int64_t> split_axis_to_mesh_axis;
for (size_t i = 0; i < dims_mapping.size(); ++i) {
if (dims_mapping[i] != -1) {
split_axis_to_mesh_axis.emplace(i, dims_mapping[i]);
}
}
return split_axis_to_mesh_axis;
}
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}
namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}
PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}
std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}
} // namespace
std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}
uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);
static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids) { const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = GenUniqueCommKey(process_ids); std::string unique_comm_key = GenUniqueCommKey(process_ids);
...@@ -202,6 +124,17 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, ...@@ -202,6 +124,17 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
return comm_context; return comm_context;
} }
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping) {
std::map<int64_t, int64_t> split_axis_to_mesh_axis;
for (size_t i = 0; i < dims_mapping.size(); ++i) {
if (dims_mapping[i] != -1) {
split_axis_to_mesh_axis.emplace(i, dims_mapping[i]);
}
}
return split_axis_to_mesh_axis;
}
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) { std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) {
std::vector<int64_t> result(num_of_pieces, total_nums / num_of_pieces); std::vector<int64_t> result(num_of_pieces, total_nums / num_of_pieces);
int64_t remain_nums = total_nums % num_of_pieces; int64_t remain_nums = total_nums % num_of_pieces;
......
...@@ -20,13 +20,14 @@ ...@@ -20,13 +20,14 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
namespace phi { namespace phi {
class DeviceContext; class DeviceContext;
namespace distributed { namespace distributed {
class CommContext;
class TCPStore;
class ProcessMesh; class ProcessMesh;
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping); bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);
...@@ -46,6 +47,11 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh); ...@@ -46,6 +47,11 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping( std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping); const std::vector<int64_t>& dims_mapping);
// If given a number, balance split it to multiple pieces.
// For example, the input value is 12, split it to 5 pieces, then return
// {3, 3, 2, 2, 2}.
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
// Create a comm context of the input process_ids. Once the newly comm context // Create a comm context of the input process_ids. Once the newly comm context
// created, it will be cached in the global instance, and get from the global // created, it will be cached in the global instance, and get from the global
// cache later. If the input dev_ctx is GPU, then nccl comm context will be // cache later. If the input dev_ctx is GPU, then nccl comm context will be
...@@ -53,20 +59,54 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping( ...@@ -53,20 +59,54 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx, CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids); const std::vector<int64_t>& process_ids);
int64_t GetCurGlobalRank(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
std::string GetMasterAddr(); do { \
if (phi::CPUContext::classof(dev_ctx)) { \
int64_t GetGlobalWorldSize(); PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
uint16_t GetMasterPort(); fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore(); })); \
} else if (phi::GPUContext::classof(dev_ctx)) { \
// If given a number, balance split it to multiple pieces. PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
// For example, the input value is 12, split it to 5 pieces, then return dtype, #fn_name, ([&] { \
// {3, 3, 2, 2, 2}. fn_name<data_t>(static_cast<const GPUContext&>(*dev_ctx), \
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces); __VA_ARGS__); \
})); \
} else { \
PADDLE_THROW(phi::errors::Unimplemented( \
"The %s in reshard only supported on CPU and GPU for now.", \
#fn_name)); \
} \
} while (0)
#else
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else { \
PADDLE_THROW(phi::errors::Unimplemented( \
"The %s in reshard only supported on CPU for now.", #fn_name)); \
} \
} while (0)
#endif
#define RESHARD_FUNCTOR_WITH_COMM(dev_ctx, fn_name, dtype, process_ids, ...) \
do { \
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids); \
dev_ctx->SetCommContext(comm_context); \
RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, __VA_ARGS__); \
} while (0)
#define RESHARD_FUNCTOR(dev_ctx, fn_name, dtype, ...) \
do { \
RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, __VA_ARGS__); \
} while (0)
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/kernels/all_gather_kernel.h"
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
namespace phi { namespace phi {
namespace distributed { namespace distributed {
...@@ -60,17 +60,22 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, ...@@ -60,17 +60,22 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in, const DistTensor& in,
const TensorDistAttr& out_dist_attr, const TensorDistAttr& out_dist_attr,
DistTensor* out) { DistTensor* out) {
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr(); const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping(); const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids(); const auto& in_process_ids = in_process_mesh.process_ids();
auto dtype = in.dtype();
// Since the precondition ensure the out_process_ids is equal to the // Since the precondition ensure the out_process_ids is equal to the
// in_process_ids, so the participate process ids mush equal to either // in_process_ids, so the participate process ids mush equal to either
// in_process_ids or out_process_ids. // in_process_ids or out_process_ids.
DenseTensor out_all_gather = ReshardAllGatherFunctor( RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
dev_ctx, in_physical_tensor_cur_rank, in_process_ids); AllGather,
dtype,
in_process_ids,
in.value(),
in_process_ids.size(),
GetMutableTensor(out));
std::map<int64_t, int64_t> split_axis_to_mesh_axis = std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping); GetSplitAxisWithDimsMapping(in_dims_mapping);
...@@ -79,7 +84,7 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, ...@@ -79,7 +84,7 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
if (split_axis == 0) { if (split_axis == 0) {
// If the input dist tensor is shard(0), the subsequent split // If the input dist tensor is shard(0), the subsequent split
// and concat is unnecessary. // and concat is unnecessary.
set_dist_props(out, out_all_gather, out_all_gather.dims(), out_dist_attr); SetDistProps(out, in.dims(), out_dist_attr);
} else { } else {
// Since the result of all_gather always concat the tensor on axis 0, // Since the result of all_gather always concat the tensor on axis 0,
// first we need to split the result on axis 0, // first we need to split the result on axis 0,
...@@ -88,21 +93,30 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, ...@@ -88,21 +93,30 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
int64_t num_of_process = in_process_ids.size(); int64_t num_of_process = in_process_ids.size();
IntArray sections(std::vector<int64_t>( IntArray sections(std::vector<int64_t>(
num_of_process, num_of_process, in.value().dims()[default_split_axis]));
in_physical_tensor_cur_rank.dims()[default_split_axis])); std::vector<DenseTensor> split_out_vec;
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor( RESHARD_FUNCTOR(dev_ctx,
*dev_ctx, out_all_gather, sections, default_split_axis); Split,
dtype,
out->value(),
sections,
default_split_axis,
&split_out_vec);
// Concat the result after split on correct axis. // Concat the result after split on correct axis.
std::vector<const DenseTensor*> concat_input_vec; std::vector<const DenseTensor*> concat_input_vec;
for (const auto& tensor : split_out_vec) { for (const auto& tensor : split_out_vec) {
concat_input_vec.emplace_back(&tensor); concat_input_vec.emplace_back(&tensor);
} }
DenseTensor concat_out_tensor =
ReshardConcatFunctor(*dev_ctx, concat_input_vec, split_axis);
set_dist_props( RESHARD_FUNCTOR(dev_ctx,
out, concat_out_tensor, concat_out_tensor.dims(), out_dist_attr); Concat,
dtype,
concat_input_vec,
split_axis,
GetMutableTensor(out));
SetDistProps(out, in.dims(), out_dist_attr);
} }
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <iterator>
#include <map> #include <map>
#include <sstream> #include <sstream>
#include <string> #include <string>
......
set(STORE_COMMON_SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc) set(STORE_COMMON_SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc
store_utils.cc)
if(WITH_GLOO) if(WITH_GLOO)
list(APPEND STORE_COMMON_SRCS gloo_store.cc) list(APPEND STORE_COMMON_SRCS gloo_store.cc)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/store/store_utils.h"
#include <cstdlib>
// the <winsock2.h> needs to be included before <winsock.h>, otherwise
// there will be symbol redefinition error on windows
#include "paddle/phi/core/distributed/store/tcp_store.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace phi {
namespace distributed {
using auto_parallel::str_split;
namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}
PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}
} // namespace
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}
std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}
uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}
std::shared_ptr<Store> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);
static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}
} // namespace distributed
} // namespace phi
...@@ -15,16 +15,22 @@ ...@@ -15,16 +15,22 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <vector> #include <memory>
#include <string>
namespace phi { namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed { namespace distributed {
class Store;
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx, int64_t GetCurGlobalRank();
const std::vector<const DenseTensor*>& input,
int64_t axis); std::string GetMasterAddr();
int64_t GetGlobalWorldSize();
uint16_t GetMasterPort();
std::shared_ptr<Store> CreateOrGetGlobalTCPStore();
} // namespace distributed } // namespace distributed
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册