From 83a2fb1f08714d12728292924ea0e07f72451987 Mon Sep 17 00:00:00 2001 From: WangXi Date: Wed, 10 Mar 2021 19:50:37 +0800 Subject: [PATCH] Add collective async wait op (#31463) --- .../operators/collective/c_wait_comm_op.cc | 91 ++++++++++++++ .../operators/collective/c_wait_compute_op.cc | 95 +++++++++++++++ paddle/fluid/platform/collective_helper.cc | 28 +++++ paddle/fluid/platform/collective_helper.h | 2 + python/paddle/fluid/framework.py | 3 +- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/collective_allreduce_op_wait.py | 114 ++++++++++++++++++ .../tests/unittests/test_collective_wait.py | 37 ++++++ 8 files changed, 370 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/collective/c_wait_comm_op.cc create mode 100644 paddle/fluid/operators/collective/c_wait_compute_op.cc create mode 100644 python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_wait.py diff --git a/paddle/fluid/operators/collective/c_wait_comm_op.cc b/paddle/fluid/operators/collective/c_wait_comm_op.cc new file mode 100644 index 00000000000..d0dfc3bb1c2 --- /dev/null +++ b/paddle/fluid/operators/collective/c_wait_comm_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + +namespace paddle { +namespace operators { + +class CWaitCommOp : public framework::OperatorBase { + public: + CWaitCommOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE_EQ(is_gpu_place(place), true, + platform::errors::PreconditionNotMet( + "wait_comm op can run on gpu place only for now.")); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + int ring_id = Attr("ring_id"); + + auto compute_stream = + static_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto comm_stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); + + auto event = + platform::NCCLCommContext::Instance().Get(ring_id, place)->comm_event(); + +// comm_stream-->event-->compute_stream +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, comm_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0)); +#endif +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; + +class CWaitCommOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) Dependency of the variable need to sync") + .AsDuplicable(); + AddOutput("Out", "(Tensor) Dependency of the variable need to sync") + .AsDuplicable(); + AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); + AddComment(R"DOC( +CWaitComm Operator + +Compute stream wait Comm Stream with async event. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_wait_comm, ops::CWaitCommOp, ops::CWaitCommOpMaker); diff --git a/paddle/fluid/operators/collective/c_wait_compute_op.cc b/paddle/fluid/operators/collective/c_wait_compute_op.cc new file mode 100644 index 00000000000..12a28040ef1 --- /dev/null +++ b/paddle/fluid/operators/collective/c_wait_compute_op.cc @@ -0,0 +1,95 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#include "paddle/fluid/framework/op_registry.h" +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + +namespace paddle { +namespace operators { + +class CWaitComputeOp : public framework::OperatorBase { + public: + CWaitComputeOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE_EQ( + is_gpu_place(place), true, + platform::errors::PreconditionNotMet( + "wait_compute op can run on gpu place only for now.")); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + int ring_id = Attr("ring_id"); + + auto compute_stream = + static_cast( + platform::DeviceContextPool::Instance().Get(place)) + ->stream(); + auto comm_stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); + + auto event = platform::NCCLCommContext::Instance() + .Get(ring_id, place) + ->compute_event(); + +// compute_stream-->event-->comm_stream +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, compute_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0)); +#endif +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; + +class CWaitComputeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) Dependency of the variable need to sync") + .AsDuplicable(); + AddOutput("Out", "(Tensor) Dependency of the variable need to sync") + .AsDuplicable(); + AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); + AddComment(R"DOC( +CWaitCompute Operator + +Comm stream wait Compute Stream with async event. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_wait_compute, ops::CWaitComputeOp, + ops::CWaitComputeOpMaker); diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 4b16a67b235..f2b478f7d20 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/platform/collective_helper.h" #include +#include "paddle/fluid/platform/cuda_resource_pool.h" + namespace paddle { namespace platform { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -43,12 +45,31 @@ class NCCLCommImpl : public NCCLComm { } CUDADeviceContext* dev_context() const override { return dev_ctx_.get(); } + gpuEvent_t compute_event() const override { return compute_event_.get(); } + + gpuEvent_t comm_event() const override { return comm_event_.get(); } + + void set_compute_event( + std::shared_ptr&& compute_event) { + compute_event_ = std::move(compute_event); + } + + void set_comm_event(std::shared_ptr&& comm_event) { + comm_event_ = std::move(comm_event); + } + private: int ring_id_; int nranks_; int rank_; ncclComm_t comm_; std::unique_ptr dev_ctx_; + + // used for comm wait compute, compute_stream-->event-->comm_stream + std::shared_ptr compute_event_; + + // used for compute wait comm, comm_stream-->event-->compute_stream + std::shared_ptr comm_event_; }; NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, @@ -124,12 +145,19 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank, std::unique_ptr dev_ctx( new CUDADeviceContext(CUDAPlace(dev_id))); + std::shared_ptr compute_event( + platform::CudaEventResourcePool::Instance().New(dev_id)); + std::shared_ptr comm_event( + platform::CudaEventResourcePool::Instance().New(dev_id)); + NCCLCommImpl* c = new NCCLCommImpl; c->set_ring_id(ring_id); c->set_nranks(nranks); c->set_rank(rank); c->set_comm(comm); c->set_dev_ctx(std::move(dev_ctx)); + c->set_compute_event(std::move(compute_event)); + c->set_comm_event(std::move(comm_event)); comm_map_mutex_.lock(); if (comm_map_.count(ring_id) == 0) { diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index 8a6719ab685..197f905ba68 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -57,6 +57,8 @@ class NCCLComm { virtual int device_id() const = 0; virtual ncclComm_t comm() const = 0; virtual gpuStream_t stream() const = 0; + virtual gpuEvent_t compute_event() const = 0; + virtual gpuEvent_t comm_event() const = 0; virtual CUDADeviceContext* dev_context() const = 0; virtual ~NCCLComm() = default; }; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fd8a39259d9..04ed384846f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2121,7 +2121,8 @@ class Operator(object): 'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify', 'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream', - 'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv' + 'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv', + 'c_wait_comm', 'c_wait_compute' } def __init__(self, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 796331e7a5a..b5c554a58cb 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -84,6 +84,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api) LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api) + LIST(REMOVE_ITEM TEST_OPS test_collective_wait) LIST(REMOVE_ITEM TEST_OPS test_memcpy_op) endif() diff --git a/python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py b/python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py new file mode 100644 index 00000000000..61a0ad3bd76 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py @@ -0,0 +1,114 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveAllreduce(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofallreduce", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + + # tout = tin + tin - tin = tin + if True: + main_prog.global_block().append_op( + type="elementwise_add", + inputs={ + 'X': tindata, + 'Y': tindata, + }, + outputs={'Out': toutdata}, ) + main_prog.global_block().append_op( + type="elementwise_sub", + inputs={ + 'X': toutdata, + 'Y': tindata, + }, + outputs={'Out': toutdata}, ) + + main_prog.global_block().append_op( + type='c_wait_compute', + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + + main_prog.global_block().append_op( + type="c_allreduce_sum", + inputs={'X': toutdata}, + attrs={'ring_id': ring_id}, + outputs={'Out': toutdata}, + attr={'use_calc_stream': False}) + + main_prog.global_block().append_op( + type="c_wait_comm", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + + # tout = tin + tout - tin = tout + if True: + main_prog.global_block().append_op( + type="elementwise_add", + inputs={ + 'X': tindata, + 'Y': toutdata, + }, + outputs={'Out': toutdata}, ) + main_prog.global_block().append_op( + type="elementwise_sub", + inputs={ + 'X': toutdata, + 'Y': tindata, + }, + outputs={'Out': toutdata}, ) + + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllreduce, "allreduce", 0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_wait.py b/python/paddle/fluid/tests/unittests/test_collective_wait.py new file mode 100644 index 00000000000..b34ace80723 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_wait.py @@ -0,0 +1,37 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_base import TestDistBase + +paddle.enable_static() + + +class TestCWaitOp(TestDistBase): + def _setup_config(self): + pass + + def test_allreduce_wait(self): + self.check_with_place( + "collective_allreduce_op_wait.py", + "allreduce", + check_error_log=True) + + +if __name__ == '__main__': + unittest.main() -- GitLab