未验证 提交 5425ad7f 编写于 作者: L LiYuRio 提交者: GitHub

use macro instead of functor (#56726)

上级 0bc369ef
......@@ -28,8 +28,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace py = pybind11;
......
......@@ -12,8 +12,5 @@ collect_srcs(
dist_meta_tensor.cc
inferspmd_utils.cc
reshard_function.cc
reshard_split_functor.cc
reshard_concat_functor.cc
reshard_all_gather_functor.cc
r_to_s_reshard_function.cc
s_to_r_reshard_function.cc)
......@@ -15,12 +15,10 @@
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "glog/logging.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/kernels/split_kernel.h"
namespace phi {
namespace distributed {
......@@ -73,17 +71,21 @@ void RToSReshardFunction::Eval(phi::DeviceContext* dev_ctx,
BalancedSplit(in.dims()[split_axis], num_of_process);
IntArray sections(split_num_vec);
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, in_physical_tensor_cur_rank, sections, split_axis);
std::vector<DenseTensor> split_out_vec;
auto dtype = in_physical_tensor_cur_rank.dtype();
RESHARD_FUNCTOR(dev_ctx,
Split,
dtype,
in_physical_tensor_cur_rank,
sections,
split_axis,
&split_out_vec);
VLOG(3) << "The current process will remain the idx "
<< coord_in_mesh[mesh_axis] << " piece of tensor";
out_physical_tensor_cur_rank = split_out_vec[coord_in_mesh[mesh_axis]];
VLOG(3) << "The shape of physical tensor after split is "
<< out_physical_tensor_cur_rank.dims();
set_dist_props(out, out_physical_tensor_cur_rank, in.dims(), out_dist_attr);
SetValue(out, split_out_vec[coord_in_mesh[mesh_axis]]);
SetDistProps(out, in.dims(), out_dist_attr);
}
REGISTER_RESHARD_FUNC(RToSReshardFunction);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/all_gather_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids) {
DenseTensor out;
int64_t world_size = process_ids.size();
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids);
dev_ctx->SetCommContext(comm_context);
if (phi::CPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const GPUContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (phi::CustomContext::classof(dev_ctx)) {
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(
input.dtype(), "AllGather", ([&] {
AllGather<data_t>(static_cast<const CustomContext&>(*dev_ctx),
input,
world_size,
&out);
}));
return out;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The all_gather in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <vector>
namespace phi {
class DenseTensor;
class DeviceContext;
namespace distributed {
DenseTensor ReshardAllGatherFunctor(DeviceContext* dev_ctx,
const DenseTensor& input,
const std::vector<int64_t>& process_ids);
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace phi {
namespace distributed {
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis) {
DenseTensor result;
auto dtype = (*input.begin())->dtype();
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const CPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(
dtype, "Concat", ([&] {
Concat<data_t>(
static_cast<const GPUContext&>(dev_ctx), input, axis, &result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The concat in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
......@@ -29,8 +29,11 @@ std::shared_ptr<DistTensor> ReshardFunction::Eval(
return out;
}
void ReshardFunction::set_dist_props(DistTensor* tensor,
const DenseTensor& value,
void ReshardFunction::SetValue(DistTensor* tensor, const DenseTensor& value) {
tensor->value_ = value;
}
void ReshardFunction::SetDistProps(DistTensor* tensor,
const DDim& dims,
const TensorDistAttr& dist_attr) {
PADDLE_ENFORCE_EQ(dist_attr.verify(vectorize(dims)),
......@@ -38,11 +41,14 @@ void ReshardFunction::set_dist_props(DistTensor* tensor,
phi::errors::InvalidArgument(
"The input dist_attr and dims are improper."));
tensor->value_ = value;
tensor->dims_ = dims;
tensor->dist_attr_ = dist_attr;
}
DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {
return &tensor->value_;
}
ReshardFunction* ChooseProperReshardFunction(
const DistTensor& in, const TensorDistAttr& out_dist_attr) {
for (const auto& func : GetReshardFunctionList()) {
......
......@@ -44,10 +44,11 @@ class ReshardFunction {
DistTensor* out) = 0;
protected:
void set_dist_props(DistTensor* tensor,
const DenseTensor& value,
void SetValue(DistTensor* tensor, const DenseTensor& value);
void SetDistProps(DistTensor* tensor,
const DDim& dims,
const TensorDistAttr& dist_attr);
DenseTensor* GetMutableTensor(DistTensor* tensor);
};
std::vector<std::unique_ptr<ReshardFunction>>& GetReshardFunctionList();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/split_kernel.h"
namespace phi {
namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis) {
std::vector<DenseTensor> result;
if (phi::CPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const CPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (phi::GPUContext::classof(&dev_ctx)) {
PD_VISIT_ALL_TYPES(input.dtype(), "Split", ([&] {
Split<data_t>(static_cast<const GPUContext&>(dev_ctx),
input,
sections,
axis,
&result);
}));
return result;
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
"The split in reshard only supported on CPU and GPU for now."));
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <vector>
#include "paddle/phi/common/int_array.h"
namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {
std::vector<DenseTensor> ReshardSplitFunctor(const DeviceContext& dev_ctx,
const DenseTensor& input,
const IntArray& sections,
int64_t axis);
} // namespace distributed
} // namespace phi
......@@ -14,22 +14,31 @@
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include <cstdlib>
// the <winsock2.h> needs to be included before <winsock.h>, otherwise
// there will be symbol redefinition error on windows
#include "paddle/phi/core/distributed/store/tcp_store.h"
#include "glog/logging.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/store_utils.h"
namespace phi {
namespace distributed {
using auto_parallel::str_split;
namespace {
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}
std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}
} // namespace
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(),
......@@ -70,93 +79,6 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
return coord;
}
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping) {
std::map<int64_t, int64_t> split_axis_to_mesh_axis;
for (size_t i = 0; i < dims_mapping.size(); ++i) {
if (dims_mapping[i] != -1) {
split_axis_to_mesh_axis.emplace(i, dims_mapping[i]);
}
}
return split_axis_to_mesh_axis;
}
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}
namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}
PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}
std::string GenUniqueCommKey(const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = "ReshardGroup";
for (const auto& id : process_ids) {
unique_comm_key += "/" + std::to_string(id);
}
return unique_comm_key;
}
int64_t GetLocalRankInParticipate(const std::vector<int64_t>& process_ids) {
int64_t cur_global_rank = GetCurGlobalRank();
auto iter =
std::find(process_ids.begin(), process_ids.end(), cur_global_rank);
return iter - process_ids.begin();
}
} // namespace
std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}
uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);
static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids) {
std::string unique_comm_key = GenUniqueCommKey(process_ids);
......@@ -202,6 +124,17 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
return comm_context;
}
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping) {
std::map<int64_t, int64_t> split_axis_to_mesh_axis;
for (size_t i = 0; i < dims_mapping.size(); ++i) {
if (dims_mapping[i] != -1) {
split_axis_to_mesh_axis.emplace(i, dims_mapping[i]);
}
}
return split_axis_to_mesh_axis;
}
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces) {
std::vector<int64_t> result(num_of_pieces, total_nums / num_of_pieces);
int64_t remain_nums = total_nums % num_of_pieces;
......
......@@ -20,13 +20,14 @@
#include <string>
#include <vector>
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/visit_type.h"
namespace phi {
class DeviceContext;
namespace distributed {
class CommContext;
class TCPStore;
class ProcessMesh;
bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);
......@@ -46,6 +47,11 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);
// If given a number, balance split it to multiple pieces.
// For example, the input value is 12, split it to 5 pieces, then return
// {3, 3, 2, 2, 2}.
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
// Create a comm context of the input process_ids. Once the newly comm context
// created, it will be cached in the global instance, and get from the global
// cache later. If the input dev_ctx is GPU, then nccl comm context will be
......@@ -53,20 +59,54 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);
int64_t GetCurGlobalRank();
std::string GetMasterAddr();
int64_t GetGlobalWorldSize();
uint16_t GetMasterPort();
std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();
// If given a number, balance split it to multiple pieces.
// For example, the input value is 12, split it to 5 pieces, then return
// {3, 3, 2, 2, 2}.
std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else if (phi::GPUContext::classof(dev_ctx)) { \
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const GPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else { \
PADDLE_THROW(phi::errors::Unimplemented( \
"The %s in reshard only supported on CPU and GPU for now.", \
#fn_name)); \
} \
} while (0)
#else
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else { \
PADDLE_THROW(phi::errors::Unimplemented( \
"The %s in reshard only supported on CPU for now.", #fn_name)); \
} \
} while (0)
#endif
#define RESHARD_FUNCTOR_WITH_COMM(dev_ctx, fn_name, dtype, process_ids, ...) \
do { \
auto* comm_context = CreateOrGetCommContext(*dev_ctx, process_ids); \
dev_ctx->SetCommContext(comm_context); \
RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, __VA_ARGS__); \
} while (0)
#define RESHARD_FUNCTOR(dev_ctx, fn_name, dtype, ...) \
do { \
RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, __VA_ARGS__); \
} while (0)
} // namespace distributed
} // namespace phi
......@@ -18,10 +18,10 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_all_gather_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_concat_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_split_functor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/kernels/all_gather_kernel.h"
#include "paddle/phi/kernels/concat_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
namespace phi {
namespace distributed {
......@@ -60,17 +60,22 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
const DenseTensor& in_physical_tensor_cur_rank = in.value();
const auto& in_dist_attr = in.dist_attr();
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
auto dtype = in.dtype();
// Since the precondition ensure the out_process_ids is equal to the
// in_process_ids, so the participate process ids mush equal to either
// in_process_ids or out_process_ids.
DenseTensor out_all_gather = ReshardAllGatherFunctor(
dev_ctx, in_physical_tensor_cur_rank, in_process_ids);
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
AllGather,
dtype,
in_process_ids,
in.value(),
in_process_ids.size(),
GetMutableTensor(out));
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
GetSplitAxisWithDimsMapping(in_dims_mapping);
......@@ -79,7 +84,7 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
if (split_axis == 0) {
// If the input dist tensor is shard(0), the subsequent split
// and concat is unnecessary.
set_dist_props(out, out_all_gather, out_all_gather.dims(), out_dist_attr);
SetDistProps(out, in.dims(), out_dist_attr);
} else {
// Since the result of all_gather always concat the tensor on axis 0,
// first we need to split the result on axis 0,
......@@ -88,21 +93,30 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
int64_t num_of_process = in_process_ids.size();
IntArray sections(std::vector<int64_t>(
num_of_process,
in_physical_tensor_cur_rank.dims()[default_split_axis]));
std::vector<DenseTensor> split_out_vec = ReshardSplitFunctor(
*dev_ctx, out_all_gather, sections, default_split_axis);
num_of_process, in.value().dims()[default_split_axis]));
std::vector<DenseTensor> split_out_vec;
RESHARD_FUNCTOR(dev_ctx,
Split,
dtype,
out->value(),
sections,
default_split_axis,
&split_out_vec);
// Concat the result after split on correct axis.
std::vector<const DenseTensor*> concat_input_vec;
for (const auto& tensor : split_out_vec) {
concat_input_vec.emplace_back(&tensor);
}
DenseTensor concat_out_tensor =
ReshardConcatFunctor(*dev_ctx, concat_input_vec, split_axis);
set_dist_props(
out, concat_out_tensor, concat_out_tensor.dims(), out_dist_attr);
RESHARD_FUNCTOR(dev_ctx,
Concat,
dtype,
concat_input_vec,
split_axis,
GetMutableTensor(out));
SetDistProps(out, in.dims(), out_dist_attr);
}
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <iterator>
#include <map>
#include <sstream>
#include <string>
......
set(STORE_COMMON_SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc)
set(STORE_COMMON_SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc
store_utils.cc)
if(WITH_GLOO)
list(APPEND STORE_COMMON_SRCS gloo_store.cc)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/store/store_utils.h"
#include <cstdlib>
// the <winsock2.h> needs to be included before <winsock.h>, otherwise
// there will be symbol redefinition error on windows
#include "paddle/phi/core/distributed/store/tcp_store.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
namespace phi {
namespace distributed {
using auto_parallel::str_split;
namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}
PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}
} // namespace
int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}
int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}
std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}
uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}
std::shared_ptr<Store> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);
static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}
} // namespace distributed
} // namespace phi
......@@ -15,16 +15,22 @@
#pragma once
#include <cstdint>
#include <vector>
#include <memory>
#include <string>
namespace phi {
class DeviceContext;
class DenseTensor;
namespace distributed {
class Store;
DenseTensor ReshardConcatFunctor(const DeviceContext& dev_ctx,
const std::vector<const DenseTensor*>& input,
int64_t axis);
int64_t GetCurGlobalRank();
std::string GetMasterAddr();
int64_t GetGlobalWorldSize();
uint16_t GetMasterPort();
std::shared_ptr<Store> CreateOrGetGlobalTCPStore();
} // namespace distributed
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册