未验证 提交 8c214b6a 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add MOE support, PART2 (#54573)

上级 8771fff3
......@@ -298,15 +298,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) {
return AllGather(out_tensor, in_tensor, offset, numel, sync_op, false);
}
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
......@@ -382,15 +373,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op // for compatibility, no use now
) {
return AllReduce(out_tensor, in_tensor, opts, sync_op, false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......@@ -478,14 +460,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(out_tensor, in_tensor, opts, sync_op, false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
const BarrierOptions& opts) {
// Only support single card single process
......@@ -759,6 +733,397 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduce(device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream);
},
CommType::REDUCE,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
const phi::DDim& out_dim = out_tensor->dims();
const phi::DDim& in_dim = in_tensor.dims();
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
// NOTE: Since `all_to_all` needs other processes' participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks with
// FLAGS_enable_nccl_dynamic_check.
phi::distributed::CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int64_t in_row_size = in_tensor.numel() / in_dim[0],
out_row_size = out_tensor->numel() / out_dim[0];
int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0;
phi::DenseTensor input_partial, output_partial;
std::vector<void*> send_buf, recv_buf;
std::vector<size_t> send_count, recv_count;
std::vector<phi::ccl::CCLDataType> send_dtype, recv_dtype;
for (auto i = 0; i < size_; i++) {
in_numel = in_size_each_rank[i] * in_row_size;
input_partial = GetPartialTensor(in_tensor, in_offset, in_numel);
out_numel = out_size_each_rank[i] * out_row_size;
output_partial = GetPartialTensor(*out_tensor, out_offset, out_numel);
in_offset += in_numel;
out_offset += out_numel;
send_buf.push_back(input_partial.data());
recv_buf.push_back(output_partial.data());
send_count.push_back(in_numel);
recv_count.push_back(out_numel);
send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype()));
recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype()));
}
phi::DeviceManager::CCLAllToAll(
device_type_,
const_cast<const void**>(send_buf.data()),
send_count.data(),
send_dtype.data(),
recv_buf.data(),
recv_count.data(),
recv_dtype.data(),
rank_,
size_,
comm,
stream);
},
CommType::ALLTOALL,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CustomPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
size_t offset = 0;
std::vector<void*> send_buf, recv_buf;
std::vector<size_t> send_count(size_, input.numel() / size_),
recv_count(size_, input.numel() / size_);
std::vector<phi::ccl::CCLDataType> send_dtype(
size_, phi::ccl::ToCCLDataType(input.dtype())),
recv_dtype(size_, phi::ccl::ToCCLDataType(input.dtype()));
for (auto i = 0; i < size_; i++) {
send_buf.push_back(
GetPointerByOffset(input.data(), offset, input.dtype()));
recv_buf.push_back(
GetPointerByOffset(output.data(), offset, input.dtype()));
offset += input.numel() / size_;
}
phi::DeviceManager::CCLAllToAll(
device_type_,
const_cast<const void**>(send_buf.data()),
send_count.data(),
send_dtype.data(),
recv_buf.data(),
recv_count.data(),
recv_dtype.data(),
rank_,
size_,
comm,
stream);
},
CommType::ALLTOALL,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduceScatter(
device_type_,
const_cast<void*>(in_tensor.data()),
out_tensor->data(),
out_tensor->numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
paddle::distributed::ToCustomCCLRedType(opts.reduce_op),
comm,
stream);
},
CommType::REDUCE_SCATTER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::ScatterLikeShape(
*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int64_t numel = in_tensor.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(in_tensor, offset, numel);
if (i != rank_) {
phi::DeviceManager::CCLSend(
device_type_,
partial_tensor.data(),
numel,
phi::ccl::ToCCLDataType(partial_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(out_tensor->data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
phi::DeviceManager::CCLRecv(
device_type_,
out_tensor->data(),
numel,
phi::ccl::ToCCLDataType(out_tensor->dtype()),
opts.root_rank,
comm,
stream);
}
},
CommType::SCATTER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(in_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCustomPlace(out_tensors, device_type_),
true,
phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
int64_t numel = input.numel() / size_;
if (rank_ == opts.root_rank) {
int64_t offset = 0;
phi::DenseTensor partial_tensor;
for (auto i = 0; i < size_; i++) {
partial_tensor = GetPartialTensor(input, offset, numel);
if (i != rank_) {
phi::DeviceManager::CCLSend(
device_type_,
partial_tensor.data(),
numel,
phi::ccl::ToCCLDataType(partial_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(output.data(),
partial_tensor.data(),
numel * phi::SizeOf(partial_tensor.dtype()),
&stream);
}
offset += numel;
}
} else {
phi::DeviceManager::CCLRecv(device_type_,
output.data(),
numel,
phi::ccl::ToCCLDataType(output.dtype()),
opts.root_rank,
comm,
stream);
}
},
CommType::SCATTER,
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> partial_tensors;
if (rank_ == opts.root_rank) {
partial_tensors.reserve(size_);
size_t offset = 0;
size_t numel = out_tensor->numel() / size_;
for (auto i = 0; i < size_; i++) {
partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel));
offset += numel;
}
}
return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto& gather_tensors = *gather_tensors_ptr;
PADDLE_ENFORCE_GT(size_,
opts.root_rank,
phi::errors::InvalidArgument(
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
return Collective(
in_wrapper,
in_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
// root receive from all devices
if (rank_ == opts.root_rank) {
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
if (i != rank_) {
phi::DeviceManager::CCLRecv(
device_type_,
gather_tensor.data(),
gather_tensor.numel(),
phi::ccl::ToCCLDataType(gather_tensor.dtype()),
i,
comm,
stream);
} else {
phi::DeviceManager::GetDeviceWithPlace(stream.GetPlace())
->MemoryCopyD2D(
gather_tensor.data(),
in_tensor.data(),
in_tensor.numel() * phi::SizeOf(in_tensor.dtype()),
&stream);
}
}
} else {
// send to root
phi::DeviceManager::CCLSend(
device_type_,
const_cast<void*>(in_tensor.data()),
in_tensor.numel(),
phi::ccl::ToCCLDataType(in_tensor.dtype()),
opts.root_rank,
comm,
stream);
}
},
CommType::GATHER,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
......
......@@ -100,13 +100,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -119,12 +112,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -137,12 +124,6 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
......@@ -169,6 +150,54 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ReduceOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Scatter(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions& opts) override;
std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) override;
protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places,
......@@ -206,6 +235,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream);
template <typename Fn>
std::shared_ptr<ProcessGroup::Task> Collective(Fn fn,
CommType op_type,
bool sync_op,
bool use_calc_stream);
void CreateCustomManagerCache(const std::string& places_key,
const std::vector<Place>& places);
const std::string device_type_;
......
......@@ -205,25 +205,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
use_calc_stream);
}
void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
const std::vector<int64_t>& size_on_each_rank,
int world_size) {
int length_size_on_each_rank = size_on_each_rank.size();
PADDLE_ENFORCE_EQ(
length_size_on_each_rank,
world_size,
phi::errors::InvalidArgument(
"The length of size_on_each_rank must be equal to world_size."));
int64_t sum_size_on_each_rank =
std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
PADDLE_ENFORCE_EQ(
sum_size_on_each_rank,
tensor_dim[0],
phi::errors::InvalidArgument(
"The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......@@ -1059,41 +1040,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
CommType::ALLGATHER);
}
void* GetPointerByOffset(void* raw_pointer, size_t offset, phi::DataType type) {
if (type == phi::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT8) {
return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::UINT8) {
return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::BOOL) {
return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
offset);
} else if (type == phi::DataType::BFLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Datatype %s in NCCL is not supported.", type));
}
return nullptr;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) {
......
......@@ -28,5 +28,60 @@ inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor,
return tensor_flattened.Slice(offset, offset + numel);
}
inline void* GetPointerByOffset(void* raw_pointer,
size_t offset,
phi::DataType type) {
if (type == phi::DataType::FLOAT32) {
return reinterpret_cast<void*>(reinterpret_cast<float*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT64) {
return reinterpret_cast<void*>(reinterpret_cast<double*>(raw_pointer) +
offset);
} else if (type == phi::DataType::FLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<int16_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT32) {
return reinterpret_cast<void*>(reinterpret_cast<int32_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT64) {
return reinterpret_cast<void*>(reinterpret_cast<int64_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::INT8) {
return reinterpret_cast<void*>(reinterpret_cast<int8_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::UINT8) {
return reinterpret_cast<void*>(reinterpret_cast<uint8_t*>(raw_pointer) +
offset);
} else if (type == phi::DataType::BOOL) {
return reinterpret_cast<void*>(reinterpret_cast<bool*>(raw_pointer) +
offset);
} else if (type == phi::DataType::BFLOAT16) {
return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
offset);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Datatype %s in NCCL is not supported.", type));
}
return nullptr;
}
inline void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
const std::vector<int64_t>& size_on_each_rank,
int world_size) {
int length_size_on_each_rank = size_on_each_rank.size();
PADDLE_ENFORCE_EQ(
length_size_on_each_rank,
world_size,
phi::errors::InvalidArgument(
"The length of size_on_each_rank must be equal to world_size."));
int64_t sum_size_on_each_rank =
std::accumulate(size_on_each_rank.begin(), size_on_each_rank.end(), 0);
PADDLE_ENFORCE_EQ(
sum_size_on_each_rank,
tensor_dim[0],
phi::errors::InvalidArgument(
"The sum of size_on_each_rank must be equal to tensor's dim[0]."));
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册