未验证 提交 ad76d37e 编写于 作者: Q QingshuChen 提交者: GitHub

fix bkcl_all_gather and c_embedding_grad bug for xpu (#51785)

上级 97701612
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
...@@ -82,10 +83,12 @@ ProcessGroupBKCL::ProcessGroupBKCL( ...@@ -82,10 +83,12 @@ ProcessGroupBKCL::ProcessGroupBKCL(
: ProcessGroupWithStream(rank, size, gid), store_(store) {} : ProcessGroupWithStream(rank, size, gid), store_(store) {}
void ProcessGroupBKCL::GroupStart() { void ProcessGroupBKCL::GroupStart() {
VLOG(3) << "bkcl_group_start";
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
} }
void ProcessGroupBKCL::GroupEnd() { void ProcessGroupBKCL::GroupEnd() {
VLOG(3) << "bkcl_group_end";
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
} }
...@@ -112,13 +115,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Recv( ...@@ -112,13 +115,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Recv(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_recv(comm, VLOG(3) << "bkcl_recv";
int r = bkcl_recv(comm,
output->data(), output->data(),
output->numel(), output->numel(),
src_rank, src_rank,
platform::ToBKCLDataType( platform::ToBKCLDataType(
framework::TransToProtoVarType(output->type())), framework::TransToProtoVarType(output->type())),
stream); stream);
return r;
}, },
CommType::RECV, CommType::RECV,
sync_op, sync_op,
...@@ -143,13 +148,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Send( ...@@ -143,13 +148,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Send(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_send(comm, VLOG(3) << "bkcl_send";
int r = bkcl_send(comm,
input.data(), input.data(),
input.numel(), input.numel(),
dst_rank, dst_rank,
platform::ToBKCLDataType( platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
stream); stream);
return r;
}, },
CommType::SEND, CommType::SEND,
sync_op, sync_op,
...@@ -269,8 +276,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -269,8 +276,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_all_reduce( VLOG(3) << "bkcl_all_reduce";
comm, int r =
bkcl_all_reduce(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -278,6 +286,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -278,6 +286,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op), ToBKCLRedType(opts.reduce_op),
stream); stream);
return r;
}, },
CommType::ALLREDUCE, CommType::ALLREDUCE,
sync_op, sync_op,
...@@ -298,7 +307,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -298,7 +307,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
int root = opts.source_rank + opts.source_root; int root = opts.source_rank + opts.source_root;
return bkcl_broadcast(comm, VLOG(3) << "bkcl_broadcast";
int r =
bkcl_broadcast(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -306,6 +317,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -306,6 +317,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
root, root,
stream); stream);
return r;
}, },
CommType::BROADCAST, CommType::BROADCAST,
sync_op, sync_op,
...@@ -315,10 +327,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -315,10 +327,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now int64_t offset,
int64_t numel, // for compatibility, no use now int64_t numel,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::XPU);
return Collective( return Collective(
out_tensor, out_tensor,
in_tensor, in_tensor,
...@@ -326,14 +346,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -326,14 +346,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_all_gather( VLOG(3) << "bkcl_all_gather";
comm, int r =
input.data(), bkcl_all_gather(comm,
input.numel(), in_tensor_maybe_partial.data(),
in_tensor_maybe_partial.numel(),
output->data(), output->data(),
platform::ToBKCLDataType( platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
stream); stream);
return r;
}, },
CommType::ALLGATHER, CommType::ALLGATHER,
sync_op, sync_op,
...@@ -353,7 +375,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce( ...@@ -353,7 +375,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_reduce(comm, VLOG(3) << "bkcl_reduce";
int r = bkcl_reduce(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -362,6 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce( ...@@ -362,6 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
ToBKCLRedType(opts.reduce_op), ToBKCLRedType(opts.reduce_op),
opts.root_rank, opts.root_rank,
stream); stream);
return r;
}, },
CommType::REDUCE, CommType::REDUCE,
sync_op, sync_op,
...@@ -381,7 +405,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter( ...@@ -381,7 +405,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_reduce_scatter( VLOG(3) << "bkcl_reduce_scatter";
int r = bkcl_reduce_scatter(
comm, comm,
input.data(), input.data(),
output->data(), output->data(),
...@@ -390,6 +415,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter( ...@@ -390,6 +415,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op), ToBKCLRedType(opts.reduce_op),
stream); stream);
return r;
}, },
CommType::REDUCE_SCATTER, CommType::REDUCE_SCATTER,
sync_op, sync_op,
...@@ -465,8 +491,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -465,8 +491,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_all_reduce( VLOG(3) << "bkcl_all_reduce";
comm,
int r =
bkcl_all_reduce(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -474,6 +502,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -474,6 +502,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op), ToBKCLRedType(opts.reduce_op),
stream); stream);
return r;
}, },
CommType::ALLREDUCE, CommType::ALLREDUCE,
/*sync_op*/ true, /*sync_op*/ true,
...@@ -506,8 +535,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -506,8 +535,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_all_reduce( VLOG(3) << "bkcl_all_reduce";
comm, int r =
bkcl_all_reduce(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -515,6 +545,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce( ...@@ -515,6 +545,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op), ToBKCLRedType(opts.reduce_op),
stream); stream);
return r;
}, },
CommType::ALLREDUCE, CommType::ALLREDUCE,
sync_op, sync_op,
...@@ -549,7 +580,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -549,7 +580,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
const XPUStream& stream) { const XPUStream& stream) {
const auto root = const auto root =
opts.source_rank * in_tensors.size() + opts.source_root; opts.source_rank * in_tensors.size() + opts.source_root;
return bkcl_broadcast(comm, VLOG(3) << "bkcl_broadcast";
int r =
bkcl_broadcast(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -557,6 +590,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -557,6 +590,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
root, root,
stream); stream);
return r;
}, },
CommType::BROADCAST, CommType::BROADCAST,
/*sync_op*/ true, /*sync_op*/ true,
...@@ -592,7 +626,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -592,7 +626,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
const XPUStream& stream) { const XPUStream& stream) {
const auto root = const auto root =
opts.source_rank * in_tensors.size() + opts.source_root; opts.source_rank * in_tensors.size() + opts.source_root;
return bkcl_broadcast(comm, VLOG(3) << "bkcl_broadcast";
int r =
bkcl_broadcast(comm,
input.data(), input.data(),
output->data(), output->data(),
input.numel(), input.numel(),
...@@ -600,6 +636,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast( ...@@ -600,6 +636,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
root, root,
stream); stream);
return r;
}, },
CommType::BROADCAST, CommType::BROADCAST,
sync_op, sync_op,
...@@ -634,14 +671,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -634,14 +671,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_all_gather( VLOG(3) << "bkcl_all_gather";
comm, int r =
bkcl_all_gather(comm,
input.data(), input.data(),
input.numel(), input.numel(),
output->data(), output->data(),
platform::ToBKCLDataType( platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
stream); stream);
return r;
}, },
CommType::ALLGATHER, CommType::ALLGATHER,
/*sync_op*/ true, /*sync_op*/ true,
...@@ -673,14 +712,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -673,14 +712,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
return bkcl_all_gather( VLOG(3) << "bkcl_all_gather";
comm, int r =
bkcl_all_gather(comm,
input.data(), input.data(),
input.numel(), input.numel(),
output->data(), output->data(),
platform::ToBKCLDataType( platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
stream); stream);
return r;
}, },
CommType::ALLGATHER, CommType::ALLGATHER,
sync_op, sync_op,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -83,8 +80,10 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> { ...@@ -83,8 +80,10 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
auto table_grad_t = auto table_grad_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W")); context.Output<phi::DenseTensor>(framework::GradVarName("W"));
T* table_grad_data = auto& dev_ctx = context.template device_context<phi::XPUContext>();
table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace()); table_grad_t->Resize(table_t->dims());
dev_ctx.template Alloc(table_grad_t, table_t->dtype());
T* table_grad_data = static_cast<T*>(table_grad_t->data());
size_t table_t_mem_size = size_t table_t_mem_size =
table_t->numel() * phi::SizeOf(table_grad_t->dtype()); table_t->numel() * phi::SizeOf(table_grad_t->dtype());
...@@ -98,9 +97,8 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> { ...@@ -98,9 +97,8 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
<< ", table_grad_t memory_size:" << table_grad_t_mem_size << ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_idx; << ", start_index:" << start_idx;
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::constant( int r = xpu::constant(
dev_ctx.x_context(), table_grad_data, table_grad_t_mem_size, (T)0); dev_ctx.x_context(), table_grad_data, table_grad_t->numel(), (T)0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
const T* d_output_data = d_output_t->data<T>(); const T* d_output_data = d_output_t->data<T>();
...@@ -132,6 +130,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> { ...@@ -132,6 +130,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64.")); "XPU c_embedding ids only support int32 or int64."));
} }
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册