未验证 提交 270f25e9 编写于 作者: Z zhangxiaoci 提交者: GitHub

support KL2 multi-card training, *test=kunlun (#43889)

* update xccl lib
    * use separate streams for compute/comm on XPU
    * add broadcast op to xpu2_op_list
上级 2c8c8419
...@@ -24,6 +24,9 @@ else() ...@@ -24,6 +24,9 @@ else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.0")
if(WITH_AARCH64) if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64") set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
set(XPU_XDNN_DIR_NAME "xdnn-kylin_aarch64") set(XPU_XDNN_DIR_NAME "xdnn-kylin_aarch64")
...@@ -76,7 +79,7 @@ set(XPU_XRE_URL ...@@ -76,7 +79,7 @@ set(XPU_XRE_URL
"${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE) CACHE STRING "" FORCE)
set(XPU_XCCL_URL set(XPU_XCCL_URL
"${XPU_BASE_URL_WITHOUT_DATE}/20220411/${XPU_XCCL_DIR_NAME}.tar.gz" "${XPU_XCCL_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz"
CACHE STRING "" FORCE) CACHE STRING "" FORCE)
set(XPU_PACK_DEPENCE_URL set(XPU_PACK_DEPENCE_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh" "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh"
......
...@@ -110,6 +110,10 @@ void BKCLParallelContext::Init() { ...@@ -110,6 +110,10 @@ void BKCLParallelContext::Init() {
strategy_.local_rank_, strategy_.local_rank_,
xpu_id, xpu_id,
ring_id); ring_id);
compute_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
} }
} }
...@@ -134,6 +138,11 @@ void BKCLParallelContext::InitWithRingID(int ring_id) { ...@@ -134,6 +138,11 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
// it will assign bkcl_comm in XPUDeviceContext within ring_id // it will assign bkcl_comm in XPUDeviceContext within ring_id
platform::BKCLCommContext::Instance().CreateComm( platform::BKCLCommContext::Instance().CreateComm(
&bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id); &bkcl_ids[0], strategy_.nranks_, strategy_.local_rank_, xpu_id, ring_id);
compute_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
comm_events_.emplace_back(
platform::XpuEventResourcePool::Instance().New(place_.device));
} }
void BKCLParallelContext::AllReduceByStream(const framework::Variable &src, void BKCLParallelContext::AllReduceByStream(const framework::Variable &src,
...@@ -213,9 +222,18 @@ void BKCLParallelContext::WaitCompute(int ring_id) { ...@@ -213,9 +222,18 @@ void BKCLParallelContext::WaitCompute(int ring_id) {
"but got ring id = %d, nrings = %d", "but got ring id = %d, nrings = %d",
ring_id, ring_id,
strategy_.nrings_)); strategy_.nrings_));
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>( auto compute_stream = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_)); platform::DeviceContextPool::Instance().Get(place_))
compute_dev_ctx->Wait(); ->stream();
auto comm_stream = platform::BKCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context()
->stream();
auto event = compute_events_[ring_id].get();
// compute_stream-->event-->comm_stream
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event, compute_stream));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(comm_stream, event));
} }
void BKCLParallelContext::WaitComm(int ring_id) { void BKCLParallelContext::WaitComm(int ring_id) {
...@@ -230,9 +248,18 @@ void BKCLParallelContext::WaitComm(int ring_id) { ...@@ -230,9 +248,18 @@ void BKCLParallelContext::WaitComm(int ring_id) {
"but got ring id = %d, nrings = %d", "but got ring id = %d, nrings = %d",
ring_id, ring_id,
strategy_.nrings_)); strategy_.nrings_));
auto comm_dev_ctx = auto comm_stream = platform::BKCLCommContext::Instance()
platform::BKCLCommContext::Instance().Get(ring_id, place_)->dev_context(); .Get(ring_id, place_)
comm_dev_ctx->Wait(); ->dev_context()
->stream();
auto compute_stream = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto event = compute_events_[ring_id].get();
// comm_stream-->event-->compute_stream
PADDLE_ENFORCE_XPU_SUCCESS(xpu_event_record(event, comm_stream));
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_wait_event(compute_stream, event));
} }
void BKCLParallelContext::SynchronizeCompute() { void BKCLParallelContext::SynchronizeCompute() {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/platform/device/xpu/xpu_resource_pool.h"
#include "xpu/bkcl.h" #include "xpu/bkcl.h"
namespace paddle { namespace paddle {
...@@ -52,6 +53,13 @@ class BKCLParallelContext : public ParallelContext { ...@@ -52,6 +53,13 @@ class BKCLParallelContext : public ParallelContext {
void WaitComm(int ring_id) override; void WaitComm(int ring_id) override;
void SynchronizeCompute() override; void SynchronizeCompute() override;
private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::XpuEventObject>> compute_events_;
// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std::vector<std::shared_ptr<platform::XpuEventObject>> comm_events_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
#include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
#ifdef PADDLE_WITH_XPU_BKCL
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace paddle { namespace paddle {
...@@ -431,10 +434,6 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -431,10 +434,6 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
VLOG(3) << "Start construct the Reducer ..."; VLOG(3) << "Start construct the Reducer ...";
nrings_ = parallel_ctx->GetNRings(); nrings_ = parallel_ctx->GetNRings();
nranks_ = parallel_ctx->GetNRanks(); nranks_ = parallel_ctx->GetNRanks();
#ifdef PADDLE_WITH_XPU_BKCL
comm_pool_.reset(new ::ThreadPool(1));
comm_op_count_ = 0;
#endif
// initialize groups // initialize groups
InitializeGroups(group_indices); InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < vars_.size(); for (size_t global_var_index = 0; global_var_index < vars_.size();
...@@ -853,8 +852,23 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) { ...@@ -853,8 +852,23 @@ void Reducer::MarkVarReady(const size_t var_index, const bool is_used_var) {
#ifdef PADDLE_WITH_XPU_BKCL #ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group_tensor.place())) { if (platform::is_xpu_place(group_tensor.place())) {
// TODO(liuyuhui) support XPU set constant auto dev_ctx = static_cast<platform::XPUDeviceContext *>(
VLOG(3) << "XPU doesn't support set_constant"; platform::DeviceContextPool::Instance().Get(place_));
if (HasGrad(var_index)) {
auto var_base = vars_[var_index]->GradVarBase();
auto tensor =
var_base->MutableVar()->GetMutable<framework::LoDTensor>();
group_tensor.ShareDataWith(*tensor).Resize(
{static_cast<int64_t>(length)});
} else {
group_tensor.Resize({static_cast<int64_t>(length)});
int r = xpu::constant(dev_ctx->x_context(),
reinterpret_cast<float *>(group_tensor.data()),
group_tensor.numel(),
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(dev_ctx->stream()));
}
} }
#elif defined(PADDLE_WITH_CNCL) #elif defined(PADDLE_WITH_CNCL)
if (platform::is_mlu_place(group_tensor.place())) { if (platform::is_mlu_place(group_tensor.place())) {
...@@ -948,33 +962,7 @@ void Reducer::MarkGroupReady(size_t group_index) { ...@@ -948,33 +962,7 @@ void Reducer::MarkGroupReady(size_t group_index) {
// so we expose WaitCompute() interface and call // so we expose WaitCompute() interface and call
// it here. // it here.
parallel_ctx_->WaitCompute(run_order); parallel_ctx_->WaitCompute(run_order);
#ifdef PADDLE_WITH_XPU_BKCL
{
std::lock_guard<std::mutex> lock(mutex_);
comm_op_count_ += 1; // lock
}
// TODO(liuyuhui): Add try catch to deal with exception later,
// otherwise the main thread will continue to run when an exception is
// thrown in comm_pool_.
auto next_group = next_group_;
comm_pool_->enqueue([this, run_order, next_group, &group] {
auto dev_id = place_.device;
platform::SetXPUDeviceId(dev_id);
FusedAllReduceSchedule(run_order, group, next_group);
{
std::lock_guard<std::mutex> lock(mutex_);
comm_op_count_ -= 1; // lock
cv_.notify_all();
}
});
#elif defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_GLOO) || defined(PADDLE_WITH_ASCEND_CL) || \
defined(PADDLE_WITH_CNCL)
FusedAllReduceSchedule(run_order, group, next_group_); FusedAllReduceSchedule(run_order, group, next_group_);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Not compiled with BKCL or NCCL or CNCL or GLOO."));
#endif
} }
} }
...@@ -997,17 +985,6 @@ void Reducer::FusedAllReduceSchedule(const int run_order, ...@@ -997,17 +985,6 @@ void Reducer::FusedAllReduceSchedule(const int run_order,
// group.dense_tensors ---> group.dense_contents_ // group.dense_tensors ---> group.dense_contents_
group.ConcatTensors(dev_context); group.ConcatTensors(dev_context);
// NOTE(liuyuhui): ConcatTensors use communication stream, but BKCL only support
// default stream for communicating, so there exist some problems in
// synchronization. And need to add a WaitComm there.
// TODO(liuyuhui): If BKCL support non-blocking communication, it should be
// fixed as multi gpus card training.
#ifdef PADDLE_WITH_XPU_BKCL
if (platform::is_xpu_place(group.dense_tensors_[0].place())) {
parallel_ctx_->WaitComm(run_order);
}
#endif
group.DivNRanks(dev_context, nranks_); group.DivNRanks(dev_context, nranks_);
// Start allreduce // Start allreduce
parallel_ctx_->AllReduceByStream( parallel_ctx_->AllReduceByStream(
...@@ -1135,12 +1112,6 @@ bool Reducer::HasGrad(size_t var_index) { ...@@ -1135,12 +1112,6 @@ bool Reducer::HasGrad(size_t var_index) {
void Reducer::FinalizeBackward() { void Reducer::FinalizeBackward() {
groups_need_finalize_ = false; groups_need_finalize_ = false;
grad_need_hooks_ = false; grad_need_hooks_ = false;
#ifdef PADDLE_WITH_XPU_BKCL
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return comm_op_count_ == 0; });
}
#endif
// Must prevent compute_stream_ starting until all comm streams have finished // Must prevent compute_stream_ starting until all comm streams have finished
for (int i = 0; i < nrings_; ++i) { for (int i = 0; i < nrings_; ++i) {
......
...@@ -347,6 +347,12 @@ BKCLComm* BKCLCommContext::AssignBKCLComm( ...@@ -347,6 +347,12 @@ BKCLComm* BKCLCommContext::AssignBKCLComm(
BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) { BKCLContext_t comm, int nranks, int rank, int dev_id, int ring_id) {
std::unique_ptr<XPUDeviceContext> dev_ctx( std::unique_ptr<XPUDeviceContext> dev_ctx(
new XPUDeviceContext(XPUPlace(dev_id))); new XPUDeviceContext(XPUPlace(dev_id)));
// used in BKCL as comm_stream, for every dev_id there is
// a comm_stream at each ring. this stream is passed as input var
// when calling collective comm commands like bkcl_all_reduce
XPUStream comm_stream;
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&comm_stream));
dev_ctx->SetXPUStream(comm_stream);
BKCLCommImpl* c = new BKCLCommImpl; BKCLCommImpl* c = new BKCLCommImpl;
c->set_ring_id(ring_id); c->set_ring_id(ring_id);
......
...@@ -60,6 +60,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -60,6 +60,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"bilinear_interp_v2_grad", {"bilinear_interp_v2_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"broadcast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"cast", {"cast",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
......
...@@ -23,7 +23,9 @@ limitations under the License. */ ...@@ -23,7 +23,9 @@ limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#endif
#ifndef PADDLE_WITH_CUSTOM_KERNEL #ifndef PADDLE_WITH_CUSTOM_KERNEL
// TODO(wilber): DeviceContextPool nees include fluid file. // TODO(wilber): DeviceContextPool nees include fluid file.
......
...@@ -66,6 +66,8 @@ struct XPUContext::Impl { ...@@ -66,6 +66,8 @@ struct XPUContext::Impl {
const Place& GetPlace() const { return place_; } const Place& GetPlace() const { return place_; }
void SetStream(XPUStream stream) { context_->xpu_stream = stream; }
xpu::Context* GetXContext() const { xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
return context_; return context_;
...@@ -115,6 +117,8 @@ XPUContext::~XPUContext() = default; ...@@ -115,6 +117,8 @@ XPUContext::~XPUContext() = default;
const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); } const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); }
void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); }
backends::xpu::XPUVersion XPUContext::xpu_version() const { backends::xpu::XPUVersion XPUContext::xpu_version() const {
return impl_->xpu_version_; return impl_->xpu_version_;
} }
......
...@@ -61,6 +61,8 @@ class XPUContext : public DeviceContext { ...@@ -61,6 +61,8 @@ class XPUContext : public DeviceContext {
void SetL3Cache(int l3_size = 14155776); void SetL3Cache(int l3_size = 14155776);
void SetXPUStream(XPUStream stream);
private: private:
struct Impl; struct Impl;
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
......
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_context.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#endif
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册