From b1026f64af97e2a1ed42afbf77aedce17ade2dad Mon Sep 17 00:00:00 2001 From: WangXi Date: Wed, 3 Feb 2021 10:45:47 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90kunlun=E3=80=91dygraph=20supports=20mu?= =?UTF-8?q?lti=20xpu=20card=20training=20(#30671)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/fluid/imperative/CMakeLists.txt | 4 + paddle/fluid/imperative/bkcl_context.cc | 172 ++++++++++++++++++ paddle/fluid/imperative/bkcl_context.h | 53 ++++++ paddle/fluid/imperative/reducer.cc | 87 ++++++++- paddle/fluid/imperative/reducer.h | 2 +- paddle/fluid/imperative/tests/CMakeLists.txt | 5 +- .../imperative/tests/bkcl_context_test.cc | 66 +++++++ paddle/fluid/imperative/tests/test_group.cc | 21 ++- .../fluid/operators/collective/CMakeLists.txt | 4 + .../operators/collective/broadcast_op_xpu.cc | 96 ++++++++++ .../fluid/operators/math/concat_and_split.cc | 98 ++++++++++ paddle/fluid/operators/math/math_function.h | 21 +-- .../fluid/operators/math/math_function_impl.h | 5 +- paddle/fluid/platform/collective_helper.cc | 132 +++++++++++++- paddle/fluid/platform/collective_helper.h | 100 +++++++++- paddle/fluid/platform/device_context.cc | 3 + paddle/fluid/platform/gen_comm_id_helper.cc | 4 +- paddle/fluid/platform/gen_comm_id_helper.h | 2 +- paddle/fluid/platform/xpu_info.h | 23 +++ paddle/fluid/pybind/CMakeLists.txt | 10 + paddle/fluid/pybind/imperative.cc | 27 ++- paddle/fluid/pybind/tensor_py.h | 38 +++- python/paddle/distributed/fleet/launch.py | 49 +++-- .../paddle/distributed/fleet/launch_utils.py | 73 +++++++- python/paddle/distributed/parallel.py | 28 ++- python/paddle/fluid/dygraph/parallel.py | 15 +- .../fluid/tests/unittests/detected_xpu.py | 25 +++ .../fluid/tests/unittests/nproc_process.py | 10 +- .../fluid/tests/unittests/test_dist_base.py | 77 ++++++-- .../unittests/test_dist_mnist_fleet_save.py | 4 +- .../unittests/test_dist_sharding_save.py | 15 +- .../unittests/test_fleet_launch_nproc.sh | 57 +++++- .../unittests/test_parallel_dygraph_mnist.py | 19 ++ 33 files changed, 1225 insertions(+), 120 deletions(-) create mode 100644 paddle/fluid/imperative/bkcl_context.cc create mode 100644 paddle/fluid/imperative/bkcl_context.h create mode 100644 paddle/fluid/imperative/tests/bkcl_context_test.cc create mode 100644 paddle/fluid/operators/collective/broadcast_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/detected_xpu.py diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 2da8169ebd9..7275a176b80 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -14,6 +14,10 @@ if(NOT WIN32) cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits) cc_library(reducer SRCS reducer.cc DEPS layer imperative_all_reduce) endif() + if(WITH_XPU_BKCL) + cc_library(bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits) + cc_library(reducer SRCS reducer.cc DEPS layer) + endif() cc_library(data_loader SRCS data_loader.cc DEPS enforce) endif(NOT WIN32) diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc new file mode 100644 index 00000000000..873068a0d31 --- /dev/null +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/imperative/bkcl_context.h" + +#include +#include +#include + +#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" + +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace imperative { + +static void AllReduce(const framework::Tensor &src, framework::Tensor *dst, + const XPUStream stream, const platform::BKCLComm *comm) { + const auto &place = src.place(); + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(place), true, + platform::errors::Unimplemented( + "Dynamic graph mode does not support multi-CPU training yet.")); + + const void *src_ptr = src.data(); + dst->Resize(src.dims()); + auto *dst_ptr = dst->mutable_data(src.place(), src.type()); + auto bkcl_dtype = platform::ToBKCLDataType(src.type()); + + PADDLE_ENFORCE_EQ(bkcl_all_reduce(comm->comm(), src_ptr, dst_ptr, src.numel(), + bkcl_dtype, BKCL_ADD, stream), + BKCL_SUCCESS, platform::errors::PreconditionNotMet( + "BKCL all reduce failed")); +} +/* +Baidu Kunlun Communication Library(BKCL) is designed for multi Baidu Kunlun +cards communication +as NVIDIA Collective Communications Library(NCCL) in multi Nvidia GPU cards. +Please refer to bkcl.h in xpu.tar.gz linked in cmake/external/xpu.cmake. +*/ +void BKCLParallelContext::BcastBKCLId( + std::vector &bkcl_ids, // NOLINT + int root) { + if (strategy_.local_rank_ == root) { + std::vector other_trainers; + for (auto &ep : strategy_.trainer_endpoints_) { + if (ep != strategy_.current_endpoint_) { + other_trainers.push_back(ep); + } + } + platform::SendBroadCastCommID(other_trainers, &bkcl_ids); + } else { + platform::RecvBroadCastCommID(strategy_.current_endpoint_, &bkcl_ids); + } +} + +void BKCLParallelContext::Init() { + std::vector bkcl_ids; + bkcl_ids.resize(strategy_.nrings_); + + if (strategy_.local_rank_ == 0) { + // generate the unique ncclid on the root worker + for (size_t i = 0; i < bkcl_ids.size(); ++i) { + auto ret = bkcl_get_unique_id(&bkcl_ids[i]); + PADDLE_ENFORCE_EQ(BKCL_SUCCESS, ret, + platform::errors::PreconditionNotMet( + "BKCL get unique id failed [%d]", ret)); + } + } + BcastBKCLId(bkcl_ids, 0); + + int xpu_id = BOOST_GET_CONST(platform::XPUPlace, place_).device; + for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { + VLOG(0) << "init BKCL context nranks: " << strategy_.nranks_ + << " local rank: " << strategy_.local_rank_ << " xpu id: " << xpu_id + << " ring id: " << ring_id; + // it will assign bkcl_comm in XPUDeviceContext within ring_id + platform::BKCLCommContext::Instance().CreateBKCLComm( + &bkcl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, xpu_id, + ring_id); + } +} + +void BKCLParallelContext::AllReduceByStream(const framework::Variable &src, + framework::Variable *dst, + int ring_id, bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(place_), true, + platform::errors::Unimplemented( + "Dynamic graph mode does not support multi-CPU training yet.")); + auto place = place_; + + auto *dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + platform::BKCLComm *comm = + platform::BKCLCommContext::Instance().Get(ring_id, place); + XPUStream stream = + use_calc_stream ? dev_ctx->x_context()->xpu_stream : comm->stream(); + + if (src.IsType()) { + if (!dst->IsType()) { + dst->Clear(); + } + AllReduce(src.Get(), + dst->GetMutable(), stream, comm); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "XPU unsupported variable type %s for imperative allreduce, only " + "LoDTensor are supported.", + platform::demangle(framework::ToTypeName(src.Type())))); + } +} + +paddle::platform::DeviceContext *BKCLParallelContext::GetDeviceContext( + int ring_id) { + return static_cast( + platform::BKCLCommContext::Instance() + .Get(ring_id, place_) + ->dev_context()); +} + +void BKCLParallelContext::WaitCompute(int ring_id) { + PADDLE_ENFORCE_GE(ring_id, 0, + platform::errors::OutOfRange( + "Ring id expected >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT( + ring_id, strategy_.nrings_, + platform::errors::OutOfRange("Ring id expected < nrings," + "but got ring id = %d, nrings = %d", + ring_id, strategy_.nrings_)); + // TODO(wangxi16): [Performance optimize] Maybe need to put Wait and + // bkcl_allreduce to comm thread, for bkcl_allreduce is blocking now. + auto compute_dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place_)); + compute_dev_ctx->Wait(); +} + +void BKCLParallelContext::WaitComm(int ring_id) { + PADDLE_ENFORCE_GE(ring_id, 0, + platform::errors::OutOfRange( + "Ring id expected >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT( + ring_id, strategy_.nrings_, + platform::errors::OutOfRange("Ring id expected < nrings," + "but got ring id = %d, nrings = %d", + ring_id, strategy_.nrings_)); + auto comm_dev_ctx = + platform::BKCLCommContext::Instance().Get(ring_id, place_)->dev_context(); + comm_dev_ctx->Wait(); +} + +} // namespace imperative +} // namespace paddle +#endif diff --git a/paddle/fluid/imperative/bkcl_context.h b/paddle/fluid/imperative/bkcl_context.h new file mode 100644 index 00000000000..d7d917f2008 --- /dev/null +++ b/paddle/fluid/imperative/bkcl_context.h @@ -0,0 +1,53 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#if defined(PADDLE_WITH_XPU_BKCL) +#include +#include +#include + +#include "paddle/fluid/imperative/parallel_context.h" +#include "xpu/bkcl.h" + +namespace paddle { +namespace imperative { + +class BKCLParallelContext : public ParallelContext { + public: + explicit BKCLParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : ParallelContext(strategy, place) {} + + ~BKCLParallelContext() override = default; + + void BcastBKCLId(std::vector& bkcl_ids, int root); // NOLINT + + void Init() override; + + void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id, + bool use_calc_stream) override; + + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; + + void WaitCompute(int ring_id) override; + + void WaitComm(int ring_id) override; +}; + +} // namespace imperative +} // namespace paddle + +#endif diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 0c33cdd7c85..83013d9e796 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -30,17 +30,15 @@ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/string/string_helper.h" -#if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/strided_memcpy.h" -#endif #include "paddle/fluid/imperative/parallel_context.h" namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) +#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) template static void ConcatTensorsForAllReduce( const DeviceContext &context, @@ -130,6 +128,69 @@ static void SplitTensorsWithType( } } +#ifdef PADDLE_WITH_XPU_BKCL +template <> +void SplitTensorsForAllReduce( + const platform::XPUDeviceContext &context, + framework::Variable *p_dense_contents, + std::vector *p_dense_tensors) { + auto *in = p_dense_contents->GetMutable(); + std::vector outs; + std::vector shape_refer; + + outs.reserve(p_dense_tensors->size()); + shape_refer.reserve(p_dense_tensors->size()); + + for (auto &tensor : *p_dense_tensors) { + outs.emplace_back(&tensor); + shape_refer.emplace_back(&tensor); + } + operators::math::SplitFunctor + split_functor_; + split_functor_(context, *in, shape_refer, 0, &outs); +} + +// context is used to select the stream for concat +template <> +void ConcatTensorsWithType( + const platform::XPUDeviceContext &context, + const std::vector &dense_tensors_, + framework::Variable *p_dense_contents, + framework::proto::VarType::Type type) { + switch (type) { + case framework::proto::VarType::FP32: + ConcatTensorsForAllReduce( + context, dense_tensors_, p_dense_contents); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors for " + "allreduce.", + framework::DataTypeToString(type))); + } +} + +// context is used to select the stream for split +template <> +void SplitTensorsWithType( + const platform::XPUDeviceContext &context, + framework::Variable *p_dense_contents, + std::vector *p_dense_tensors, + framework::proto::VarType::Type type) { + switch (type) { + case framework::proto::VarType::FP32: + SplitTensorsForAllReduce( + context, p_dense_contents, p_dense_tensors); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it splits tensors for " + "allreduce.", + framework::DataTypeToString(type))); + } +} +#endif + void Group::ConcatTensors(const platform::DeviceContext &context) { VLOG(3) << "Before concat, set output tensor size is " << all_length_; auto tensor = dense_contents_.GetMutable(); @@ -146,6 +207,16 @@ void Group::ConcatTensors(const platform::DeviceContext &context) { PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't concat grad tensors since it's not compiled with NCCL," "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU_BKCL + ConcatTensorsWithType( + static_cast(context), + dense_tensors_, &dense_contents_, dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat xpu grads since it's not compiled with BKCL," + "Please recompile or reinstall Paddle with BKCL support.")); #endif } else if (platform::is_cpu_place(place)) { ConcatTensorsWithType( @@ -168,6 +239,16 @@ void Group::SplitTensors(const platform::DeviceContext &context) { PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't split grad tensor since it's not compiled with NCCL," "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU_BKCL + SplitTensorsWithType( + static_cast(context), + &dense_contents_, &dense_tensors_, dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't split xpu grad since it's not compiled with BKCL," + "Please recompile or reinstall Paddle with BKCL support.")); #endif } else if (platform::is_cpu_place(place)) { SplitTensorsWithType( diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 90c4cdb3c6a..0d5d93b5900 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -44,7 +44,7 @@ class VariableWrapper; namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) +#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) class Group { public: // Here, we use dense_contents_ & sparse_contents_ to diff --git a/paddle/fluid/imperative/tests/CMakeLists.txt b/paddle/fluid/imperative/tests/CMakeLists.txt index b236ece541e..353c137fbf9 100644 --- a/paddle/fluid/imperative/tests/CMakeLists.txt +++ b/paddle/fluid/imperative/tests/CMakeLists.txt @@ -4,6 +4,9 @@ else() if (WITH_NCCL) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) endif() + if (WITH_XPU_BKCL) + cc_test(bkcl_context_test SRCS bkcl_context_test.cc DEPS bkcl_context) + endif() endif(WIN32) @@ -13,6 +16,6 @@ cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info s cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy) -if (WITH_NCCL) +if (WITH_NCCL OR WITH_XPU_BKCL) cc_test(test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy) endif() diff --git a/paddle/fluid/imperative/tests/bkcl_context_test.cc b/paddle/fluid/imperative/tests/bkcl_context_test.cc new file mode 100644 index 00000000000..580d86b1696 --- /dev/null +++ b/paddle/fluid/imperative/tests/bkcl_context_test.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include // NOLINT + +#include "paddle/fluid/imperative/bkcl_context.h" + +#include "gtest/gtest.h" + +namespace imperative = paddle::imperative; +namespace platform = paddle::platform; + +int nrings = 2; +imperative::ParallelStrategy GetStrategy(int local_rank) { + std::vector eps = {"127.0.0.1:9866", "localhost:9867"}; + imperative::ParallelStrategy strategy; + strategy.trainer_endpoints_ = eps; + strategy.current_endpoint_ = eps[local_rank]; + strategy.nranks_ = 2; + strategy.local_rank_ = local_rank; + strategy.nrings_ = nrings; + return strategy; +} + +#if defined(PADDLE_WITH_XPU_BKCL) +void BcastBKCLId(int local_rank, std::vector* bkcl_ids) { + auto strategy = GetStrategy(local_rank); + platform::XPUPlace xpu(local_rank); + imperative::BKCLParallelContext ctx(strategy, xpu); + ctx.BcastBKCLId(*bkcl_ids, 0); +} + +TEST(BcastBKCLId, Run) { + std::vector bkcl_ids; + bkcl_ids.resize(nrings); + for (int i = 0; i < nrings; ++i) { + bkcl_get_unique_id(&bkcl_ids[i]); + } + + std::thread t(BcastBKCLId, 0, &bkcl_ids); + + std::vector recv_bkcl_ids; + recv_bkcl_ids.resize(nrings); + for (int i = 0; i < nrings; ++i) { + bkcl_get_unique_id(&recv_bkcl_ids[i]); + } + BcastBKCLId(1, &recv_bkcl_ids); + + t.join(); + for (int i = 0; i < nrings; ++i) { + EXPECT_EQ( + 0, std::memcmp(&bkcl_ids[i], &recv_bkcl_ids[i], BKCL_UNIQUE_ID_BYTES)); + } +} +#endif diff --git a/paddle/fluid/imperative/tests/test_group.cc b/paddle/fluid/imperative/tests/test_group.cc index 146ed9396b9..00c3814f913 100644 --- a/paddle/fluid/imperative/tests/test_group.cc +++ b/paddle/fluid/imperative/tests/test_group.cc @@ -20,14 +20,11 @@ #include "glog/logging.h" #include "gtest/gtest.h" -#if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/imperative/reducer.h" -#endif namespace paddle { namespace imperative { -#if defined(PADDLE_WITH_NCCL) TEST(TestGroup, TestPrintGroupMessage) { Group group; std::stringstream stream1, stream2; @@ -80,8 +77,10 @@ void GroupConcatSplit(Place place, size_t size) { } if (std::is_same::value) { +#if defined(PADDLE_WITH_NCCL) paddle::memory::Copy(place, data, cpu_place, value.data(), sizeof(T) * value.size(), 0); +#endif } else { paddle::memory::Copy(place, data, cpu_place, value.data(), sizeof(T) * value.size()); @@ -134,6 +133,7 @@ void GroupConcatSplit(Place place, size_t size) { } } +#if defined(PADDLE_WITH_NCCL) TEST(TestGroup, TestConcatSplit) { platform::CUDAPlace cuda_place(0); platform::CPUPlace cpu_place; @@ -165,5 +165,20 @@ TEST(TestGroup, TestConcatSplitException) { } #endif +#if defined(PADDLE_WITH_XPU_BKCL) +TEST(TestGroup, TestXPUConcatSplit) { + platform::XPUPlace xpu_place(0); + platform::CPUPlace cpu_place; + + int size = 3; + GroupConcatSplit(cpu_place, size); + GroupConcatSplit(xpu_place, size); + + size = 15; + GroupConcatSplit(cpu_place, size); + GroupConcatSplit(xpu_place, size); +} +#endif + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 2b3c80839f2..2e9d1909a65 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -19,6 +19,10 @@ if(WITH_NCCL) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) endif() +if(WITH_BKCL) + set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) +endif() + if(WITH_GLOO) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper) endif() diff --git a/paddle/fluid/operators/collective/broadcast_op_xpu.cc b/paddle/fluid/operators/collective/broadcast_op_xpu.cc new file mode 100644 index 00000000000..2bfd77b8c2a --- /dev/null +++ b/paddle/fluid/operators/collective/broadcast_op_xpu.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/bkcl_helper.h" +#include "paddle/fluid/platform/collective_helper.h" +#endif + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +template +class BKCLBroadcastOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_xpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet( + "The place of ExecutionContext should be XPUPlace.")); + +#if defined(PADDLE_WITH_XPU_BKCL) + int dev_id = BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()).device; + int root_dev_id = ctx.Attr("root"); + + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + PADDLE_ENFORCE_EQ( + out->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "Currently, the output of broadcast op must be initialized," + "because this op can only be an In-Place operation.")); + void* send_recv_buffer = out->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_EQ( + send_recv_buffer, in->data(), + platform::errors::PreconditionNotMet("Currently, the broadcast op can " + "only be an In-Place operation.")); + + auto& dev_ctx = ctx.template device_context(); + auto comm = dev_ctx.bkcl_context(); + auto stream = dev_ctx.x_context()->xpu_stream; + + // TODO(wangxi16): bkcl_broadcast only support float type, + // need to converted other type to float before broadcasting. + // Broadcast is equivalent to no type of operation, does not affect + // correctness. + // Once bkcl_broadcast support other type, need chang to: + // BKCLDataType data_type = platform::ToBKCLDataType(in->type()); + BKCLDataType data_type = BKCL_FLOAT; + size_t scale = sizeof(T) / sizeof(float); + auto ret = bkcl_broadcast(comm, send_recv_buffer, send_recv_buffer, + static_cast(in->numel()) * scale, + data_type, root_dev_id, stream); + PADDLE_ENFORCE_EQ(ret, BKCL_SUCCESS, + platform::errors::Unavailable("bkcl_broadcast failed")); + + VLOG(3) << "Bcast " << ctx.InputNames("X")[0] << ", (" << in->numel() << ")" + << " From " << root_dev_id << " to " << dev_id; + + if (ctx.Attr("sync_mode")) { + dev_ctx.Wait(); + } +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_XPU_KERNEL(broadcast, ops::BKCLBroadcastOpKernel, + ops::BKCLBroadcastOpKernel, + ops::BKCLBroadcastOpKernel, + ops::BKCLBroadcastOpKernel); diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index 3b0c3c1686a..7df78b321de 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -119,12 +119,110 @@ class SplitFunctor { } } }; + +#ifdef PADDLE_WITH_XPU +/* + * All tensors' dimension should be the same and the values of + * each dimension must be the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::XPUDeviceContext& context, + const std::vector& input, int axis, + framework::Tensor* output) { + int dev_id = + BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()).GetDeviceId(); + platform::XPUDeviceGuard guard(dev_id); + + int num = input.size(); + auto input_dims = input[0].dims(); + + std::vector> xdims_list(num); + for (int i = 0; i < num; ++i) { + std::vector tmp_dims(input_dims.size()); + for (int j = 0; j < input_dims.size(); ++j) { + tmp_dims[j] = input[i].dims()[j]; + } + xdims_list[i] = tmp_dims; + } + + std::vector ptrs; + for (int i = 0; i < num; ++i) { + ptrs.push_back(input[i].data()); + } + + auto r = xpu::concat(context.x_context(), ptrs, output->data(), + xdims_list, axis); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d %s], please check whether " + "Baidu Kunlun Card is properly installed.", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class SplitFunctor { + public: + void operator()(const platform::XPUDeviceContext& context, + const framework::Tensor& input, + const std::vector& ref_inputs, + const int axis, std::vector* outputs) { + int dev_id = + BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()).GetDeviceId(); + platform::XPUDeviceGuard guard(dev_id); + + auto& ins = ref_inputs; + + int num = ins.size(); + auto input_dims = ins[0]->dims(); + std::vector split_list(num); + std::vector xdims_list(input_dims.size()); + int total_length = 0; + for (int i = 0; i < num; ++i) { + split_list[i] = ins[i]->dims()[axis]; + total_length += ins[i]->dims()[axis]; + } + + for (int i = 0; i < input_dims.size(); ++i) { + if (i == axis) continue; + xdims_list[i] = input_dims[i]; + } + xdims_list[axis] = total_length; + + std::vector ptrs(num); + for (int i = 0; i < num; ++i) { + ptrs[i] = outputs->at(i)->data(); + } + + auto r = xpu::split(context.x_context(), input.data(), ptrs, + xdims_list, split_list, axis); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d %s], please check whether " + "Baidu Kunlun Card is properly installed.", + r, XPUAPIErrorMsg[r])); + } +}; +#endif + #define DEFINE_FUNCTOR(type) \ template class ConcatFunctor; \ template class SplitFunctor; FOR_ALL_TYPES(DEFINE_FUNCTOR); +#ifdef PADDLE_WITH_XPU +#define DEFINE_XPU_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class SplitFunctor; + +DEFINE_XPU_FUNCTOR(float) +#endif + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index 1ad1c29ddd8..ea313cb6169 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -88,27 +88,22 @@ struct RowwiseMean { #ifdef PADDLE_WITH_XPU template struct TensorSetConstantXPU { - TensorSetConstantXPU(framework::Tensor* tensor, U value) - : tensor_(tensor), value_(value) {} + TensorSetConstantXPU(framework::Tensor* tensor, U value, + platform::Place place) + : tensor_(tensor), value_(value), place_(place) {} template void apply() const { - int dev_id = -1; - xpu_current_device(&dev_id); - if (dev_id >= 64) { - // if dev_id >= 64, the device is a simulator device, -64 to get real - // dev_id - dev_id -= 64; - } - auto xpu = platform::XPUPlace(dev_id); - auto* begin = tensor_->mutable_data(xpu); + auto* begin = tensor_->mutable_data(place_); int numel = tensor_->numel(); std::unique_ptr data_cpu(new T[numel]); std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast(value_)); - memory::Copy(xpu, begin, platform::CPUPlace(), - static_cast(data_cpu.get()), numel * sizeof(T)); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place_), begin, + platform::CPUPlace(), static_cast(data_cpu.get()), + numel * sizeof(T)); } framework::Tensor* tensor_; U value_; + platform::Place place_; }; #endif diff --git a/paddle/fluid/operators/math/math_function_impl.h b/paddle/fluid/operators/math/math_function_impl.h index 68cfdacde2a..0e44f903043 100644 --- a/paddle/fluid/operators/math/math_function_impl.h +++ b/paddle/fluid/operators/math/math_function_impl.h @@ -32,8 +32,9 @@ void SetConstant::operator()(const DeviceContext& context, #ifdef PADDLE_WITH_XPU if (platform::is_xpu_place(context.GetPlace())) { xpu_place = true; - framework::VisitDataType(tensor->type(), - TensorSetConstantXPU(tensor, num)); + framework::VisitDataType( + tensor->type(), + TensorSetConstantXPU(tensor, num, context.GetPlace())); } #endif if (!xpu_place) { diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 08d70404a24..1e0e60eff8c 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/platform/collective_helper.h" #include namespace paddle { namespace platform { - +#if defined(PADDLE_WITH_NCCL) class NCCLCommImpl : public NCCLComm { public: void set_ring_id(int ring_id) { ring_id_ = ring_id; } @@ -159,7 +158,132 @@ void NCCLCommContext::ReleaseNCCLComms() { } } -} // namespace platform -} // namespace paddle +#endif + +#if defined(PADDLE_WITH_XPU_BKCL) + +class BKCLCommImpl : public BKCLComm { + public: + void set_ring_id(int ring_id) { ring_id_ = ring_id; } + int ring_id() const override { return ring_id_; } + + void set_nranks(int nranks) { nranks_ = nranks; } + int nranks() const override { return nranks_; } + + void set_rank(int rank) { rank_ = rank; } + int rank() const override { return rank_; } + + int device_id() const override { + return BOOST_GET_CONST(XPUPlace, dev_ctx_->GetPlace()).device; + } + + void set_comm(BKCLContext_t comm) { comm_ = comm; } + BKCLContext_t comm() const override { return comm_; } + + XPUStream stream() const override { + return dev_ctx_->x_context()->xpu_stream; + } + + void set_dev_ctx(std::unique_ptr&& dev_ctx) { + dev_ctx_ = std::move(dev_ctx); + } + XPUDeviceContext* dev_context() const override { return dev_ctx_.get(); } + + private: + int ring_id_; + int nranks_; + int rank_; + BKCLContext_t comm_; + std::unique_ptr dev_ctx_; +}; + +BKCLComm* BKCLCommContext::CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, + int rank, int dev_id, int ring_id) { + PADDLE_ENFORCE_NOT_NULL(bkcl_id, + platform::errors::InvalidArgument( + "The bkcl unique id should not be null.")); + PADDLE_ENFORCE_GT( + nranks, 1, + platform::errors::InvalidArgument( + "Expected nranks > 1. But received nranks is %d.", nranks)); + PADDLE_ENFORCE_GE(rank, 0, + platform::errors::InvalidArgument( + "Expected rank >= 0. But received rank is %d.", rank)); + PADDLE_ENFORCE_LT( + rank, nranks, + platform::errors::InvalidArgument( + "Expected rank < nranks. But received rank is %d, nranks is %d.", + rank, nranks)); + PADDLE_ENFORCE_GE( + dev_id, 0, + platform::errors::InvalidArgument( + "Expected dev_id >= 0. But received dev_id is %d.", dev_id)); + + BKCLContext_t comm = nullptr; + auto ret = xpu_set_device(dev_id); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::PreconditionNotMet( + "XPU API return wrong value[%d %s], please check whether " + "Baidu Kunlun Card is properly installed.", + ret, XPUAPIErrorMsg[ret])); + ret = bkcl_init_rank(&comm, rank, nranks, bkcl_id); + PADDLE_ENFORCE_EQ(ret, BKCL_SUCCESS, + platform::errors::PreconditionNotMet( + "bkcl_init_rank failed, got wrong value [%d].", ret)); + + auto* comm_wrapper = AssignBKCLComm(comm, nranks, rank, dev_id, ring_id); + + VLOG(1) << "bkcl communicator of rank " << rank << " in ring " << ring_id + << " has been created on device " << dev_id; + + std::call_once(once_flag_, []() { + std::atexit([]() { BKCLCommContext::Instance().ReleaseBKCLComms(); }); + }); + + return comm_wrapper; +} + +BKCLComm* BKCLCommContext::AssignBKCLComm(BKCLContext_t comm, int nranks, + int rank, int dev_id, int ring_id) { + std::unique_ptr dev_ctx( + new XPUDeviceContext(XPUPlace(dev_id))); + + BKCLCommImpl* c = new BKCLCommImpl; + c->set_ring_id(ring_id); + c->set_nranks(nranks); + c->set_rank(rank); + c->set_comm(comm); + c->set_dev_ctx(std::move(dev_ctx)); + + comm_map_mutex_.lock(); + if (comm_map_.count(ring_id) == 0) { + comm_map_.emplace(ring_id, std::map>()); + } + auto& dev2comm = comm_map_[ring_id]; + + dev2comm.emplace(dev_id, std::unique_ptr(c)); + comm_map_mutex_.unlock(); + + if (ring_id == 0) { + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get( + platform::XPUPlace(dev_id))); + dev_ctx->set_bkcl_context(comm); + } + + return comm_map_[ring_id][dev_id].get(); +} + +void BKCLCommContext::ReleaseBKCLComms() { + for (auto& p : comm_map_) { + for (auto& q : p.second) { + q.second.reset(); + } + } +} #endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index d44199f309b..82d79c53d0d 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -14,7 +14,6 @@ #pragma once -#if defined(PADDLE_WITH_NCCL) #include #include #include @@ -28,6 +27,7 @@ namespace paddle { namespace platform { +#if defined(PADDLE_WITH_NCCL) // In order to apply hierarchical communication with NCCL, we need // a communication ring contains NCCL communicators associated to a global // ncclUniqueId. E.g. for a hierarchical case, @@ -120,8 +120,102 @@ class NCCLCommContext { NCCLCommContext() = default; DISABLE_COPY_AND_ASSIGN(NCCLCommContext); }; +#endif -} // namespace platform -} // namespace paddle +#if defined(PADDLE_WITH_XPU_BKCL) +// In order to apply hierarchical communication with BKCL, we need +// a communication ring contains BKCL communicators associated to a global +// BKCLUniqueId. E.g. for a hierarchical case, +// +// 11 - 12 21 - 22 +// | | | | +// 13 - 14 - 23 - 24 +// | | +// 31 - 32 - 41 - 42 +// | | | | +// 33 - 34 43 - 44 +// +// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24), +// (31,32,33,34), (41,42,43,44) as bottoms respectively. +// +// We could also use a single communication ring for the flatten case +// +// The BKCLComm instance is created and reversed in the BKCLCommContext +// singleton with a global user specified group id. +class BKCLComm { + public: + virtual int ring_id() const = 0; + virtual int nranks() const = 0; + virtual int rank() const = 0; + virtual int device_id() const = 0; + virtual BKCLContext_t comm() const = 0; + virtual XPUStream stream() const = 0; + virtual XPUDeviceContext* dev_context() const = 0; + virtual ~BKCLComm() = default; +}; + +// A singleton BKCL communicator context reserves communication ring ids +class BKCLCommContext { + public: + static BKCLCommContext& Instance() { + static BKCLCommContext comm_ctx; + return comm_ctx; + } + + BKCLComm* CreateBKCLComm(BKCLUniqueId* bkcl_id, int nranks, int rank, + int dev_id, int ring_id = 0); + + void CreateAllBKCLComms(const std::vector& dev_ids, int ring_id = 0); + + // a latter comm with the same dev_id and the same ring_id + // will override the former + BKCLComm* AssignBKCLComm(BKCLContext_t comm, int nranks, int rank, int dev_id, + int ring_id = 0); + // retrieve a communicator by the ring id in multiprocessing mode + BKCLComm* Get(int ring_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Communicator in ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_EQ(comm_map_.at(ring_id).size(), 1, + platform::errors::InvalidArgument( + "One device id should be specified to retrieve from " + "multiple communicators.")); + return comm_map_.at(ring_id).begin()->second.get(); + } + + // retrieve a communicator by the ring id and the device id + BKCLComm* Get(int ring_id, int dev_id) const { + PADDLE_ENFORCE_GT( + comm_map_.count(ring_id), 0, + platform::errors::InvalidArgument( + "Communicator of ring id %d has not been initialized.", ring_id)); + PADDLE_ENFORCE_GT( + comm_map_.at(ring_id).count(dev_id), 0, + platform::errors::InvalidArgument( + "Communicator at device id %d has not been initialized in ring %d.", + dev_id, ring_id)); + return comm_map_.at(ring_id).at(dev_id).get(); + } + + // retrieve a communicator by the ring id and place + BKCLComm* Get(int ring_id, Place place) const { + return Get(ring_id, BOOST_GET_CONST(XPUPlace, place).device); + } + + private: + std::once_flag once_flag_; + std::mutex comm_map_mutex_; + // ring id to dev-BKCLComm + std::map>> comm_map_; + + void ReleaseBKCLComms(); + + BKCLCommContext() = default; + DISABLE_COPY_AND_ASSIGN(BKCLCommContext); +}; #endif + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index b9a8dd98456..51a799c65fb 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -188,6 +188,9 @@ XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) { "XPU API return wrong value[%d], please check whether " "Baidu Kunlun Card is properly installed.", ret)); + + LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << place_.device; + context_ = xpu::create_context(); const int MAX_XPU_NUM = 16; const int l3_size = 13.5 * 1024 * 1024; diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index 08f0af5fc91..732e3e5e5eb 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NCCL +#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/gen_comm_id_helper.h" #include @@ -339,7 +339,7 @@ void RecvBroadCastCommID(int server_fd, std::string endpoint, INSTANT_TEMPLATE(ncclUniqueId) #endif #ifdef PADDLE_WITH_XPU_BKCL -INSTANT_TEMPLATE(bkclUniqueId) +INSTANT_TEMPLATE(BKCLUniqueId) #endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/gen_comm_id_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h index 5384d704708..114f5a0b993 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NCCL +#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) #include #include #include diff --git a/paddle/fluid/platform/xpu_info.h b/paddle/fluid/platform/xpu_info.h index efaba13453e..2bf7b0b5cb6 100644 --- a/paddle/fluid/platform/xpu_info.h +++ b/paddle/fluid/platform/xpu_info.h @@ -28,6 +28,29 @@ std::vector GetXPUSelectedDevices(); //! Set the XPU device id for next execution. void SetXPUDeviceId(int device_id); +class XPUDeviceGuard { + public: + explicit inline XPUDeviceGuard(int dev_id) { + int prev_id = platform::GetXPUCurrentDeviceId(); + if (prev_id != dev_id) { + prev_id_ = prev_id; + platform::SetXPUDeviceId(dev_id); + } + } + + inline ~XPUDeviceGuard() { + if (prev_id_ != -1) { + platform::SetXPUDeviceId(prev_id_); + } + } + + XPUDeviceGuard(const XPUDeviceGuard& o) = delete; + XPUDeviceGuard& operator=(const XPUDeviceGuard& o) = delete; + + private: + int prev_id_{-1}; +}; + } // namespace platform } // namespace paddle #endif diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 39e83ab12d5..e4b86a998a9 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -5,6 +5,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp if (WITH_GPU) set(PYBIND_DEPS ${PYBIND_DEPS} dynload_cuda) + set(PYBIND_DEPS ${PYBIND_DEPS} cuda_device_guard) endif() if (WITH_NCCL) @@ -12,6 +13,11 @@ if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} reducer) endif() +if (WITH_XPU_BKCL) + set(PYBIND_DEPS ${PYBIND_DEPS} reducer) + set(PYBIND_DEPS ${PYBIND_DEPS} bkcl_context) +endif() + if(NOT WIN32) set(PYBIND_DEPS ${PYBIND_DEPS} data_loader) set(PYBIND_DEPS ${PYBIND_DEPS} mmap_allocator) @@ -79,6 +85,10 @@ if(WITH_PYTHON) list(APPEND OP_FUNCTION_GENERETOR_DEPS nccl_context) endif(WITH_NCCL) + if(WITH_XPU_BKCL) + list(APPEND OP_FUNCTION_GENERETOR_DEPS bkcl_context) + endif(WITH_XPU_BKCL) + add_executable(op_function_generator op_function_generator.cc) target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS}) get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index cceae74f1dc..6185b978511 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -32,6 +32,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" +#include "paddle/fluid/imperative/bkcl_context.h" #include "paddle/fluid/imperative/data_loader.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/nccl_context.h" @@ -1377,16 +1378,10 @@ void BindImperative(py::module *m_ptr) { }, py::call_guard()); -#if defined(PADDLE_WITH_NCCL) +#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) py::class_>(m, "ParallelContext"); - py::class_>( - m, "NCCLParallelContext") - .def(py::init()) - .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }); py::class_>( m, "Reducer", R"DOC()DOC") @@ -1404,6 +1399,24 @@ void BindImperative(py::module *m_ptr) { py::arg("tensor_indices") = std::vector{}, py::call_guard()); #endif + +#if defined(PADDLE_WITH_NCCL) + py::class_>( + m, "NCCLParallelContext") + .def(py::init()) + .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }); +#endif + +#if defined(PADDLE_WITH_XPU_BKCL) + py::class_>( + m, "BKCLParallelContext") + .def(py::init()) + .def("init", [](imperative::BKCLParallelContext &self) { self.Init(); }); +#endif } } // namespace pybind diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 6d1281d11f1..e5db28c6f3e 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -27,6 +27,9 @@ limitations under the License. */ #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/platform/bfloat16.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_device_guard.h" +#endif #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" @@ -256,6 +259,38 @@ void TensorSetElement(framework::Tensor *self, size_t offset, T elem) { } } +// NOTE(wangxi): When copying data to the accelerator card, +// we need set_device(dev_id) first. +template +static int GetDeviceId(const P &place) { + // for CPUPlace and CUDAPinnedPlace. + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't Get CPUPlace or CUDAPinnedPlace Device Id.")); +} + +template <> +int GetDeviceId(const platform::CUDAPlace &place) { + return place.GetDeviceId(); +} + +template <> +int GetDeviceId(const platform::XPUPlace &place) { + return place.GetDeviceId(); +} + +// NOTE(wangxi16): Used by VarBase __setitem__ +template <> +int GetDeviceId(const platform::Place &place) { + if (paddle::platform::is_gpu_place(place)) { + return GetDeviceId(BOOST_GET_CONST(platform::CUDAPlace, place)); + } else if (paddle::platform::is_xpu_place(place)) { + return GetDeviceId(BOOST_GET_CONST(platform::XPUPlace, place)); + } + // for CPUPlace and CUDAPinnedPlace. + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't Get CPUPlace or CUDAPinnedPlace Device Id.")); +} + template void SetTensorFromPyArrayT( framework::Tensor *self, @@ -279,6 +314,7 @@ void SetTensorFromPyArrayT( } } else if (paddle::platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU + platform::XPUDeviceGuard guard(GetDeviceId(place)); auto dst = self->mutable_data(place); xpu_memcpy(dst, array.data(), array.nbytes(), XPUMemcpyKind::XPU_HOST_TO_DEVICE); @@ -290,7 +326,7 @@ void SetTensorFromPyArrayT( } else { #ifdef PADDLE_WITH_CUDA if (paddle::platform::is_gpu_place(place)) { - // TODO(zhiqiu): set SetDeviceId before calling cuda APIs. + platform::CUDADeviceGuard guard(GetDeviceId(place)); auto dst = self->mutable_data(place); paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), cudaMemcpyHostToDevice); diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index c7c60a3fbde..0f9b13d8a12 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -108,16 +108,26 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra "In gpu training, it should be less or equal to the gpus number of you system(or you set by --gpus). And so each process can" " bound to one or average number of gpus.") - base_group.add_argument( - "--gpus", - type=str, - default=None, - help="It's for gpu training." - "For example:" - "--gpus=\"0,1,2,3\" will launch four training processes each bound to one gpu." - ) - - base_group.add_argument("--selected_gpus", dest="gpus") + if fluid.core.is_compiled_with_cuda(): + base_group.add_argument( + "--gpus", + type=str, + default=None, + help="It's for gpu training." + "For example:" + "--gpus=\"0,1,2,3\" will launch four training processes each bound to one gpu." + ) + base_group.add_argument("--selected_gpus", dest="gpus") + + if fluid.core.is_compiled_with_xpu(): + base_group.add_argument( + "--xpus", + type=str, + default=None, + help="It's for xpu training. For example: " + "--xpus=\"0,1,2,3\" will launch four training processes each bound to one xpu." + ) + base_group.add_argument("--selected_xpus", dest="xpus") base_group.add_argument( "training_script", @@ -288,14 +298,16 @@ def which_distributed_mode(args): ) if fluid.core.is_compiled_with_cuda(): - cuda_device_num = fluid.core.get_cuda_device_count() + device_count = fluid.core.get_cuda_device_count() + elif fluid.core.is_compiled_with_xpu(): + device_count = fluid.core.get_xpu_device_count() else: - cuda_device_num = 0 + device_count = 0 if len(has_ps_args) > 0: logger.info( - "Run parameter-sever mode. pserver arguments:{}, cuda count:{}". - format(has_ps_args, cuda_device_num)) + "Run parameter-sever mode. pserver arguments:{}, cuda or xpu count:{}". + format(has_ps_args, device_count)) has_ps_heter_args = list(set(has_ps_args) & set(ps_heter_args)) if len(has_ps_heter_args) > 0: return DistributeMode.PS_HETER @@ -303,17 +315,18 @@ def which_distributed_mode(args): return DistributeMode.PS elif len(has_collective_args) > 0: logger.info("Run collective gpu mode. gpu arguments:{}, cuda count:{}". - format(has_collective_args, cuda_device_num)) + format(has_collective_args, device_count)) return DistributeMode.COLLECTIVE else: - if not fluid.core.is_compiled_with_cuda(): + if not fluid.core.is_compiled_with_cuda( + ) and not fluid.core.is_compiled_with_xpu(): logger.warning( - "Not found distinct arguments and not compiled with cuda. Default use ps mode" + "Not found distinct arguments and not compiled with cuda or xpu. Default use ps mode" ) return DistributeMode.PS else: logger.warning( - "Not found distinct arguments and compiled with cuda. Default use collective mode" + "Not found distinct arguments and compiled with cuda or xpu. Default use collective mode" ) return DistributeMode.COLLECTIVE diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index 625e8a476b5..b4f1f931490 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -47,10 +47,11 @@ class DeviceMode(): """ Training devices type """ + UNKNOWN = -1 CPU = 0 GPU = 1 KUNLUN = 2 - UNKNOWN = 3 + XPU = 2 class Cluster(object): @@ -275,6 +276,11 @@ def get_cluster(node_ips, node_ip, trainer_endpoints, device_mode, trainer.gpus.extend(devices_per_proc[i]) else: trainer.gpus.append(devices_per_proc[i]) + elif device_mode == DeviceMode.XPU: + if isinstance(devices_per_proc[i], (list, tuple)): + trainer.gpus.extend(devices_per_proc[i]) + else: + trainer.gpus.extend(devices_per_proc[i]) trainer.endpoint = "%s" % (cur_node_endpoints[i]) trainer.rank = trainer_rank trainer_rank += 1 @@ -454,9 +460,12 @@ def start_local_trainers(cluster, "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) } - if len(t.gpus) > 0: + if fluid.core.is_compiled_with_cuda() and len(t.gpus) > 0: proc_env["FLAGS_selected_gpus"] = "%s" % ",".join( [str(g) for g in t.gpus]) + elif fluid.core.is_compiled_with_xpu() and len(t.gpus) > 0: + proc_env["FLAGS_selected_xpus"] = "%s" % ",".join( + [str(g) for g in t.gpus]) current_env.update(proc_env) @@ -584,15 +593,47 @@ def get_gpus(gpus): return res_gpus +def get_xpus(xpus): + if xpus is None: + xpus_num = fluid.core.get_xpu_device_count() + res_xpus = [str(x) for x in range(0, xpus_num)] + else: + xpu_visible_devices = os.getenv("XPU_VISIBLE_DEVICES") + if xpu_visible_devices is None or xpu_visible_devices == "": + res_xpus = [x.strip() for x in xpus.split(',')] + else: + # change xpus into relative values + # e.g. XPU_VISIBLE_DEVICES=4,5,6,7; args.xpus=4,5,6,7; + # therefore xpus=0,1,2,3 + xpu_visible_devices_list = xpu_visible_devices.split(',') + for x in xpus.split(','): + assert x in xpu_visible_devices_list, "Can't find "\ + "your xpus %s in XPU_VISIBLE_DEVICES[%s]."\ + % (x, xpu_visible_devices) + res_xpus = [ + xpu_visible_devices_list.index(x.strip()) + for x in xpus.split(',') + ] + logger.info("Change selected_xpus into reletive values. --ips:{} " + "will change into relative_ips:{} according to your " + "XPU_VISIBLE_DEVICES:{}".format( + xpus, res_xpus, xpu_visible_devices_list)) + + return res_xpus + + def get_device_mode(): - #TODO(gongwb):Add XPU supported - if not fluid.core.is_compiled_with_cuda( - ) or fluid.core.get_cuda_device_count() <= 0: - print("launch train in CPU mode") - return DeviceMode.CPU + if fluid.core.is_compiled_with_cuda() and fluid.core.get_cuda_device_count( + ) > 0: + print("launch train in GPU mode") + return DeviceMode.GPU + elif fluid.core.is_compiled_with_xpu() and fluid.core.get_xpu_device_count( + ) > 0: + print("launch train in XPU mode") + return DeviceMode.XPU - print("launch train in GPU mode") - return DeviceMode.GPU + print("launch train in CPU mode") + return DeviceMode.CPU def get_device_proc_info(args): @@ -613,13 +654,25 @@ def get_device_proc_info(args): ] else: devices_per_proc = gpus + elif device_mode == DeviceMode.XPU: + xpus = get_xpus(args.xpus) + if args.nproc_per_node is not None: + assert (len(xpus) % int(args.nproc_per_node)) == 0, \ + "xpus' number:{} mod args.nproc_per_node:{} must == 0".format(len(xpus), arg.nproc_per_node) + + n = int(len(xpus) / int(args.nproc_per_node)) + devices_per_proc = [ + xpus[i:i + n] for i in six.moves.range(0, len(xpus), n) + ] + else: + devices_per_proc = xpus elif device_mode == DeviceMode.CPU: if args.nproc_per_node is None: devices_per_proc = [0] else: devices_per_proc = [x for x in range(0, args.nproc_per_node)] else: - assert False, "Can't support device_mode:{}, support only cpu and gpu now.".format( + assert False, "Can't support device_mode:{}, support only cpu|gpu|xpu now.".format( device_mode) return (device_mode, devices_per_proc) diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index c41c3663a17..582c0be713f 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -120,12 +120,12 @@ def init_parallel_env(): ) return - # 1. gpu check - if not core.is_compiled_with_cuda(): + # 1. gpu xpu check, must be gpu or xpu + if not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(): raise NotImplementedError( "Cannot initialize parallel environment in CPU-only version, now only " - "supports initializing the GPU parallel environment. Please recompile " - "or reinstall paddle with GPU support.") + "supports initializing the GPU and XPU parallel environment. Please recompile " + "or reinstall paddle with GPU or XPU support.") # 2. check env def _check_var_exists(var_name): @@ -135,7 +135,11 @@ def init_parallel_env(): "environment variable %s is needed, but not set." % var_name) - _check_var_exists("FLAGS_selected_gpus") + if core.is_compiled_with_cuda(): + _check_var_exists("FLAGS_selected_gpus") + elif core.is_compiled_with_xpu(): + _check_var_exists('FLAGS_selected_xpus') + _check_var_exists("PADDLE_TRAINER_ID") _check_var_exists("PADDLE_CURRENT_ENDPOINT") _check_var_exists("PADDLE_TRAINERS_NUM") @@ -176,11 +180,19 @@ def init_parallel_env(): # directly, if they want to switch default place, # they need to call a function to change default place, # here just set correctly place to users - place = core.CUDAPlace(parallel_env.device_id) + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(parallel_env.device_id) + elif core.is_compiled_with_xpu(): + place = core.XPUPlace(parallel_env.device_id) _set_expected_place(place) - # init nccl context - parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place)) + # init nccl or bkcl context + if core.is_compiled_with_cuda(): + parallel_helper._set_parallel_ctx( + core.NCCLParallelContext(strategy, place)) + elif core.is_compiled_with_xpu(): + parallel_helper._set_parallel_ctx( + core.BKCLParallelContext(strategy, place)) parallel_helper._init_parallel_ctx() # 5: init gloo context (step 2: gloo init) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index a80f6b3f491..854cb86d925 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -55,9 +55,12 @@ def prepare_context(strategy=None): if isinstance(place, core.CUDAPlace): parallel_helper._set_parallel_ctx( core.NCCLParallelContext(strategy, place)) + elif isinstance(place, core.XPUPlace): + parallel_helper._set_parallel_ctx( + core.BKCLParallelContext(strategy, place)) else: # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation - assert ("Only support CUDAPlace for now.") + assert ("Only support CUDAPlace or XPUPlace for now.") parallel_helper._init_parallel_ctx() return strategy @@ -108,9 +111,13 @@ class ParallelEnv(object): self._rank = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._world_size = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) - # imperative only support one gpu - selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",") - self._device_id = int(selected_gpus[0]) + # imperative only support one gpu or xpu + if core.is_compiled_with_cuda(): + selected_gpus = os.getenv("FLAGS_selected_gpus", "0").split(",") + self._device_id = int(selected_gpus[0]) + elif core.is_compiled_with_xpu(): + selected_xpus = os.getenv("FLAGS_selected_xpus", "0").split(",") + self._device_id = int(selected_xpus[0]) self._trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(",") diff --git a/python/paddle/fluid/tests/unittests/detected_xpu.py b/python/paddle/fluid/tests/unittests/detected_xpu.py new file mode 100644 index 00000000000..d7b6f58c941 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/detected_xpu.py @@ -0,0 +1,25 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import sys +import paddle.fluid as fluid + +print("compile with xpu:", fluid.core.is_compiled_with_xpu()) +print("get_xpu_device_count:", fluid.core.get_xpu_device_count()) + +if fluid.core.is_compiled_with_xpu() and fluid.core.get_xpu_device_count() > 0: + sys.exit(0) +else: + sys.exit(1) diff --git a/python/paddle/fluid/tests/unittests/nproc_process.py b/python/paddle/fluid/tests/unittests/nproc_process.py index c0e60eec458..e8b8ea11440 100644 --- a/python/paddle/fluid/tests/unittests/nproc_process.py +++ b/python/paddle/fluid/tests/unittests/nproc_process.py @@ -15,18 +15,22 @@ import os import sys import time +import paddle.fluid as fluid def train(prefix): - selected_gpus = os.getenv("FLAGS_selected_gpus") + if fluid.core.is_compiled_with_xpu(): + selected_devices = os.getenv("FLAGS_selected_xpus") + else: + selected_devices = os.getenv("FLAGS_selected_gpus") trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS") current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") worker_endpoints = worker_endpoints_env trainers_num = len(worker_endpoints.split(',')) - name = "selected_gpus:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"\ - .format(selected_gpus, worker_endpoints, trainers_num, current_endpoint,trainer_id) + name = "selected_devices:{} worker_endpoints:{} trainers_num:{} current_endpoint:{} trainer_id:{}"\ + .format(selected_devices, worker_endpoints, trainers_num, current_endpoint,trainer_id) print(name) with open("{}.check_{}.log".format(prefix, trainer_id), "w") as f: diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index d30de102020..6511ee65c59 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -464,8 +464,14 @@ class TestParallelDyGraphRunnerBase(object): def run_trainer(self, args): seed = 90 - device_id = int(os.getenv("FLAGS_selected_gpus", "0")) - place = fluid.CUDAPlace(device_id) + if fluid.core.is_compiled_with_cuda(): + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + elif fluid.core.is_compiled_with_xpu(): + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) + else: + assert ("Only support CUDAPlace or XPUPlace for now.") with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed @@ -476,7 +482,8 @@ class TestParallelDyGraphRunnerBase(object): model, train_reader, opt = self.get_model() nranks = len(args.endpoints.split(",")) if args.endpoints else 1 - if args.update_method == "nccl2": + #if args.update_method == "nccl2": + if args.update_method == "nccl2" or args.update_method == "bkcl": strategy = dygraph.parallel.ParallelStrategy() strategy.nranks = nranks strategy.local_rank = args.trainer_id @@ -592,7 +599,7 @@ def runtime_main(test_class): '--update_method', type=str, default="local", - choices=["pserver", "nccl2", "local", "nccl2_reduce_layer"]) + choices=["pserver", "nccl2", "bkcl", "local", "nccl2_reduce_layer"]) parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--nccl_comm_num', type=int, required=False, default=1) @@ -608,6 +615,7 @@ def runtime_main(test_class): '--current_endpoint', type=str, required=False, default="") parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--use_cuda', action='store_true') + parser.add_argument('--use_xpu', action='store_true') parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--dc_asgd', action='store_true') @@ -656,9 +664,15 @@ class TestDistBase(unittest.TestCase): def _after_setup_config(self): if self._enforce_place == "CPU": self.__use_cuda = False + self.__use_xpu = False self._use_dgc = False elif self._enforce_place == "GPU": self.__use_cuda = True + self.__use_xpu = False + elif self._enforce_place == "XPU": + self.__use_cuda = False + self.__use_xpu = True + self._use_dgc = False else: if fluid.core.is_compiled_with_cuda(): self.__use_cuda = True @@ -681,6 +695,7 @@ class TestDistBase(unittest.TestCase): self._dc_asgd = False # must use with async mode self._use_reader_alloc = True self._nccl2_mode = False + self._bkcl_mode = False self._pipeline_mode = False self._mp_mode = False # FIXME(typhoonzero): I added this stupid argument to enable @@ -783,7 +798,7 @@ class TestDistBase(unittest.TestCase): batch_size=DEFAULT_BATCH_SIZE, batch_merge_repeat=1, log_name="", - gpus="0"): + devices="0"): cmd = self._python_interp @@ -804,7 +819,14 @@ class TestDistBase(unittest.TestCase): if self.__use_cuda: cmd += " --use_cuda" env_local = { - "CUDA_VISIBLE_DEVICES": gpus, + "CUDA_VISIBLE_DEVICES": devices, + "PADDLE_TRAINERS_NUM": "1", + "PADDLE_TRAINER_ID": "0" + } + elif self.__use_xpu: + cmd += " --use_xpu" + env_local = { + "FLAGS_selected_xpus": devices, "PADDLE_TRAINERS_NUM": "1", "PADDLE_TRAINER_ID": "0" } @@ -812,7 +834,7 @@ class TestDistBase(unittest.TestCase): env_local = {'CPU_NUM': '1'} # not use dgc in single card - if len(gpus) > 1 and self._use_dgc: + if len(devices) > 1 and self._use_dgc: cmd += " --use_dgc" env_local.update(envs) @@ -962,6 +984,19 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_CURRENT_ENDPOINT": ep, }) + # TODO(liuyuhui):XPU_VISIBLE_DEVICES is not working right now, + # will update it after Badiu Kunlun partners' support. + elif self.__use_xpu: + tr_cmd += " --use_xpu" + env.update({ + "FLAGS_selected_xpus": "{}".format(trainer_id), + #"XPU_VISIBLE_DEVICES": "{}".format(trainer_id + 1), + "PADDLE_TRAINERS_NUM": "{}".format(trainer_num), + "PADDLE_TRAINER_ID": "{}".format(trainer_id), + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": ep, + "GLOG_v": "2", + }) else: env.update({'CPU_NUM': '1'}) @@ -999,8 +1034,8 @@ class TestDistBase(unittest.TestCase): return tr_cmd, env - def _run_cluster_nccl2(self, model, envs, nccl2_reduce_layer, - check_error_log, log_name): + def _run_cluster_nccl2(self, model, envs, update_method, check_error_log, + log_name): if self._use_hallreduce: self._ps_endpoints = "" @@ -1018,10 +1053,6 @@ class TestDistBase(unittest.TestCase): # NOTE: we reuse ps_endpoints as nccl2 worker endpoints worker_endpoints = self._ps_endpoints.split(",") - if nccl2_reduce_layer: - update_method = "nccl2_reduce_layer" - else: - update_method = "nccl2" trainer_num = len(worker_endpoints) @@ -1150,16 +1181,24 @@ class TestDistBase(unittest.TestCase): tr0_losses, tr1_losses = self._run_cluster_nccl2( model_file, required_envs, - True, - check_error_log, + update_method="nccl2_reduce_layer", + check_error_log=check_error_log, log_name=log_name) else: tr0_losses, tr1_losses = self._run_cluster_nccl2( model_file, required_envs, - False, - check_error_log, + update_method='nccl2', + check_error_log=check_error_log, log_name=log_name) + elif self._bkcl_mode: + tr0_losses, tr1_losses = self._run_cluster_nccl2( + model_file, + required_envs, + update_method='bkcl', + check_error_log=check_error_log, + log_name=log_name) + elif self._pipeline_mode: tr0_losses, tr1_losses = self._run_pipeline( model_file, required_envs, check_error_log, log_name=log_name) @@ -1196,7 +1235,7 @@ class TestDistBase(unittest.TestCase): required_envs, check_error_log, log_name=log_name + "_dgc_2cards", - gpus="0,1") + devices="0,1") self._use_dgc = False base_losses = self._run_local( @@ -1204,7 +1243,7 @@ class TestDistBase(unittest.TestCase): required_envs, check_error_log, log_name=log_name + "_base_2cards", - gpus="0,1") + devices="0,1") self._use_dgc = True diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleet_save.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleet_save.py index 7336794578e..2a6af6e3908 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleet_save.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleet_save.py @@ -89,8 +89,8 @@ class TestDistMnistFleetSave(TestDistBase): tr0_losses, tr1_losses = self._run_cluster_nccl2( model_file, required_envs, - False, - check_error_log, + update_method='nccl2', + check_error_log=check_error_log, log_name=log_name) dirname = '/tmp' diff --git a/python/paddle/fluid/tests/unittests/test_dist_sharding_save.py b/python/paddle/fluid/tests/unittests/test_dist_sharding_save.py index b4620d7a0c5..e94ad37c6bd 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_sharding_save.py +++ b/python/paddle/fluid/tests/unittests/test_dist_sharding_save.py @@ -32,7 +32,6 @@ class TestDistMnistFleetSave(TestDistBase): self._sharding_save = True self._enforce_place = "GPU" - def _rm_temp_files(self, dirname): shutil.rmtree(dirname) @@ -40,9 +39,13 @@ class TestDistMnistFleetSave(TestDistBase): sharding_save_files = sorted(os.listdir(dirname)) - check_files = ['fc_0.b_0', 'fc_0.b_0_velocity_0', 'fc_0.w_0', 'fc_0.w_0_velocity_0', 'fc_1.b_0', - 'fc_1.b_0_velocity_0', 'fc_1.w_0', 'fc_1.w_0_velocity_0', 'fc_2.b_0', - 'fc_2.b_0_velocity_0', 'fc_2.w_0', 'fc_2.w_0_velocity_0', 'learning_rate_0'] + check_files = [ + 'fc_0.b_0', 'fc_0.b_0_velocity_0', 'fc_0.w_0', + 'fc_0.w_0_velocity_0', 'fc_1.b_0', 'fc_1.b_0_velocity_0', + 'fc_1.w_0', 'fc_1.w_0_velocity_0', 'fc_2.b_0', + 'fc_2.b_0_velocity_0', 'fc_2.w_0', 'fc_2.w_0_velocity_0', + 'learning_rate_0' + ] if sharding_save_files != check_files: self._rm_temp_files(dirname) @@ -62,8 +65,8 @@ class TestDistMnistFleetSave(TestDistBase): tr0_losses, tr1_losses = self._run_cluster_nccl2( model_file, required_envs, - False, - check_error_log, + update_method='nccl2', + check_error_log=check_error_log, log_name=log_name) dirname = './ut_sharding_save_model' diff --git a/python/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh b/python/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh index 14679c49eae..89f696dee47 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh +++ b/python/paddle/fluid/tests/unittests/test_fleet_launch_nproc.sh @@ -27,7 +27,7 @@ function test_nproc_0(){ # nproc_per_node=1, each with 2 gpus python -m paddle.distributed.launch ${distributed_args} nproc_process.py fleet_nproc_0 - str0="selected_gpus:${gpus} worker_endpoints:127.0.0.1:35789 trainers_num:1 current_endpoint:127.0.0.1:35789 trainer_id:0" + str0="selected_devices:${gpus} worker_endpoints:127.0.0.1:35789 trainers_num:1 current_endpoint:127.0.0.1:35789 trainer_id:0" if grep -q "$str0" "$file_0"; then echo "find trainer 0" else @@ -50,6 +50,12 @@ if ! python detected_gpu.py ; then test_nproc_0 "" fi +# unittest3:xpu +if python detected_xpu.py ; then + echo "begin ut 3:" + export XPU_VISIBLE_DEVICES=0,1 + test_nproc_0 "0,1" +fi function test_nproc_1_gpu(){ file_0="fleet_nproc_1.check_0.log" @@ -59,7 +65,7 @@ function test_nproc_1_gpu(){ distributed_args="--log_dir=testlog --nproc_per_node=2" python -m paddle.distributed.launch ${distributed_args} nproc_process.py fleet_nproc_1 - str0="selected_gpus:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0" + str0="selected_devices:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0" if grep -q "$str0" "$file_0"; then echo "find trainer 0" else @@ -67,7 +73,7 @@ function test_nproc_1_gpu(){ exit -1 fi - str1="selected_gpus:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1" + str1="selected_devices:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1" if grep -q "$str1" "$file_1"; then echo "find trainer 1" else @@ -76,9 +82,9 @@ function test_nproc_1_gpu(){ fi } -# unittest3: nproc_per_node=2, each with 1 gpus +# unittest4: nproc_per_node=2, each with 1 gpus if python detected_gpu.py ; then - echo "begin ut 3:" + echo "begin ut 4:" export CUDA_VISIBLE_DEVICES=0,1 test_nproc_1_gpu fi @@ -91,7 +97,7 @@ function test_nproc_1_cpu(){ distributed_args="--log_dir=testlog --nproc_per_node=2" python -m paddle.distributed.launch ${distributed_args} nproc_process.py fleet_nproc_1 - str0="selected_gpus: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0" + str0="selected_devices: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0" if grep -q "$str0" "$file_0"; then echo "find trainer 0" else @@ -99,7 +105,7 @@ function test_nproc_1_cpu(){ exit -1 fi - str1="selected_gpus: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1" + str1="selected_devices: worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1" if grep -q "$str1" "$file_1"; then echo "find trainer 1" else @@ -108,9 +114,42 @@ function test_nproc_1_cpu(){ fi } -# unittest4: nproc_per_node=2, cpu +# unittest5: nproc_per_node=2, cpu if ! python detected_gpu.py ; then - echo "begin ut 4:" + echo "begin ut 5:" export CUDA_VISIBLE_DEVICES="" test_nproc_1_cpu fi + + +function test_nproc_1_xpu(){ + file_0="fleet_nproc_1.check_0.log" + file_1="fleet_nproc_1.check_1.log" + rm -f ${file_0} ${file_1} + + distributed_args="--log_dir=testlog --nproc_per_node=2" + python -m paddle.distributed.launch ${distributed_args} nproc_process.py fleet_nproc_1 + + str0="selected_devices:0 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35789 trainer_id:0" + if grep -q "$str0" "$file_0"; then + echo "find trainer 0" + else + echo "not find trainer 0" + exit -1 + fi + + str1="selected_devices:1 worker_endpoints:127.0.0.1:35789,127.0.0.1:35790 trainers_num:2 current_endpoint:127.0.0.1:35790 trainer_id:1" + if grep -q "$str1" "$file_1"; then + echo "find trainer 1" + else + echo "not find trainer 1" + exit -1 + fi +} + +# unittest6: nproc_per_node=2, each with 1 gpus +if python detected_xpu.py ; then + echo "begin ut 6:" + export XPU_VISIBLE_DEVICES=0,1 + test_nproc_1_xpu +fi diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py index 9cc507aa9b7..e63d1eedd9d 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py @@ -41,6 +41,25 @@ class TestParallelDygraphMnist(TestDistBase): log_name=flag_name) +#TODO(liuyuhui): Multi-Card Baidu Kunlun XPU training exist accuracy problems +#it is difficult to find out immediately where the problem is, +#and we will work with frameworkers' help to fix it. +class TestParallelDygraphMnistXPU(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._bkcl_mode = True + self._dygraph = True + self._enforce_place = "XPU" + + def test_mnist_xpu(self): + if fluid.core.is_compiled_with_xpu(): + self.check_with_place( + "parallel_dygraph_mnist.py", + delta=1e-1, + check_error_log=True, + log_name=flag_name) + + class TestParallelDygraphMnistSpawn(TestDistSpawnRunner): def test_mnist_with_spawn(self): if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): -- GitLab