未验证 提交 ad76d37e 编写于 作者: Q QingshuChen 提交者: GitHub

fix bkcl_all_gather and c_embedding_grad bug for xpu (#51785)

上级 97701612
......@@ -21,6 +21,7 @@
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/check/static_check.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
......@@ -82,10 +83,12 @@ ProcessGroupBKCL::ProcessGroupBKCL(
: ProcessGroupWithStream(rank, size, gid), store_(store) {}
void ProcessGroupBKCL::GroupStart() {
VLOG(3) << "bkcl_group_start";
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
}
void ProcessGroupBKCL::GroupEnd() {
VLOG(3) << "bkcl_group_end";
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
}
......@@ -112,13 +115,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Recv(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_recv(comm,
VLOG(3) << "bkcl_recv";
int r = bkcl_recv(comm,
output->data(),
output->numel(),
src_rank,
platform::ToBKCLDataType(
framework::TransToProtoVarType(output->type())),
stream);
return r;
},
CommType::RECV,
sync_op,
......@@ -143,13 +148,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Send(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_send(comm,
VLOG(3) << "bkcl_send";
int r = bkcl_send(comm,
input.data(),
input.numel(),
dst_rank,
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
stream);
return r;
},
CommType::SEND,
sync_op,
......@@ -269,8 +276,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_all_reduce(
comm,
VLOG(3) << "bkcl_all_reduce";
int r =
bkcl_all_reduce(comm,
input.data(),
output->data(),
input.numel(),
......@@ -278,6 +286,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
return r;
},
CommType::ALLREDUCE,
sync_op,
......@@ -298,7 +307,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
BKCLContext_t comm,
const XPUStream& stream) {
int root = opts.source_rank + opts.source_root;
return bkcl_broadcast(comm,
VLOG(3) << "bkcl_broadcast";
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
......@@ -306,6 +317,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
},
CommType::BROADCAST,
sync_op,
......@@ -315,10 +327,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset, // for compatibility, no use now
int64_t numel, // for compatibility, no use now
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::XPU);
return Collective(
out_tensor,
in_tensor,
......@@ -326,14 +346,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_all_gather(
comm,
input.data(),
input.numel(),
VLOG(3) << "bkcl_all_gather";
int r =
bkcl_all_gather(comm,
in_tensor_maybe_partial.data(),
in_tensor_maybe_partial.numel(),
output->data(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
stream);
return r;
},
CommType::ALLGATHER,
sync_op,
......@@ -353,7 +375,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_reduce(comm,
VLOG(3) << "bkcl_reduce";
int r = bkcl_reduce(comm,
input.data(),
output->data(),
input.numel(),
......@@ -362,6 +385,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
ToBKCLRedType(opts.reduce_op),
opts.root_rank,
stream);
return r;
},
CommType::REDUCE,
sync_op,
......@@ -381,7 +405,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_reduce_scatter(
VLOG(3) << "bkcl_reduce_scatter";
int r = bkcl_reduce_scatter(
comm,
input.data(),
output->data(),
......@@ -390,6 +415,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::ReduceScatter(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
return r;
},
CommType::REDUCE_SCATTER,
sync_op,
......@@ -465,8 +491,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_all_reduce(
comm,
VLOG(3) << "bkcl_all_reduce";
int r =
bkcl_all_reduce(comm,
input.data(),
output->data(),
input.numel(),
......@@ -474,6 +502,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
return r;
},
CommType::ALLREDUCE,
/*sync_op*/ true,
......@@ -506,8 +535,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_all_reduce(
comm,
VLOG(3) << "bkcl_all_reduce";
int r =
bkcl_all_reduce(comm,
input.data(),
output->data(),
input.numel(),
......@@ -515,6 +545,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllReduce(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
return r;
},
CommType::ALLREDUCE,
sync_op,
......@@ -549,7 +580,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
const XPUStream& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return bkcl_broadcast(comm,
VLOG(3) << "bkcl_broadcast";
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
......@@ -557,6 +590,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
},
CommType::BROADCAST,
/*sync_op*/ true,
......@@ -592,7 +626,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
const XPUStream& stream) {
const auto root =
opts.source_rank * in_tensors.size() + opts.source_root;
return bkcl_broadcast(comm,
VLOG(3) << "bkcl_broadcast";
int r =
bkcl_broadcast(comm,
input.data(),
output->data(),
input.numel(),
......@@ -600,6 +636,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Broadcast(
framework::TransToProtoVarType(input.type())),
root,
stream);
return r;
},
CommType::BROADCAST,
sync_op,
......@@ -634,14 +671,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_all_gather(
comm,
VLOG(3) << "bkcl_all_gather";
int r =
bkcl_all_gather(comm,
input.data(),
input.numel(),
output->data(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
stream);
return r;
},
CommType::ALLGATHER,
/*sync_op*/ true,
......@@ -673,14 +712,16 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
return bkcl_all_gather(
comm,
VLOG(3) << "bkcl_all_gather";
int r =
bkcl_all_gather(comm,
input.data(),
input.numel(),
output->data(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
stream);
return r;
},
CommType::ALLGATHER,
sync_op,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -83,8 +80,10 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
auto table_grad_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
T* table_grad_data =
table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace());
auto& dev_ctx = context.template device_context<phi::XPUContext>();
table_grad_t->Resize(table_t->dims());
dev_ctx.template Alloc(table_grad_t, table_t->dtype());
T* table_grad_data = static_cast<T*>(table_grad_t->data());
size_t table_t_mem_size =
table_t->numel() * phi::SizeOf(table_grad_t->dtype());
......@@ -98,9 +97,8 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_idx;
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::constant(
dev_ctx.x_context(), table_grad_data, table_grad_t_mem_size, (T)0);
dev_ctx.x_context(), table_grad_data, table_grad_t->numel(), (T)0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
const T* d_output_data = d_output_t->data<T>();
......@@ -132,6 +130,7 @@ class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
PADDLE_THROW(platform::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册