diff --git a/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0aa51f7cfdfdcd51db6b31a76a6c5c8b77b3d62 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_max_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_allreduce_max, + ops::CAllReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_min_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_min_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..2f16a89c217dacb1529ab2f57d300aceabb95a85 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_min_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_allreduce_min, + ops::CAllReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 2f56f43d793fa941e96e5711ac48eb2899290259..ab1cc508fdf69e8a3b625f3204c2cc87d787363f 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -20,11 +20,19 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/collective_helper.h" +#endif + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #endif +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/bkcl_helper.h" +#endif + #if defined(PADDLE_WITH_GLOO) #include #include "paddle/fluid/framework/fleet/gloo_wrapper.h" @@ -105,6 +113,68 @@ class CAllReduceOpCPUKernel : public framework::OpKernel { } }; +template +class CAllReduceOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_XPU_BKCL) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + + auto place = ctx.GetPlace(); + BKCLDataType dtype = platform::ToBKCLDataType(in->type()); + int64_t numel = in->numel(); + const void* sendbuff = in->data(); + out->Resize(in->dims()); + void* recvbuff = out->mutable_data(place); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::BKCLCommContext::Instance().Get(rid, place); + + XPUStream stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx) + ->x_context() + ->xpu_stream; + } else { + stream = comm->stream(); + } + + BKCLOp bkcl_red_type = BKCL_ADD; + switch (red_type) { + case kRedSum: + bkcl_red_type = BKCL_ADD; + break; + + case kRedMax: + bkcl_red_type = BKCL_MAX; + break; + + case kRedMin: + bkcl_red_type = BKCL_MIN; + break; + + case kRedProd: + bkcl_red_type = BKCL_PRODUCT; + break; + + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid reduce type: %d", red_type)); + } + + PADDLE_ENFORCE_EQ(bkcl_all_reduce(comm->comm(), sendbuff, recvbuff, numel, + dtype, bkcl_red_type, stream), + BKCL_SUCCESS, platform::errors::PreconditionNotMet( + "BKCL all reduce failed")); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should be compiled with XPU.")); +#endif + } +}; + template class CAllReduceOpCUDAKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/collective/c_allreduce_prod_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_prod_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..92ba00428065bc318a48e7e9e63910716015cbf7 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_prod_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_allreduce_prod, + ops::CAllReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4ec538cd2323009657ea85e0a5d59db0ea0d3c8 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_allreduce_sum, + ops::CAllReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_max_op_xpu.cc b/paddle/fluid/operators/collective/c_reduce_max_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d3af7bb5f258b425a8618e412c3bb5552113bfe --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_max_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_reduce_max, + ops::CReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_min_op_xpu.cc b/paddle/fluid/operators/collective/c_reduce_min_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..791e58d8493cec454e2fe74c772ec70999cc8f36 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_min_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_reduce_min, + ops::CReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 1bce01e13a2ad25638128f4f619f458348d97b5e..e5374781099723c705eee904042551a52fa8f01a 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -24,10 +24,19 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/collective_helper.h" +#endif + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #endif + +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/bkcl_helper.h" +#endif + #if defined(PADDLE_WITH_GLOO) #include #include "paddle/fluid/framework/fleet/gloo_wrapper.h" @@ -110,6 +119,69 @@ class CReduceOpCPUKernel : public framework::OpKernel { } }; +template +class CReduceOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_XPU_BKCL) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + + auto place = ctx.GetPlace(); + BKCLDataType dtype = platform::ToBKCLDataType(in->type()); + int64_t numel = in->numel(); + const void* sendbuff = in->data(); + out->Resize(in->dims()); + void* recvbuff = out->mutable_data(place); + + int rid = ctx.Attr("ring_id"); + int root = ctx.Attr("root_id"); + auto comm = platform::BKCLCommContext::Instance().Get(rid, place); + + XPUStream stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx) + ->x_context() + ->xpu_stream; + } else { + stream = comm->stream(); + } + + BKCLOp bkcl_red_type = BKCL_ADD; + switch (red_type) { + case kRedSum: + bkcl_red_type = BKCL_ADD; + break; + + case kRedMax: + bkcl_red_type = BKCL_MAX; + break; + + case kRedMin: + bkcl_red_type = BKCL_MIN; + break; + + case kRedProd: + bkcl_red_type = BKCL_PRODUCT; + break; + + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid reduce type: %d", red_type)); + } + + PADDLE_ENFORCE_EQ(bkcl_reduce(comm->comm(), sendbuff, recvbuff, numel, + dtype, bkcl_red_type, root, stream), + BKCL_SUCCESS, platform::errors::PreconditionNotMet( + "BKCL all reduce failed")); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should be compiled with XPU.")); +#endif + } +}; + template class CReduceOpCUDAKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/collective/c_reduce_prod_op_xpu.cc b/paddle/fluid/operators/collective/c_reduce_prod_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7e770e8ffdcaf36408a13c7e01a31ebd1eced20 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_prod_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_reduce_prod, + ops::CReduceOpXPUKernel) diff --git a/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc b/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0ec4d2a99cd711f315186f7ce8966585685aeef --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_sum_op_xpu.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace platform { +struct XPUPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(c_reduce_sum, + ops::CReduceOpXPUKernel) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 486ad38ae296f9667d024e4bd6c40a1036bbc270..d0906052c999f3e547a9b1451e1a81013adfd85a 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -627,6 +627,12 @@ if (WITH_XPU) add_subdirectory(xpu) endif() +# dist xpu tests: +if (WITH_XPU_BKCL) + py_test(test_collective_reduce_api_xpu SRCS "test_collective_reduce_api.py") + py_test(test_collective_allreduce_api_xpu SRCS "test_collective_allreduce_api.py") +endif() + if (WITH_ASCEND_CL) add_subdirectory(npu) endif() diff --git a/python/paddle/fluid/tests/unittests/test_collective_allreduce_api.py b/python/paddle/fluid/tests/unittests/test_collective_allreduce_api.py index a405da80adaf0f2c3b6698bd175797670a748c62..eed2388f36ffe68ab9e5d1ecdf2201525adeb55d 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_allreduce_api.py +++ b/python/paddle/fluid/tests/unittests/test_collective_allreduce_api.py @@ -27,8 +27,14 @@ class TestCollectiveAllreduceAPI(TestDistBase): pass def test_allreduce_nccl(self): - self.check_with_place("collective_allreduce_api.py", "allreduce", - "nccl") + if paddle.fluid.core.is_compiled_with_cuda(): + self.check_with_place("collective_allreduce_api.py", "allreduce", + "nccl") + + def test_allreduce_bkcl(self): + if paddle.fluid.core.is_compiled_with_xpu(): + self.check_with_place("collective_allreduce_api.py", "allreduce", + "bkcl") def test_allreduce_gloo(self): self.check_with_place("collective_allreduce_api.py", "allreduce", diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 660018e285a85261e531449119af97bc25cf4e6a..ad85adb2d51978fe659f8cc7eaf05714b19e15c1 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -50,6 +50,9 @@ class TestCollectiveAPIRunnerBase(object): device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace( device_id) #if args.use_gpu else fluid.CPUPlace() + elif args['backend'] == 'bkcl': + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) else: place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -71,7 +74,6 @@ class TestCollectiveAPIRunnerBase(object): def runtime_main(test_class, col_type): args = {} model = test_class() - args["deviceid"] = os.getenv("FLAGS_selected_gpus") args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID")) args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM")) args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS') @@ -112,21 +114,38 @@ class TestDistBase(unittest.TestCase): worker_endpoints = self._ps_endpoints.split(",") w0_ep, w1_ep = worker_endpoints #print("w0_ep:",w0_ep," w1_ep:",w1_ep) - env0 = { - "FLAGS_selected_gpus": "0", - "PADDLE_TRAINER_ID": "0", - "PADDLE_TRAINERS_NUM": "2", - "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, - "PADDLE_CURRENT_ENDPOINT": w0_ep - } + if core.is_compiled_with_cuda(): + env0 = { + "FLAGS_selected_gpus": "0", + "PADDLE_TRAINER_ID": "0", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w0_ep + } - env1 = { - "FLAGS_selected_gpus": "1", - "PADDLE_TRAINER_ID": "1", - "PADDLE_TRAINERS_NUM": "2", - "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, - "PADDLE_CURRENT_ENDPOINT": w1_ep - } + env1 = { + "FLAGS_selected_gpus": "1", + "PADDLE_TRAINER_ID": "1", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w1_ep + } + elif core.is_compiled_with_xpu(): + env0 = { + "FLAGS_selected_xpus": "0", + "PADDLE_TRAINER_ID": "0", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w0_ep + } + + env1 = { + "FLAGS_selected_xpus": "1", + "PADDLE_TRAINER_ID": "1", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w1_ep + } #update environment env0.update(envs) env1.update(envs) @@ -169,7 +188,10 @@ class TestDistBase(unittest.TestCase): path_id="0", check_error_log=False, need_envs={}): - with_gloo = '0' if backend == "nccl" else '1' + if backend == "nccl" or backend == "bkcl": + with_gloo = '0' + else: + with_gloo = '1' required_envs = { "FLAGS_fraction_of_gpu_memory_to_use": "0.15", "FLAGS_eager_delete_tensor_gb": "0.0", diff --git a/python/paddle/fluid/tests/unittests/test_collective_reduce_api.py b/python/paddle/fluid/tests/unittests/test_collective_reduce_api.py index 8d28c794f023a6945893342a53386f6ffb8a6052..721f446c9f09462e622811df81a989928b1509f4 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_reduce_api.py +++ b/python/paddle/fluid/tests/unittests/test_collective_reduce_api.py @@ -27,7 +27,12 @@ class TestCollectiveReduceAPI(TestDistBase): pass def test_reduce_nccl(self): - self.check_with_place("collective_reduce_api.py", "reduce", "nccl") + if paddle.fluid.core.is_compiled_with_cuda(): + self.check_with_place("collective_reduce_api.py", "reduce", "nccl") + + def test_reduce_bkcl(self): + if paddle.fluid.core.is_compiled_with_xpu(): + self.check_with_place("collective_reduce_api.py", "reduce", "bkcl") def test_reduce_gloo(self): self.check_with_place("collective_reduce_api.py", "reduce", "gloo", "1")